HDU7401 流量监控

发布时间 2023-11-05 21:41:57作者: came11ia

给定一颗 \(n\) 个节点的树。求:

  1. 有多少种匹配 \((a_1,b_{1}),\cdots,(a_{\frac{n}{2}},b_{\frac{n}{2}})\),使得对于每一对匹配 \((u,v)\),点 \(u\) 是点 \(v\) 的祖先。
  2. 对于一组合法匹配,定义其权值为这些匹配的交点个数。对于两组匹配 \((a,b),(c,d)\),它们之间有一个交点当且仅当这四个点在同一条到叶节点的链上且交错排列。求所有匹配的权值和。

答案对 \(998244353\) 取模,\(1 \leq n \leq 2000\)


第一问是简单的:考虑树形 DP,设 \(f_{u,i}\) 表示当前考虑到点 \(u\),子树内还有 \(i\) 个点未被匹配的方案数,答案就是 \(f_{1,0}\)。转移时我们先对 \(u\) 所有儿子的 \(f\) 背包合并,再考虑 \(u\) 是否往下匹配:

  • \(u\) 往下匹配,那么可以在 \(i\) 个点中任选一个,但这样未匹配的点数会减少 \(1\),即 \(f_{u,i-1} \gets f_{u,i} \times i\)
  • \(u\) 不往下匹配,那么未匹配的点数增加 \(1\),即 \(f_{u,i+1} \gets f_{u,i}\)

时间复杂度 \(\mathcal{O}(n^2)\)

对于第二问,我们考虑把贡献拆成比较方便计算的方式。对于一个交点,我们在其靠近叶子的匹配处统计它的贡献。具体来说,对于当前考虑的点 \(u\),如果我们选择往下匹配某个点 \(v\),那么 \((u,v)\) 路径上还没有被匹配的点就一定会和 \(u\) 的某个祖先匹配,这样会和 \((u,v)\) 产生 \(1\) 的贡献。即若 \(u\)\(v\) 匹配,产生的贡献是 \((u,v)\) 路径上未匹配的点数。

这样就可以 DP 了,设 \(g_{u,i}\) 表示当前考虑到点 \(u\),子树内还有 \(i\) 个点未被匹配,的所有方案的权值和。同时 \(u\) 往下找匹配的时候,我们还需要知道贡献和,因此再设 \(h_{u,i}\) 表示当前考虑到点 \(u\),子树内还有 \(i\) 个点未被匹配的所有方案中,这 \(i\) 个点到根的路径上未被匹配的点数和。

转移时都先将所有儿子的 DP 值背包合并,然后考虑 \(u\) 是否往下匹配:

\(g\)

  • \(u\) 往下匹配,那么可以在 \(i\) 个点中任选一个,所有方案的贡献和是 \(i \times g_{u,i} + h_{u,i}\),即 \(g_{u,i-1} \gets i \times g_{u,i} + h_{u,i}\)
  • \(u\) 不往下匹配,那么贡献不变,即 \(g_{u,i+1} \gets g_{u,i}\)

\(h\)

  • \(u\) 往下匹配。考虑对于某个固定的方案,贡献是如何变化的:\(u\) 在这 \(i\) 个点中任选一个 \(v\) 匹配,减少的贡献可以分成两部分,即 \(v\) 本身产生的贡献,以及对于 \(v\) 子树内所有未被匹配的节点,它们的贡献都会减少 \(1\)。第一类贡献的和显然是 \(h_{u,i}\),而根据子树大小和等于深度和,第二类贡献的和也是 \(h_{u,i}\)。即转移为 \(h_{u,i - 1} \gets h_{u,i} \times (i-2)\)
  • \(u\) 不往下匹配,那么对于每一种方案,所有 \(i\) 个点的贡献都会增加 \(1\),即 \(h_{u,i+1} \gets h_{u,i} + i \times f_{u,i}\)

注意一下转移顺序即可。总时间复杂度 \(\mathcal{O}(n^2)\)

