题解 ABC294G【Distance Queries on a Tree】

发布时间 2023-03-22 21:11:27作者: rui_er

DFS 序树状数组。

不妨以 \(1\) 为根,设 \(\operatorname{dep}(u)\) 表示 \(u\) 到根路径的边权和,\(\operatorname{dis}(u,v)\) 表示 \(u,v\) 间路径的边权和,\(\operatorname{LCA}(u,v)\) 表示 \(u,v\) 的最近公共祖先。则显然有:\(\operatorname{dis}(u,v)=\operatorname{dep}(u)+\operatorname{dep}(v)-2\times\operatorname{dep}(\operatorname{LCA}(u,v))\)。只需要维护每个点的 \(\operatorname{dep}\)

给一条边重新赋权值等价于给它的边权加上 \(\Delta w\),这一操作会导致这条边的子节点所在子树内每个点的 \(\operatorname{dep}\) 增加 \(\Delta w\)。使用 DFS 序即可将子树转化为区间,区间加、单点查只需要差分后树状数组维护即可。

时间复杂度 \(\mathcal O((n+q)\log n)\)

// Problem: G - Distance Queries on a Tree
// Contest: AtCoder - AtCoder Beginner Contest 294
// URL: https://atcoder.jp/contests/abc294/tasks/abc294_g
// Memory Limit: 1024 MB
// Time Limit: 4000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x,y,z) for(ll x=(y);x<=(z);x++)
#define per(x,y,z) for(ll x=(y);x>=(z);x--)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do{freopen(s".in","r",stdin);freopen(s".out","w",stdout);}while(false)
#define likely(exp) __builtin_expect(!!(exp), 1)
#define unlikely(exp) __builtin_expect(!!(exp), 0)
using namespace std;
typedef long long ll;

mt19937 rnd(std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
ll randint(ll L, ll R) {
	uniform_int_distribution<ll> dist(L, R);
	return dist(rnd);
}

template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}

const ll N = 2e5+5;

ll n, m, fa[N][19], dis[N], sz[N], bottom[N], dfn[N], tms, val[N];
vector<tuple<ll, ll, ll>> e[N];

struct BIT {
	ll c[N];
	ll lowbit(ll x) {return x & (-x);}
	void add(ll x, ll k) {for(; x <= n; x += lowbit(x)) c[x] += k;}
	ll ask(ll x) {ll k = 0; for(; x; x -= lowbit(x)) k += c[x]; return k;}
}bit;

void dfs(ll u, ll f) {
	fa[u][0] = f;
	rep(i, 1, 18) fa[u][i] = fa[fa[u][i-1]][i-1];
	dis[u] = dis[f] + 1;
	dfn[u] = ++tms;
	sz[u] = 1;
	for(auto i : e[u]) {
		ll v, w, id;
		tie(v, w, id) = i;
		if(v != f) {
			dfs(v, u);
			sz[u] += sz[v];
			bottom[id] = v;
			bit.add(dfn[v], w);
			bit.add(dfn[v]+sz[v], -w);
		}
	}
}

ll LCA(ll u, ll v) {
	if(dis[u] < dis[v]) swap(u, v);
	per(i, 18, 0) if(dis[fa[u][i]] >= dis[v]) u = fa[u][i];
	if(u == v) return u;
	per(i, 18, 0) if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
	return fa[u][0];
}

int main() {
	scanf("%lld", &n);
	rep(i, 1, n-1) {
		ll u, v, w;
		scanf("%lld%lld%lld", &u, &v, &w);
		e[u].emplace_back(v, w, i);
		e[v].emplace_back(u, w, i);
		val[i] = w;
	}
	dfs(1, 0);
	for(scanf("%lld", &m); m; m--) {
		ll op, u, v;
		scanf("%lld%lld%lld", &op, &u, &v);
		if(op == 1) {
			ll bot = bottom[u];
			bit.add(dfn[bot], v-val[u]);
			bit.add(dfn[bot]+sz[bot], val[u]-v);
			val[u] = v;
		}
		else {
			ll lca = LCA(u, v);
			printf("%lld\n", bit.ask(dfn[u])+bit.ask(dfn[v])-2*bit.ask(dfn[lca]));
		}
	}
	return 0;
}