NOI2023 D2T1 贸易

发布时间 2023-09-07 13:31:38作者: Chy12321

图中不存在横插边,\(u \rightsquigarrow v\) 可拆成 \(u \rightsquigarrow \operatorname{lca}(u, v) \rightsquigarrow v\) 计算。

\(u \rightsquigarrow \operatorname{lca}(u, v)\),不可能走第二类道路,树形 DP 统计每条边被经过的次数并累加答案即可,时间复杂度 \(\mathcal O(2^n)\)

具体地,令 \(sum(u)\) 表示点 \(u\) 对答案的贡献,则有 \(sum(fa_u) \gets sum(u) + sz_u \times a_u\)

瓶颈在 \(\operatorname{lca}(u, v) \rightsquigarrow v\) 上。

\(f(u, k)\) 表示从点 \(u\) 深度为 \(k\) 的祖先出发到 \(u\) 的最短路长度。

对于每一条第二类边 \(x \to y\),它能更新满足 \(v \rightsquigarrow x \to y \rightsquigarrow u(dep_y > dep_u > dep_v > dep_x)\)\(f(u, dep_v)\),即 \(f(u, dep_v) = d_{u, x} + w_{x, y} + d_{y, v}\)

时间复杂度 \(\mathcal O(n^2m)\)

同时有 ,类似于 Floyd,时间复杂度 \(\mathcal O(2^nn^2)\)

统计 \(fa_v \rightsquigarrow u\) 的答案时,有 \(ans \gets f(u, dep_v - 1) \times (sz_v + 1) + sum(v \oplus 1) + sz_{v \oplus 1} \times a_{v \oplus 1}\),其中 \(\oplus\)按位异或,时间复杂度 \(\mathcal O(2^n n)\)

总时间复杂度为 \(\mathcal O(2^nn^2)\)

Bonus

本题中可以利用 \(dep_u = \log_2 u\)\(sz_u = 2^{n - dep_u} - 1\) 来递推处理,去掉 dfs 带来的常数,同时代码也更精简。

代码:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 18, MOD = 998244353;

int n, m, a[1 << MAXN], dep[1 << MAXN], sz[1 << MAXN];
ll dis[1 << MAXN], sum[1 << MAXN], f[1 << MAXN][MAXN];

int main() {
	ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
	cin >> n >> m; int tot = (1 << n) - 1; sz[1] = (1 << n) - 1;
	for (int i = 2; i <= tot; i++) {
		cin >> a[i], dis[i] = dis[i >> 1] + a[i];
		dep[i] = dep[i >> 1] + 1, sz[i] = (1 << (n - dep[i])) - 1;
	}
	for (int i = tot; i >= 2; i--) sum[i >> 1] += sum[i] + 1ll * sz[i] * a[i];
	memset(f, 0x3f, sizeof(f));
	while (m--) {
		int x, y, w;
		cin >> x >> y >> w;
		for (int v = y; v > x; v >>= 1) {
			for (int u = (v >> 1); u >= x; u >>= 1) {
				f[v][dep[u]] = min(f[v][dep[u]], dis[u] - dis[x] + w + dis[y] - dis[v]);
			}
		}
	}
	for (int u = 1; u <= tot; u++) {
		for (int v = (u >> 1); v; v >>= 1) {
			for (int k = dep[v] - 1; k >= 0; k--) {
				f[u][k] = min(f[u][k], f[u][dep[v]] + f[v][k]);
			}
		}
	}
	ll ans = 0;
	for (int u = tot; u; u--) {
		(ans += sum[u]) %= MOD;
		for (int v = u; v > 1; v >>= 1) {
			if (f[u][dep[v] - 1] < f[0][0]) (ans += f[u][dep[v] - 1] % MOD * (sz[v] + 1) + sum[v ^ 1] + 1ll * sz[v ^ 1] * a[v ^ 1]) %= MOD;
		}
	}
	cout << ans;
	return 0;
}