code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
constexpr int N = 2e3 + 5, mod = 998244353;
bool Mbe;
void add(int &x, int y) {
	x = x + y >= mod ? x + y - mod : x + y;
}
int n;
vector <int> e[N];
int f[N][N], g[N][N], h[N][N], sz[N];
void dfs(int u, int ff) {
	if (ff && e[u].size() == 1) {
		sz[u] = f[u][1] = 1;
		return;
	}
	sz[u] = 0;
	int c = 0;
	for (int i = 0; i < e[u].size(); i++) {
		int v = e[u][i];
		if (v == ff) continue;
		dfs(v, u);
		c++;
		if (c == 1) {
			for (int i = 0; i <= sz[v]; i++) f[u][i] = f[v][i], g[u][i] = g[v][i], h[u][i] = h[v][i];
		} else {
			static int tmp[N];
			for (int i = 0; i <= sz[u]; i++)
				for (int j = 0; j <= sz[v]; j++) {
					add(tmp[i + j], 1LL * g[u][i] * f[v][j] % mod);
					add(tmp[i + j], 1LL * f[u][i] * g[v][j] % mod);
				}
			for (int i = 0; i <= sz[u] + sz[v]; i++) {
				g[u][i] = tmp[i];
				tmp[i] = 0;
			}
			
			for (int i = 0; i <= sz[u]; i++) 
				for (int j = 0; j <= sz[v]; j++) {
					add(tmp[i + j], 1LL * h[u][i] * f[v][j] % mod);
					add(tmp[i + j], 1LL * f[u][i] * h[v][j] % mod);
				}
			for (int i = 0; i <= sz[u] + sz[v]; i++) {
				h[u][i] = tmp[i];
				tmp[i] = 0;
			}
			
			for (int i = 0; i <= sz[u]; i++) 
				for (int j = 0; j <= sz[v]; j++) 
					add(tmp[i + j], 1LL * f[u][i] * f[v][j] % mod);
			
			for (int i = 0; i <= sz[u] + sz[v]; i++) {
				f[u][i] = tmp[i];
				tmp[i] = 0;
			}
		}
		sz[u] += sz[v];
	}

	static int tmp[N];
	for (int i = 0; i <= sz[u]; i++) {
		if (i >= 1) {
			add(tmp[i - 1], h[u][i]);
			add(tmp[i - 1], 1LL * g[u][i] * i % mod);
		}
		add(tmp[i + 1], g[u][i]);
	}
	for (int i = 0; i <= sz[u] + 1; i++) {
		g[u][i] = tmp[i];
		tmp[i] = 0;
	}
	
	for (int i = 0; i <= sz[u]; i++) {
		if (i >= 3) add(tmp[i - 1], 1LL * h[u][i] * (i - 2) % mod);
		add(tmp[i + 1], h[u][i]);
		add(tmp[i + 1], 1LL * f[u][i] * i % mod);
	}
	for (int i = 0; i <= sz[u] + 1; i++) {
		h[u][i] = tmp[i];
		tmp[i] = 0;
	}
	
	for (int i = 0; i <= sz[u]; i++) {
		if (i >= 1) add(tmp[i - 1], 1LL * f[u][i] * i % mod);
		add(tmp[i + 1], f[u][i]);
	}
	for (int i = 0; i <= sz[u] + 1; i++) {
		f[u][i] = tmp[i];
		tmp[i] = 0;
	}
	sz[u] += 1; 
}
void solve() {
	cin >> n;
	for (int i = 1; i < n; i++) {
		int u, v;
		cin >> u >> v;
		e[u].emplace_back(v);
		e[v].emplace_back(u);
	}
	dfs(1, 0);
	cout << f[1][0] << " " << g[1][0] << "\n";
	for (int i = 1; i <= n; i++) {
		e[i].clear();
		sz[i] = 0;
		for (int j = 0; j <= n; j++) f[i][j] = g[i][j] = h[i][j] = 0;
 	}
} 
bool Med;
int main() {
//	fprintf(stderr, "%.9lf\n", 1.0 * (&Mbe - &Med) / 1048576.0);
	ios :: sync_with_stdio(false);
	cin.tie(0), cout.tie(0);
	int t; cin >> t;
	while (t--) solve();
//	cerr << 1e3 * clock() / CLOCKS_PER_SEC << "ms\n";
	return 0; 
}