树上启发式合并

发布时间 2023-08-25 20:21:10作者: 徐子洋

树上启发式合并

与其说树上启发式合并是一种算法,不如说是一种思想。它在于通过”小的并入大的“保证复杂度,从而解决很多看似无法做的问题。

论纯用树上启发式合并的题很少,但是很多题却可以用树上启发式合并去解决。

模板

求解的问题往往具有如下性质:

  • 每颗子树都有要记录的信息,信息的数量和子树大小有关。
  • 一个父亲的信息包含它儿子的信息。

(若觉得抽象,不妨先看例题,再回来看模板)。

这种方法和轻重链剖分一样是找出重儿子,然后把其它儿子的信息逐个合并到重儿子上。

重儿子:子树大小最大的儿子(集合大小往往和子树大小有关)。

/*
sz[u]:u的子树大小
dep[u]:u的深度
son[u]:u的重儿子
e[u]: u的所有儿子
*/
void Union(int u, int v){
    //把v的信息合并到u上
}
void dfs(int u){//进行预处理,包括重儿子、子树大小、深度等等
    sz[u] = 1;
    //额外的一些处理
    for(int v: e[u]){
        dep[v] = dep[u] + 1, dfs(v);//遍历所有儿子
        sz[u] += sz[v];
        if(sz[v] > sz[son[u]]) son[u] = v;
    }
}
void solve(int u){
    for(int v: e[u]){
        if(v == son[u]) continue;
        dfs(v);
    }
    if(son[u]) dfs(son[u]);
    //Union+对询问进行处理
}
//上述的"对询问进行处理"是一种离线的方法,即用一个vector存每个点(子树)相关的询问。之所以要这样,是因为树上启发式合并的空间和一些线段树合并一样,对一个节点处理完之后是释放掉的。

当然,还存在一种复杂度一样的写法:

/*
dep[u]:u的深度
son[u]:u的重儿子
e[u]: u的所有儿子
*/
void Union(int u, int v){
    //把v的信息合并到u上
}
void solve(int u){
    //处理一些信息
    for(int v: e[u]){
        dep[v] = dep[u] + 1, dfs(v);
        Union(u, v);
    }
    //对询问进行处理
}

时间复杂度

考虑共有 \(n\) 个元素,一个元素 \(i\) 所在集合被合并到另一集合时才会产生时间花销。而我们每次把小的集合合并到大的,\(i\) 所在集合大小至少 \(\times 2\)。而 \(\times 2\) 必定不超过 \(\log n\) 次。故时间复杂度为 \(O(n \times \log n)=O(n\log n)\)

例题

CF208E Blood Cousins

这道题可以通过倍增访问到 \(p\) 级祖先,然后 \(p\) 即表亲就等于 \(p\) 级祖先的 \(p\) 级子孙个数 \(-1\)

显然求一个点 \(u\)\(p\) 级子孙个数等价于求 \(u\) 子树里深度为 \(dep_u+p\) 那一层的节点数。

这个东西用树上启发式合并求就行了。对于每颗子树维护两个东西:深度的集合以及每种深度出现次数(用 \(\text{STL}\) 的动态开点哈希 unordered_map)。

#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); i++)
#define FR(i, a, b) for(int i = (a); i >= (b); i--)
using namespace std;
const int N = 1e5 + 10;
struct Q{int d, i, s;};
int n, m, tot, rt[N], dep[N], fa[N][20], ans[N];
vector<int> root, e[N], s[N];
vector<Q> q[N];
unordered_map<int, int> mp[N];
void Union(int u, int v){
	if(s[u].size() < s[v].size())
		swap(s[u], s[v]), swap(mp[u], mp[v]);
	for(int &x: s[v]){
		if(!mp[u][x]) s[u].push_back(x);
		mp[u][x] += mp[v][x]; mp[v].erase(x);
	}
	s[v].clear();
}
void dfs_lca(int u, int root){
	dep[u] = dep[fa[u][0]] + 1, rt[u] = root;
	FL(i, 1, 16) fa[u][i] = fa[fa[u][i - 1]][i - 1];
	for(int &v: e[u]) fa[v][0] = u, dfs_lca(v, root); 
}
void solve(int u){
	for(int &v: e[u]) solve(v), Union(u, v);
	s[u].push_back(dep[u]), mp[u][dep[u]]++;
	for(Q &t: q[u]) ans[t.i] = mp[u][t.d] - 1;
} 
int main(){
	scanf("%d", &n);
	FL(i, 1, n){
		int r; scanf("%d", &r);
		if(!r) root.push_back(i);
		else e[r].push_back(i);
	}
	for(int &u: root) dfs_lca(u, u);
	scanf("%d", &m);
	FL(i, 1, m){
		int v, p, r;
		scanf("%d%d", &v, &p), r = p - 1;
		FR(j, 16, 0) if((1 << j) <= r)
			r -= (1 << j), v = fa[v][j];
		if(!fa[v][0]) continue;
		q[fa[v][0]].emplace_back((Q){dep[fa[v][0]] + p, i, v});
	}
	for(int &u: root) solve(u);
	FL(i, 1, m) printf("%d ", ans[i]);
	return 0;
}

