Atcoder ABC221F Diameter set

发布时间 2023-06-09 08:10:19作者: lhzawa

不难。

第一步肯定是求出直径 \(d\)

然后能发现 \(d\bmod 2 = 0\) 时很好求。
可以先任意找到一条直径,再找到这个直径的中点,则容易知道以这个中点为根,其中的每个子树的节点与中点经过的边数最大值为 \(\frac{d}{2}\)
所以能够得到每个子树内选两个点距离最大值也为 \(d - 2\),所以合法的方案只能为每个子树内选一个 \(\frac{d}{2}\) 距离的点或者不选,这样两段拼成 \(d\)
设一共有 \(k\) 个子树 ,于是可以对于 \(1\le i\le k\) 的每个子树于求出子树里面与中点距离为 \(\frac{d}{2}\) 的点的个数 \(cnt_i\)
则很容易求出总方案数,即每个子树都可以选一个或不选:\(\prod\limits_{i = 1}^k (cnt_i + 1)\);不合法的方案数也很好求,即只选了一个点或不选:\(\prod\limits_{i = 1}^k cnt_i + 1\);所以合法方案数也很好求啦:\(\prod\limits_{i = 1}^k (cnt_i + 1) - \sum\limits_{i - 1}^k cnt_i - 1\)

考虑 \(d\bmod 2 = 1\) 怎么求,因为这时候直径的中点在边上,刚刚找点就不行了。
直径中点在边上,那直接对每个边开一个虚点,既没改变树的形态距离也满足了 \(d\bmod 2 = 0\),且虚点也不会被算入答案。

// lhzawa(https://www.cnblogs.com/lhzawa/)
#include<bits/stdc++.h>
using namespace std;
const int N = 4e5 + 10;
const long long mod = 998244353;
int n;
vector<int> ev[N];
int dep[N];
void dfsdep(int u, int fa) {
	for (int v : ev[u]) {
		// printf("%d -> %d\n", u, v);
		if (v == fa) {
			continue;
		}
		dep[v] = dep[u] + 1;
		dfsdep(v, u);
	}
	return ;
}
int stk[N], top;
int fd = 0, d;
void dfsd(int u, int fa, int t) {
	if (! fd) {
		stk[++top] = u;
	}
	if (u == t) {
		fd = 1;
	}
	for (int v : ev[u]) {
		if (v == fa) {
			continue;
		}
		dfsd(v, u, t);
	}
	if (! fd) {
		top--;
	}
	return ;
}
int cnt[N];
void dfsdpu(int u, int fa, int top) {
	cnt[top] += (dep[u] == d / 2);
	for (int v : ev[u]) {
		if (v == fa) {
			continue;
		}
		dep[v] = dep[u] + 1;
		dfsdpu(v, u, top);
	}
	return ;
}
int main() {
	scanf("%d", &n);
	function<void (int, int)> add = [](int u, int v) -> void {
		ev[u].push_back(v);
		return ;
	};
	int m = n;
	for (int i = 1; i < n; i++) {
		int x, y;
		scanf("%d%d", &x, &y);
		m++, add(m, x), add(x, m), add(m, y), add(y, m);
	}
	dfsdep(1, 0);
	int s = 0;
	for (int i = 1; i <= m; i++) {
		// printf("%d ", dep[i]);
		s = (dep[i] > dep[s] ? i : s);
	}
	// printf("\n");
	dep[s] = 0;
	dfsdep(s, 0);
	int t = 0;
	for (int i = 1; i <= m; i++) {
		t = (dep[i] > dep[t] ? i : t);
	}
	dfsd(s, 0, t);
	d = dep[t];
	// printf("%d <-> %d = %d\n", s, t, d);
	int rt = stk[(top + 1) >> 1];
	// printf("rt = %d\n", rt);
	for (int u : ev[rt]) {
		dep[u] = 1;
		dfsdpu(u, rt, u);
	}
	long long c = 1, h = 1;
	for (int u : ev[rt]) {
		c = 1ll * c * (cnt[u] + 1) % mod, h += cnt[u];
	}
	printf("%lld\n", (c - h + mod) % mod);
	return 0;
}