AtCoder Grand Contest 058 F Authentic Tree DP

发布时间 2023-09-14 20:08:16作者: zltzlt

洛谷传送门

AtCoder 传送门

人生中第一道 AtCoder 问号题。

\(P = 998244353\)

注意到 \(f(T)\) 的定义式中,\(\frac{1}{n}\) 大概是启示我们转成概率去做。发现若把 \(\frac{1}{n}\) 换成 \(\frac{1}{n - 1}\) 答案就是 \(1\),所以 \(\frac{1}{n}\) 大概是要转成点数之类的。

考虑把边转成点,若原树存在边 \((u, v)\),就新建点 \(p\),断开 \((u, v)\),连边 \((u, p), (p, v)\),称 \(p\) 点为边点。但是这样点数就变成 \(2n - 1\) 了。

但是!考虑再挂 \(P - 1\) 个叶子到 \(p\) 下面,点数就变成 \(n + (n - 1) \times P\)。模意义下 \(\frac{1}{n} \equiv \frac{1}{n + (n - 1)P} \pmod P\)

我们可以把原问题转化成在新树上的这个问题:

随机生成一个排列 \(p_{1 \sim n + (n - 1)P}\),求所有边点\(p\) 值大于其所有邻居的 \(p\) 值的概率。

证明大概就是考虑不断取树上的最大值,取 \(n - 1\) 次,每次取边点的概率在模意义下等于 \(\frac{1}{n}\),转移式也与原题相同。

考虑给树上的边定向,从小连到大,那么就是要求每条边的起点都是边点的概率。像这样(图来自 kkio):

随便定一个根。发现有些 \(p \to u\) 的边从下连到上看起来不顺眼,考虑容斥。那么所有下连到上的边可以选择上连到下或者下连到上(断掉)。设有 \(k\) 条原本是下连到上的边,容斥系数为 \((-1)^k\)

那么我们可以设 \(f_{u, i}\) 表示,\(u\) 的子树中以 \(u\) 为根的外向树大小模 \(P\) 意义下等于 \(i\),容斥系数乘概率之和。

对于一个边点 \(p\),在它原树上对应的边 \((u, v)\) 上统计贡献。

考虑若边 \((u, p)\) 从上到下,那 \(v\) 子树中以 \(v\) 为根的的外向树可以直接接到 \(u\) 下面,\(p\) 直接树形背包合并,\(f_{u, i + j} \gets -f'_{u, i} \times \frac{f_{v, j}}{j}\)。乘 \(\frac{1}{j}\) 是计入边点作为外向树的根的概率,乘 \(-1\) 是计入容斥系数。若边 \((u, p)\) 断开,那么 \(f_{u, i} \gets f'_{u, i} \times \sum\limits_{j = 1}^{sz_v} \frac{f_{v, j}}{j}\)。最后还要 \(f_{u, i} \gets \frac{f_{u, i}}{i}\),表示 \(u\) 点作为外向树的根的概率。

时间复杂度 \(O(n^2)\)

code
// Problem: F - Authentic Tree DP
// Contest: AtCoder - AtCoder Grand Contest 058
// URL: https://atcoder.jp/contests/agc058/tasks/agc058_f
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 5050;
const ll mod = 998244353;

ll n, f[maxn][maxn], g[maxn], h[maxn], inv[maxn], sz[maxn];
vector<int> G[maxn];

void dfs(int u, int fa) {
	sz[u] = 1;
	f[u][1] = 1;
	for (int v : G[u]) {
		if (v == fa) {
			continue;
		}
		dfs(v, u);
		for (int i = 1; i <= sz[u] + sz[v]; ++i) {
			h[i] = f[u][i];
			f[u][i] = 0;
		}
		for (int i = 1; i <= sz[u]; ++i) {
			for (int j = 1; j <= sz[v]; ++j) {
				f[u][i + j] = (f[u][i + j] - h[i] * f[v][j] % mod * inv[j] % mod + mod) % mod;
			}
		}
		for (int i = 1; i <= sz[u]; ++i) {
			f[u][i] = (f[u][i] + h[i] * g[v] % mod) % mod;
		}
		sz[u] += sz[v];
	}
	for (int i = 1; i <= sz[u]; ++i) {
		f[u][i] = f[u][i] * inv[i] % mod;
		g[u] = (g[u] + f[u][i] * inv[i] % mod) % mod;
	}
}

void solve() {
	scanf("%lld", &n);
	inv[1] = 1;
	for (int i = 2; i <= n; ++i) {
		inv[i] = (mod - mod / i) * inv[mod % i] % mod;
	}
	for (int i = 1, u, v; i < n; ++i) {
		scanf("%d%d", &u, &v);
		G[u].pb(v);
		G[v].pb(u);
	}
	dfs(1, -1);
	ll ans = 0;
	for (int i = 1; i <= n; ++i) {
		ans = (ans + f[1][i]) % mod;
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}