CF570D Tree Requests

能重组成回文串仅当只存在至多一种字符的出现次数为奇数。

\(\text{Solution 1}\)

这道题其实是在上一题的启发式合并的基础上,哈希加了一维字母。

查询时遍历所有字母。

时间复杂度 \(O(n\log n + 26n)\),前者为启发式合并的复杂度,后者为查询的复杂度。

\(\text{Solution 2}\)

用一个二进制状态来表示每种字母出现是奇数次还是偶数次。

把上题哈希中存的出现次数换成这个二进制状态就行了。

正睿OI 908

本题是启发式合并加上动规计数。

动规计数:对于所有点分别算出作为 \(o_1,o_2\) 凑成 \(7\) 的方案数,然后再求出答案即可。

但是求 \(o_1,o_2\) 的过程中需要求 \(x\) 子树里与 \(x\) 距离为 \(d\) 的点的个数。

这个等价于求 \(x\) 子树里深度为 \(dep_x+d\) 那一层有多少个节点。直接树上启发式合并就行了。

#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); i++)
#define FR(i, a, b) for(int i = (a); i >= (b); i--)
using namespace std;
typedef long long ll;
const int N = 1e5 + 10, mod = 1e9 + 7;
int n, A, B, C, D, tot, id[N], id2[N], dep[N];
ll ans, sum, f[N][2];
vector<int> e[N], s[N];
unordered_map<int, int> mp[N];
void Union(int u, int v){
	if(s[u].size() < s[v].size()) swap(s[u], s[v]), swap(mp[u], mp[v]);
	for(int &x: s[v])
		s[u].push_back(x), mp[u][x] += mp[v][x], mp[v].erase(x);
	s[v].clear(); 
}
void dfs(int u){
	id[u] = ++tot;
	for(int &v: e[u]){
		dep[v] = dep[u] + 1, dfs(v);
		(f[u][0] += 1ll * mp[u][dep[u] + A] * mp[v][dep[u] + B]) %= mod;
		(f[u][1] += 1ll * mp[u][dep[u] + C] * mp[v][dep[u] + D]) %= mod;
		Union(u, v);
	}
	s[u].push_back(dep[u]), mp[u][dep[u]]++;
}
int main(){
	scanf("%d", &n), dep[1] = 1;
	FL(i, 2, n){
		int u, v; scanf("%d%d", &u, &v);
		e[u].push_back(v);
	}
	scanf("%d%d%d%d", &A, &B, &C, &D);
	dfs(1);
	FL(i, 1, n) id2[id[i]] = i;
	FL(i, 1, n) (ans += sum * f[id2[i]][1]) %= mod, (sum += f[id2[i]][0]) %= mod;
	printf("%lld\n", ans);
	return 0;
}

CF600E Lomsat gelralCF1009F Dominant Indices

两题做法基本一致。

显然,用维护深度的办法在维护颜色的同时,记录一下最大值以及编号和就行了。

CF246E Blood Cousins Return

把 CF208E 中 unordered_map 里存的东西替换为深度对应的不同子串个数。

合并时每往大的集合里加一个元素,就看看是否出现过(显然再开一个哈希就行了)。

询问离线到每个节点处理。

CF375D Tree and Queries

