UOJ33 树上 GCD

发布时间 2023-09-05 14:24:20作者: zltzlt

UOJ 传送门

\(f_{u, i}\)\(u\) 子树内深度为 \(i\) 的点的个数,在 \(\operatorname{LCA}\) 处计算答案。但是时间复杂度无法接受。

考虑长剖,计算答案只用枚举到轻链长,先对轻儿子做一遍 \(\text{Dirichlet}\) 后缀和,重儿子的信息直接继承上来。但是我们没法查询深度 \(\bmod k = i\) 的点的个数。

这是经典根号分治,\(\le \sqrt{n}\) 的加入时处理好,\(> \sqrt{n}\) 的暴力枚举。

然后总时间复杂度是 \(O(n \sqrt{n})\) 的。

别忘记计算点对是祖先后代关系的点。

code
// Problem: #33. 【UR #2】树上GCD
// Contest: UOJ
// URL: https://uoj.ac/problem/33
// Memory Limit: 256 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const int B = 300;

int n, pr[maxn], tot;
ll g[maxn], h[maxn];
vector<int> G[maxn];
bool vis[maxn];

int dep[maxn], mxd[maxn], son[maxn];

void dfs(int u) {
	mxd[u] = dep[u];
	for (int v : G[u]) {
		dep[v] = dep[u] + 1;
		dfs(v);
		if (mxd[v] > mxd[u]) {
			mxd[u] = mxd[v];
			son[u] = v;
		}
	}
}

vector<ll> f[maxn];

namespace DS {
	ll f[B + 5][B + 5], g[maxn];
	
	inline void add(int x, ll y) {
		g[x] += y;
		for (int i = 1; i <= B; ++i) {
			f[i][x % i] += y;
		}
	}
	
	inline void clear(int x) {
		g[x] = 0;
		for (int i = 1; i <= B; ++i) {
			f[i][x % i] = 0;
		}
	}
	
	inline ll query(int k, int x) {
		return f[k][x % k];
	}
}

void dfs2(int u) {
	for (int v : G[u]) {
		if (v == son[u]) {
			continue;
		}
		dfs2(v);
		f[v] = vector<ll>(1, 0);
		for (int i = dep[v]; i <= mxd[v]; ++i) {
			f[v].pb(DS::g[i]);
			DS::clear(i);
		}
	}
	if (son[u]) {
		dfs2(son[u]);
	}
	for (int v : G[u]) {
		if (v == son[u]) {
			continue;
		}
		int m = (int)f[v].size() - 1;
		for (int i = 1; i <= tot && pr[i] <= m; ++i) {
			for (int j = m / pr[i]; j; --j) {
				f[v][j] += f[v][pr[i] * j];
			}
		}
		for (int i = 1; i <= m; ++i) {
			if (i <= B) {
				g[i] += f[v][i] * DS::query(i, dep[u]);
			} else {
				ll s = 0;
				for (int j = dep[u]; j <= mxd[u]; j += i) {
					s += DS::g[j];
				}
				g[i] += f[v][i] * s;
			}
		}
		for (int i = 1; i <= tot && pr[i] <= m; ++i) {
			for (int j = 1; j <= m / pr[i]; ++j) {
				f[v][j] -= f[v][pr[i] * j];
			}
		}
		for (int i = 1; i <= m; ++i) {
			DS::add(dep[u] + i, f[v][i]);
		}
	}
	DS::add(dep[u], 1);
}

void solve() {
	scanf("%d", &n);
	for (int i = 2; i <= n; ++i) {
		if (!vis[i]) {
			pr[++tot] = i;
		}
		for (int j = 1; j <= tot && i * pr[j] <= n; ++j) {
			vis[i * pr[j]] = 1;
			if (i % pr[j] == 0) {
				break;
			}
		}
	}
	for (int i = 2, p; i <= n; ++i) {
		scanf("%d", &p);
		G[p].pb(i);
	}
	dfs(1);
	dfs2(1);
	for (int i = 1; i <= tot && pr[i] <= mxd[1]; ++i) {
		for (int j = 1; j * pr[i] <= mxd[1]; ++j) {
			g[j] -= g[j * pr[i]];
		}
	}
	for (int i = 1; i < n; ++i) {
		h[1] += DS::g[i];
		h[i + 1] -= DS::g[i];
	}
	for (int i = 1; i < n; ++i) {
		h[i] += h[i - 1];
		printf("%lld\n", g[i] + h[i]);
	}
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}