树上启发式合并(dsu on tree)

发布时间 2023-12-08 14:46:34作者: 小超手123

dsu on Tree(树上启发式合并)

当遇到处理子树询问,并且无修改时。可以考虑树上启发式合并。

算法流程:

  • step1:处理出每个点的 \(siz_x\) 以及重儿子 \(son_x\)
void dfs(int x, int fa) {
	siz[x] = 1;
	int Maxson = 0;
	for(int i = 0; i < p[x].size(); i++) {
		int y = p[x][i];
		if(y == fa) continue;
		dfs(y, x);
		siz[x] += siz[y]; 
		if(siz[y] > Maxson) {
			Maxson = siz[y];
			son[x] = y;
		}
	}
}
  • step2:进入它的所有轻儿子,并删除贡献。
  • step3:进入它的重儿子,这里不删除贡献。
  • step4:将轻儿子加入贡献。
  • step5:如果需要删除贡献,就删除贡献。
void add(int x, int fa, int val) { //加入/删除贡献
	cnt[a[x]] += val;
	if(cnt[a[x]] > Mx) {
		Mx = cnt[a[x]];
		sum = a[x]; 
	}
	else if(cnt[a[x]] == Mx) {
		sum += a[x];
	}
	for(int i = 0; i < p[x].size(); i++) {
		int y = p[x][i];
		if(y == fa || y == Son) continue;
		add(y, x, val); 
	}
}
void Get(int x, int fa, bool opt) {
	for(int i = 0; i < p[x].size(); i++) {
		int y = p[x][i];
		if(y == fa || y == son[x]) continue;
		Get(y, x, 1);
	}
	if(son[x]) Get(son[x], x, 0), Son = son[x];
	add(x, fa, 1); Son = 0;
	ans[x] = sum;
	if(opt) add(x, fa, -1), sum = 0, Mx = 0;
}

时间复杂度为 \(n \log n\)

这就是CF600E的代码。

#include<bits/stdc++.h>
#define int long long
#define N 100005
using namespace std;
int n;
int a[N], son[N], siz[N], cnt[N];
vector<int>p[N];
void dfs(int x, int fa) {
	siz[x] = 1;
	int Maxson = 0;
	for(int i = 0; i < p[x].size(); i++) {
		int y = p[x][i];
		if(y == fa) continue;
		dfs(y, x);
		siz[x] += siz[y]; 
		if(siz[y] > Maxson) {
			Maxson = siz[y];
			son[x] = y;
		}
	}
}
int Son = 0, sum = 0, Mx = 0;
int ans[N];
void add(int x, int fa, int val) {
	cnt[a[x]] += val;
	if(cnt[a[x]] > Mx) {
		Mx = cnt[a[x]];
		sum = a[x]; 
	}
	else if(cnt[a[x]] == Mx) {
		sum += a[x];
	}
	for(int i = 0; i < p[x].size(); i++) {
		int y = p[x][i];
		if(y == fa || y == Son) continue;
		add(y, x, val); 
	}
}
void Get(int x, int fa, bool opt) {
	for(int i = 0; i < p[x].size(); i++) {
		int y = p[x][i];
		if(y == fa || y == son[x]) continue;
		Get(y, x, 1);
	}
	if(son[x]) Get(son[x], x, 0), Son = son[x];
	add(x, fa, 1); Son = 0;
	ans[x] = sum;
	if(opt) add(x, fa, -1), sum = 0, Mx = 0;
}
signed main() {
    scanf("%lld", &n);
    for(int i = 1; i <= n; i++) scanf("%lld", &a[i]);
    for(int i = 1, u, v; i <= n - 1; i++) {
    	scanf("%lld %lld", &u, &v);
    	p[u].push_back(v);
		p[v].push_back(u); 
	} 
	dfs(1, 0);
	Get(1, 0, 0);
	for(int i = 1; i <= n; i++) {
		printf("%lld ", ans[i]);
	}
	return 0;
}

后面是一些我做过的简单题:

CF375D

CF208E

CF1009F

CF246E