树上启发式合并维护颜色数的同时:

\(cnt_i\) 为颜色 \(i\) 的出现次数,\(sum_i\) 为出现次数大于等于 \(i\) 的颜色数。

和莫队类似的修改函数:

void add(int u){sum[++cnt[c[u]]]++;}
void del(int u){sum[cnt[c[u]]--]++;}

巧妙在 \(cnt\)\(0\) 开始加起,所以 \(\sum_i^{\le cnt_{c_u}} sum_i\) 均加了 \(1\),也正好与 \(sum\) 的定义相呼应。

CF741D Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths

判断字符集能否重构成回文串的方法同上

能重组成回文串仅当只存在至多一种字符的出现次数为奇数。

我们令 \(a_u\) 表示 \(1\to u\) 路径上的字符集的二进制状态。具体的,从右往左数第 \(1\) 位表示字符 \(a\) 的出现次数是否为奇数;从右往左第 \(2\) 位表示字符 \(b\) 的出现次数是否为奇数……以此类推。

我们发现,祖先 \(p\)\(u\) 路径上的二进制状态等价于 \(a_p\bigoplus a_u\)。也就是任意点对 \((u,v)\) 路径上的二进制状态等价于 \((a_u\bigoplus a_{lca})\bigoplus (a_v\bigoplus a_{lca})=a_u\bigoplus a_v\)

这时我们就有方法统计答案的最大值了。点 \(u\) 的答案等价于经过 \(u\) 的最长合法路径的长度,以及其子节点的答案的最大值。维护经过 \(u\) 的最长合法路径只需要维护所有的 \(a_i\),之后直接树上启发式合并即可。

这题的启发式合并过程中,先遍历轻儿子,最后重儿子。轻儿子的信息清空,重儿子的不清空。对于一颗子树先查询再修改。查询:由于允许至多一种字符出现次数为奇数,所以就枚举哪种字符出现次数为奇数(或者没有字符出现次数为奇数)。

#include <bits/stdc++.h>
#define FL(i, a, b) for(int i = (a); i <= (b); i++)
#define FR(i, a, b) for(int i = (a); i >= (b); i--)
using namespace std;
const int N = 5e5 + 10, INF = 1e9;
int n, a[N], sz[N], son[N], ans[N], dep[N], cnt[1 << 22];
vector<pair<int, int> > e[N];
void dfs(int u){
	sz[u] = 1;
	for(auto &p: e[u]){
		int v = p.first, w = p.second;
		a[p.first] = a[u] ^ (1 << w);
		dep[v] = dep[u] + 1, dfs(v), sz[u] += sz[v];
		if(sz[v] > sz[son[u]]) son[u] = v;
	}
}
void Add(int u){
	cnt[a[u]] = max(cnt[a[u]], dep[u]);
	for(auto &p: e[u]) Add(p.first);
}
void Del(int u){
	cnt[a[u]] = -INF;
	for(auto &p: e[u]) Del(p.first);
}
int calc(int u, int rt){
	int ret = max(0, dep[u] + cnt[a[u]]);
	FL(i, 0, 21) ret = max(ret, dep[u] + cnt[a[u] ^ (1 << i)]);
	if(u == rt) cnt[a[u]] = max(cnt[a[u]], dep[u]);
	for(auto &p: e[u]) if(p.first != son[rt]){
		ret = max(ret, calc(p.first, rt));
		if(u == rt) Add(p.first);
	}
	return ret;
}
void solve(int u, int h){
	for(auto &p: e[u])
		if(p.first != son[u]) solve(p.first, 0);
	if(son[u]) solve(son[u], 1);
	ans[u] = calc(u, u), ans[u] = max(0, ans[u] - dep[u] * 2);
	for(auto &p: e[u]) ans[u] = max(ans[u], ans[p.first]);
	if(!h) Del(u);
}
int main(){
	scanf("%d", &n);
	FL(i, 0, (1 << 22) - 1) cnt[i] = -INF;
	FL(i, 2, n){
		int p; char c;
		scanf("%d %c", &p, &c);
		e[p].push_back({i, c - 'a'});
	}
	dfs(1), solve(1, 0);
	FL(i, 1, n) printf("%d ", ans[i]);
	return 0;
}