[HackerRank] Unique Colors

发布时间 2023-09-27 16:43:27作者: FOX_konata

原题

对于一个树上问题,我们显然先考虑链上怎么做

多种颜色链上还是不会做怎么办?考虑只有黑白两种颜色

我们发现这个问题正着难算,我们就考虑用 $n \times m - $ 不满足条件的颜色个数,其中 \(m\) 为颜色种类

我们发现对于一个固定的点,他的答案即为 \(2n -\) 这个点所在的黑色连通块大小 \(-\) 这个点所在的白色连通块大小

i  : 1  2  3  4  5  6  7  8  9  10
ci : 0  0  1  0  0  1  1  1  0  0
sz0: 2  2  0  2  2  0  0  0  2  2
sz1: 0  0  1  0  0  3  3  3  0  0
ans: 18 18 19 18 18 17 17 17 18 18

然后我们再考虑回多个颜色段的问题,我们发现对于每种颜色,他是独立的。意思就是我们枚举当前考虑某个颜色,我们就可以把 \(c_i\) 变成是这个颜色不是这个颜色两种,这就变成了黑白颜色的问题

i  : 1  2  3  4  5  6  7  8  9  10
ci : 0  0  1  2  2  1  0  2  2  1
sz0: 0  0  4  4  4  4  0  3  3  3
sz1: 2  2  0  2  2  0  3  3  3  0
sz2: 3  3  3  0  0  2  2  0  0  1
ans: 25 25 23 24 24 24 25 24 24 26

然后我们用同样的思想,对于固定某个颜色 \(c_i\) ,他在树上的贡献就是不经过颜色为 \(c_i\) 的点时连通块的大小,而产生贡献的部分就是这个连通块中所有的点

然后问题就变成了求这些连通块。我们考虑现在在 \(u\) 点,则我们可以通过 \(dfs\) 求出对于 \(u\) 点的每个儿子,在向下遍历到 \(x\) 时,若路径 \(x \rightarrow u\) 上没有一个颜色 \(c_u\)\(c_x = c_u\) ,则把 \(x\) 加入 \(S_{son_u}\) 这个集合中。

容易发现,我们可以通过 \(dfs\) 序 + 容斥原理来让这些连通块加上一个值。具体的,我们让 \(u\) 点子树加上一个答案,然后再在 \(S_u\) 这个集合中的点的子树内减去这个答案。这可以差分实现

最终还要记得处理以 \(1\) 为根的连通块的情况

最终复杂度 \(O(n)\)

我说的可能不清楚,给下代码:

#include <bits/stdc++.h>
// #pragma GCC optimize(2)
#define pcn putchar('\n')
#define ll long long
#define LL __int128
#define MP make_pair
#define fi first
#define se second
#define gsize(x) ((int)(x).size())
#define Min(a, b) (a = min(a, b))
#define Max(a, b) (a = max(a, b))
#define For(i, j, k) for(int i = (j), END##i = (k); i <= END##i; ++ i)
#define For__(i, j, k) for(int i = (j), END##i = (k); i >= END##i; -- i)
#define Fore(i, j, k) for(int i = (j); i; i = (k))
//#define random(l, r) ((ll)(rnd() % (r - l + 1)) + l)
using namespace std;

namespace IO {
	template <typename T> T read(T &num) {
		num = 0;
		T f = 1;
		char c = ' ';
		while(c < 48 || c > 57) if((c = getchar()) == '-') f = -1;
		while(c >= 48 && c <= 57) num = (num << 1) + (num << 3) + (c ^ 48), c = getchar();
		return num *= f;
	}
	ll read() {
		ll num = 0, f = 1;
		char c = ' ';
		while(c < 48 || c > 57) if((c = getchar()) == '-') f = -1;
		while(c >= 48 && c <= 57) num = (num << 1) + (num << 3) + (c ^ 48), c = getchar();
		return num * f;
	}
	template <typename T> void Write(T x) {
		if(x < 0) putchar('-'), x = -x;
		if(x == 0) {
			putchar('0');
			return ;
		}
		if(x > 9) Write(x / 10);
		putchar('0' + x % 10);
		return ;
	}
	void putc(string s) {
		int len = s.size() - 1;
		For(i, 0, len) putchar(s[ i ]);
	}
	template <typename T> void write(T x, string s = "\0") {
		Write( x ), putc( s );
	}
}
using namespace IO;

/* ====================================== */

const int maxn = 1e5 + 50;

int n;
int a[ maxn ];
struct Graph {
	struct Edge {
		int v, nxt;
	} edge[ maxn << 1 ];
	int head[ maxn ], cnt = 1;

	void addedge(int u, int v) {
		edge[ ++ cnt ] = Edge {v, head[ u ]};
		head[ u ] = cnt;
	}
	void Addedge(int u, int v) {
		addedge(u, v);
		addedge(v, u);
	}
} G;
int dfn[ maxn ], times, siz[ maxn ];
int buc[ maxn ];
vector<int> dwn[ maxn ], col[ maxn ];
ll ans[ maxn ];

void dfs(int u, int father) {
	int t = buc[ a[ u ] ]; // 存储备份信息 
	if(t) dwn[ t ].push_back(u);
	if(!t || u == 1) col[ a[ u ] ].push_back(u); // 特判以1为根 
	dfn[ u ] = ++ times, siz[ u ] = 1;
	Fore(i, G.head[ u ], G.edge[ i ].nxt) {
		int v = G.edge[ i ].v;
		if(v == father) continue;
		buc[ a[ u ] ] = v;// buc[ i ]的定义 
		dfs(v, u);
		siz[ u ] += siz[ v ];
	}
	buc[ a[ u ] ] = t; // 还原备份 
}

void upd(int l, int r, int k) { // 差分修改区间 
	if(l > r) return ;
	ans[ l ] += k;
	ans[ r + 1 ] -= k;
}

void mian() {
	read(n);
	For(i, 1, n) read(a[ i ]);
	int u, v;
	For(i, 2, n) {
		read(u), read(v);
		G.Addedge(u, v);
	}
	buc[ a[ 1 ] ] = 1; // buc[ i ]:存储颜色i的点的儿子 
	dfs(1, 0);
	For(i, 1, n) {
		int res = siz[ i ];
		for(auto j : dwn[ i ]) res -= siz[ j ];
		upd(dfn[ i ], dfn[ i ] + siz[ i ] - 1, res);
		for(auto j : dwn[ i ]) upd(dfn[ j ], dfn[ j ] + siz[ j ] - 1, -res);
		res = siz[ 1 ];
		for(auto j : col[ i ]) res -= siz[ j ];
		upd(1, n, res);
		for(auto j : col[ i ]) upd(dfn[ j ], dfn[ j ] + siz[ j ] - 1, -res);
	}
	For(i, 1, n) ans[ i ] += ans[ i - 1 ];
	For(i, 1, n) write(1ll * n * n - ans[ dfn[ i ] ], "\n");
}

void init() {

}

void treatment() {

}

int main() {

#ifdef ONLINE_JUDGE
#else
	freopen("data.in", "r", stdin);
//	freopen("data.out", "w", stdout);
#endif

	treatment();
	int T = 1;
//	read(T);
	while(T --) {
		init();
		mian();
	}

	return 0;
}