[NOI2023] 贸易

发布时间 2023-12-18 19:19:49作者: 小超手123

题意:

给定一棵深度为 \(n\) 的完美二叉树,根节点为 \(1\),对于所有非 \(1\) 的点,都有一条连到其父亲的边权为 \(a_i\) 的单向边,除此之外,还给定了 \(m\) 条单向边(\(u \rightarrow v)\),边权为 \(w\),保证 \(u\)\(v\) 的祖先,求 \(\sum_{i=1}^{2^n-1}\sum_{j=1}^{2^n-1}dis(i,j)\),其中 \(dis(i,j)\) 表示 \(i\)\(j\) 的最短路,如果无法到达,则 \(dis(i,j)=0\)

\(n \le 18\)

分析:

考虑枚举每个点 \(x\),然后快速统计到达其的最短路之和,即 \(\sum dis(y,x)\)

  • \(y\)\(x\) 的子树内,这是简单的,显然 \(y\) 只会走树边。记 \(g(u)\) 表示 \(u\) 子树内的所有点到 \(u\) 的最短路之和。不难发现 \(g(u) = g(ls) \times a_{ls} \times siz_{ls} + g(rs) \times a_{rs} \times siz_{rs}\)
  • \(y\) 不在 \(x\) 的子树内,此时 \(y\) 只能通过 \(y \rightarrow lca(x,y) \rightarrow x\) 的路径到达,因此在 \(lca\) 处统计答案。而 \(y \rightarrow lca(x,y)\) 只会走树边,可以用 \(g\) 来计算;而 \(lca(x,y) \rightarrow x\) 则是一类边与二类边混合起来走,看起来不好处理,但这棵树的深度最多只有 \(18\),这就意味着我们暴力跑 dijstra 是不会超时的,因为每个点最多只会被遍历 \(18\) 次。

时间复杂度 \(O(n^22^n)\)

代码:

#include<bits/stdc++.h>
#define int long long
#define N 1000006 
#define mod 998244353
using namespace std;
int read() {
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') { if(ch == '-') f = -1; ch = getchar(); }
	while(ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
	return x * f;
}
void write(int x) {
	if(x < 0) putchar('-'), x = -x;
	if(x > 9) write(x / 10);
	putchar('0' + x % 10);
}
int n, m, ans;
int a[N];
struct edge{
	int to, w;
};
vector<edge>G[N];
int h[N], f[20][N], g[N];
bool vis[N];
int siz[N], Begin[N], End[N], dep[N], z[N], cnt; 
void dfs1(int x) {
	Begin[x] = ++cnt; z[cnt] = x;
	siz[x] = 1; dep[x] = dep[x / 2] + 1;
	int ls = x * 2, rs = x * 2 + 1;
	if(ls <= n) {
		dfs1(ls);
		g[x] = (g[x] + (g[ls] + siz[ls] * a[ls] % mod) % mod) % mod;
		siz[x] += siz[ls];
	}
	if(rs <= n) {
		dfs1(rs);
		g[x] = (g[x] + (g[rs] + siz[rs] * a[rs] % mod) % mod) % mod;
		siz[x] += siz[rs];
	}
	End[x] = cnt;
}
struct node{
	int v, w;
	friend bool operator < (node x, node y) {
		return x.w > y.w;
	}
};
void dijstra(int s, int dis[]) {
	for(int i = Begin[s]; i <= End[s]; i++) {
		vis[z[i]] = 0;
		dis[z[i]] = 1e18;
		dis[z[i]] = min(dis[z[i]], f[dep[s / 2]][z[i]] + a[s]);
	}
	priority_queue<node>Q;
	dis[s] = 0; Q.push((node){s, 0});
	while(!Q.empty()) {
		int x = Q.top().v;
		Q.pop();
		if(vis[x]) continue;
		vis[x] = 1;
		for(auto r : G[x]) {
			int y = r.to;
			if(dis[y] > dis[x] + r.w) {
				dis[y] = dis[x] + r.w;
				Q.push((node){y, dis[y]});
			}
		}
		int y = x / 2;
		if(y < 1) continue;
		if(dis[y] > dis[x] + a[x]) {
			dis[y] = dis[x] + a[x];
			Q.push((node){y, dis[y]});
		}
	}
}
void dfs2(int x) {
	int now = x;
	while(now != 1) {
		if(f[dep[now / 2]][x] == 1e18) break;
		int P = (g[now / 2] - g[now] + mod - siz[now] * a[now] % mod + mod) % mod + f[dep[now / 2]][x] * (siz[now / 2] - siz[now]) % mod;
		h[x] = (h[x] + P % mod) % mod;
		now /= 2;
	}
	dijstra(x, f[dep[x]]);
	int ls = x * 2, rs = x * 2 + 1;
	if(ls <= n) dfs2(ls);
	if(rs <= n) dfs2(rs);
}
signed main() {
    n = read(), m = read();
    n = (1 << n) - 1;
    for(int i = 1; i <= n; i++) f[0][i] = 1e18;
    for(int i = 2; i <= n; i++) a[i] = read();
    for(int i = 1, u, v, w; i <= m; i++) {
    	u = read(), v = read(), w = read();
    	G[u].push_back((edge){v, w});
	}
	dfs1(1);
	dfs2(1);
	for(int i = 1; i <= n; i++) ans = (ans + (h[i] + g[i]) % mod) % mod;
	write(ans);
	return 0;
}