[题解] P5901 [IOI2009] Regions

发布时间 2023-11-09 21:41:17作者: definieren

P5901 [IOI2009] Regions

给你一棵树,每个点有颜色 \(h_i\)
多次询问,每次询问有多少对 \((u, v)\) 满足 \(u\)\(v\) 的祖先且 \(u\) 的颜色是 \(r_1\)\(v\) 的颜色是 \(r_2\)
\(n, q \le 2 \times 10^5, h_i \le 2.5 \times 10^4\)

总颜色数一定,考虑对颜色的出现次数根号分治。记阈值为 \(B\)

\(r_1\)\(r_2\) 颜色出现次数都 \(< B\) 时,我们直接枚举 \(r_1\) 中的每个点,数它的子树内的颜色为 \(r_2\) 的数的个数。这个记一下 dfs 序之后就是一维偏序,可以提前排好序,预处理时间复杂度 \(O(n \log n)\),单次查询时间复杂度 \(O(B)\)

当其中一个出现次数 \(\ge B\) 时,由于颜色数较少,考虑预处理出答案。我们可以枚举每个颜色然后 dfs 预处理,时间复杂度 \(O(\frac{n^2}{B})\)

\(B = \sqrt n\) 时,时间复杂度 \(O(n \sqrt n)\)

constexpr int MAXN = 2e5 + 9, MAXM = 2.5e4 + 9, B = 400;
int n, m, q, h[MAXN], id[MAXN], dfn[MAXN], siz[MAXN],
	cnt[MAXM];
ll ans1[B][MAXM], ans2[B][MAXM];
vector<int> G[MAXN], s1, s2, a[MAXM];
vector<pair<int, int> > qry[MAXM];

void dfs0(int u) {
	static int dfc = 0;
	dfn[u] = ++ dfc, siz[u] = 1;
	for (auto v : G[u])
		dfs0(v), siz[u] += siz[v];
	return;
}
void dfs1(int u, int col, int cnt) {
	ans1[id[col]][h[u]] += cnt;
	cnt += (h[u] == col);
	for (auto v : G[u]) dfs1(v, col, cnt);
	cnt -= (h[u] == col);
	return;
}
int dfs2(int u, int col) {
	int cnt = 0;
	for (auto v : G[u])
		cnt += dfs2(v, col);
	ans2[id[col]][h[u]] += cnt;
	cnt += (h[u] == col);
	return cnt;
}

void slv() {
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);
	cin >> n >> m >> q >> h[1];
	for (int i = 2; i <= n; i ++) {
		int s; cin >> s >> h[i];
		G[s].emplace_back(i);
	}
	for (int i = 1; i <= n; i ++) cnt[h[i]] ++;
	for (int i = 1; i <= m; i ++)
		if (cnt[i] < B) id[i] = s1.size(), s1.emplace_back(i);
		else id[i] = s2.size(), s2.emplace_back(i);
	dfs0(1);
	for (int i = 1; i <= n; i ++) if (cnt[h[i]] < B)
		a[id[h[i]]].emplace_back(i);
	for (int i = 0; i < s1.size(); i ++) {
		sort(a[i].begin(), a[i].end(), [&](int x, int y) {
			return dfn[x] < dfn[y];
		});
		for (auto j : a[i]) {
			qry[i].emplace_back(dfn[j] - 1, -1);
			qry[i].emplace_back(dfn[j] + siz[j] - 1, 1);
		}
		sort(qry[i].begin(), qry[i].end(), [&](pii x, pii y) {
			return x.fir < y.fir;
		});
		for (auto &j : a[i]) j = dfn[j];
	}
	for (auto i : s2) dfs1(1, i, 0), dfs2(1, i);
	while (q --) {
		int r1, r2; cin >> r1 >> r2;
		if (cnt[r1] > B) cout << ans1[id[r1]][r2] << endl;
		else if (cnt[r2] > B) cout << ans2[id[r2]][r1] << endl;
		else {
			r1 = id[r1], r2 = id[r2];
			int cnt = -1; ll ans = 0;
			for (auto [u, op] : qry[r1]) {
				while (cnt + 1 < a[r2].size() && a[r2][cnt + 1] <= u)
					cnt ++;
				ans += op * (cnt + 1);
			}
			cout << ans << endl;
		}
	}
	return;
}