【题解】[ABC248G] GCD cost on the tree

发布时间 2023-04-21 03:47:12作者: kymru

「八云紫」无数次痛苦地询问,为什么我们还活着?
……而「古明地恋」从不会回答。

恋恋闭上了觉之眼。

思路

容斥 + dp.

\(\gcd\) 相关,考虑 \(\mu\) 反演或者 \(\varphi\) 反演。

本质上都和容斥差不多,不如直接一步到位考虑容斥。

把权值拆成 \(\gcd\) 和对应的方案数两部分,考虑求对应的方案数。

\(f[v]\) 表示 \(\gcd = v\) 时的合法路径总数,\(g[v]\) 表示 \(v \mid \gcd\) 时的合法路径总数,显然有 \(f[v] = g[v] - \sum\limits_{i = 2, iv \leq \max(a_i)} f[iv]\).

于是我们只需要求出 \(g\).

一个想法是考虑保留所有为 \(v\) 倍数的边,然后在形成的连通块中做一个 dp 统计答案。

类似长剖一样,先令 \(res[u]\) 表示 \(u\) 子树中所有非孤点路径的总点数。

然后考虑每个点对于答案的贡献:长度乘以贡献次数。用 \(len[u]\) 表示 \(u\) 子树中所有一端为 \(u\) 的非孤点路径总点数,\(cnt[u]\) 表示所有一端为 \(u\) 的非孤点路径数量。

转移比较好推:

考虑所有 \(u\) 的子结点 \(v\).

分别算出两边的贡献累加:\(res[u] = res[u] + cnt[u] \times len[v] + len[u] \times cnt[v]\).

\(v\) 相连的路径可以在顶端加上 \(u\),此时这 \(cnt[v]\) 条路径的长度都会加一,于是有 \(len[u] = len[u] + len[v] + cnt[v]\).

同理可得 \(cnt[u] = cnt[u] + cnt[v]\).

考虑枚举 \(\gcd\) 的时候给合法的结点打标记,这样每次只需要遍历连通块 dp。每个结点只会被标记 \(\sigma(a_i)\) 次,所以时间复杂度是 \(O(V \log V + n \max(\sigma(a_i)))\).

代码

#include <cstdio>
#include <vector>
using namespace std;

#define il inline

const int maxn = 1e5 + 5;
const int mod = 998244353;

int n, m;
int a[maxn], f[maxn], g[maxn];
int len[maxn], cnt[maxn], res[maxn];
bool vis[maxn];
vector<int> gr[maxn], nd[maxn];

il int max(const int &a, const int &b) { return (a >= b ? a : b); }

il int read()
{
	int res = 0;
	char ch = getchar();
	while ((ch < '0') || (ch > '9')) ch = getchar();
	while ((ch >= '0') && (ch <= '9')) res = res * 10 + ch - '0', ch = getchar();
	return res;
}

void dfs(int u, int val)
{
	cnt[u] = len[u] = 1, res[u] = 0, vis[u] = false;
	for (int v : gr[u])
	{
		if (vis[v])
		{
			dfs(v, val);
			res[u] = (res[u] + 1ll * cnt[u] * len[v] % mod + 1ll * len[u] * cnt[v] % mod) % mod;
			len[u] = (1ll * len[u] + len[v] + cnt[v]) % mod;
			cnt[u] = (cnt[u] + cnt[v]) % mod;
		}
	}
	g[val] = (g[val] + res[u]) % mod;
}

int main()
{
	n = read();
	for (int i = 1; i <= n; i++) a[i] = read(), m = max(m, a[i]);
	for (int i = 1, u, v; i <= n - 1; i++)
	{
		u = read(), v = read();
		gr[u].push_back(v), gr[v].push_back(u);
	}
	for (int i = 1; i <= n; i++)
	{
		for (int j = 1; j * j <= a[i]; j++)
		{
			if (a[i] % j == 0)
			{
				nd[j].push_back(i);
				if (j * j != a[i]) nd[a[i] / j].push_back(i);
			}
		}
	}
	for (int i = m; i; i--)
	{
		if (nd[i].size())
		{
			for (int v : nd[i]) vis[v] = true;
			for (int v : nd[i])
				if (vis[v]) dfs(v, i);
			f[i] = g[i];
			for (int j = 2 * i; j <= m; j += i) f[i] = (f[i] - f[j] + mod) % mod;
		}
	}
	int ans = 0;
	for (int i = 1; i <= m; i++) ans = (ans + 1ll * i * f[i] % mod) % mod;
	printf("%d\n", ans);
	return 0;
}