题解 P9233【[蓝桥杯 2023 省 A] 颜色平衡树】

发布时间 2023-08-02 20:50:37作者: rui_er

看到树上数颜色,想到树上启发式合并(dsu on tree)。

这题几乎就是树上启发式合并板子了,感觉讲一下算法的原理比较好。

暴力解法显然是对每棵子树 dfs 一遍,求出子树大小 \(\operatorname{size}\)、子树颜色出现次数的桶 \(\operatorname{cnt}\),以及颜色出现次数的出现次数的桶 \(\operatorname{ccnt}\),判断 \(\operatorname{cnt}(C_u)\times\operatorname{ccnt}(\operatorname{cnt}(C_u))\stackrel{?}{=}\operatorname{size}(u)\) 即可判断这棵子树是不是颜色平衡树。时间复杂度 \(O(n^2)\)

注意到很多棵子树之间是包含关系。例如一条链的时候,明明可以只 dfs 一遍就能统计完答案。能不能利用这一点优化复杂度呢?

想到启发式合并。我们进行重链剖分,求出每个节点的重儿子 \(\operatorname{son}\),于是希望每个节点 \(u\) 能够从 \(\operatorname{son}(u)\) 处继承 \(\operatorname{cnt}\)\(\operatorname{ccnt}\) 的信息。如果你不会重链剖分也无所谓,重儿子的定义是子树大小最大的儿子,轻儿子的定义是除了重儿子以外的所有儿子,重边的定义是该节点与重儿子之间的边,轻边的定义是该节点与轻儿子之间的边。

定义 \(\operatorname{add}(u,\Delta)\) 表示将 \(u\) 子树的节点以 \(\Delta\) 的贡献加入到 \(\operatorname{cnt}\)\(\operatorname{ccnt}\) 中,其中 \(\Delta=\pm 1\)。于是有算法流程 \(\operatorname{calc}(u,save)\),其中 \(u\) 是当前递归到的节点,\(save\) 是一个是否保存当前贡献的开关,一会会用到:

  1. 对于所有轻儿子 \(v\),递归 \(\operatorname{calc}(v,\textrm{false})\),也就是递归求出轻儿子子树的答案,并擦除这棵子树的贡献。
  2. 如果有重儿子,递归 \(\operatorname{calc}(\operatorname{son}(u),\textrm{true})\),也就是递归求出重儿子子树的答案,并保留这棵子树的贡献。
  3. 将当前节点 \(u\) 贡献到 \(\operatorname{cnt}\)\(\operatorname{ccnt}\) 中。
  4. 对于所有轻儿子 \(v\),调用 \(\operatorname{add}(v,+1)\),将轻儿子子树贡献计入。此时 \(\operatorname{cnt}\)\(\operatorname{ccnt}\) 中的信息是 \(u\) 子树的。
  5. 统计 \(u\) 子树的答案。
  6. 如果 \(save=\textrm{false}\),调用 \(\operatorname{add}(v,-1)\) 擦除贡献。

算法正确性是显然的。由重儿子的定义,容易证明:根节点到任意节点路径上的轻边不超过 \(O(\log n)\) 条。

一个点会被暴力统计贡献,只有在 \(\operatorname{calc}\) 搜到这个点,或者搜到这个点的某个作为轻儿子的祖先时才会发生。于是每个点被暴力到的次数为 \(O(\log n)\),总复杂度为 \(O(n\log n)\)

// Problem: P9233 [蓝桥杯 2023 省 A] 颜色平衡树
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P9233
// Memory Limit: 256 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x,y,z) for(int x=(y);x<=(z);x++)
#define per(x,y,z) for(int x=(y);x>=(z);x--)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do{freopen(s".in","r",stdin);freopen(s".out","w",stdout);}while(false)
#define likely(exp) __builtin_expect(!!(exp), 1)
#define unlikely(exp) __builtin_expect(!!(exp), 0)
using namespace std;
typedef long long ll;

mt19937 rnd(std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
int randint(int L, int R) {
	uniform_int_distribution<int> dist(L, R);
	return dist(rnd);
}

template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}

const int N = 2e5+5;

int n, c[N], f[N], sz[N], son[N], cnt[N], ccnt[N], ans;
vector<int> e[N];

void dfs(int u) {
	sz[u] = 1;
	for(int v : e[u]) {
		dfs(v);
		sz[u] += sz[v];
		if(sz[v] > sz[son[u]]) son[u] = v;
	}
}

void add(int u, int dt) {
	--ccnt[cnt[c[u]]];
	cnt[c[u]] += dt;
	++ccnt[cnt[c[u]]];
	for(int v : e[u]) add(v, dt);
}

void calc(int u, bool save) {
	for(int v : e[u]) if(v != son[u]) calc(v, false);
	if(son[u]) calc(son[u], true);
	--ccnt[cnt[c[u]]];
	++cnt[c[u]];
	++ccnt[cnt[c[u]]];
	for(int v : e[u]) if(v != son[u]) add(v, 1);
	if(cnt[c[u]] * ccnt[cnt[c[u]]] == sz[u]) ++ans;
	if(!save) add(u, -1);
}

int main() {
	scanf("%d", &n);
	rep(i, 1, n) {
		scanf("%d%d", &c[i], &f[i]);
		if(f[i]) e[f[i]].push_back(i);
	}
	dfs(1);
	calc(1, true);
	printf("%d\n", ans);
	return 0;
}