CodeForces 715E Complete the Permutations

发布时间 2023-09-24 21:07:09作者: zltzlt

洛谷传送门

CF 传送门

最小交换次数等于 \(n - \text{环数}\)。所以题目要我们统计把 \(p, q\) 补全成排列,连边 \(p_i \to q_i\),环数 \(= i\) 的方案数。

考虑把边根据 \(p_i, q_i\) 的是否已知状态分成四类:

  1. \(p \to q\)
  2. \(p \to 0\)
  3. \(0 \to q\)
  4. \(0 \to 0\)

注意若存在 \(p \to 0, 0 \to q\)\(p = q\),我们把它们合并成一个 \(4\) 类边。

对于 \(1\) 类边直接把它缩成一个点,记录一下是否形成环即可。

对于剩下的边,设 \(2\) 类边数量为 \(n_1\)\(3\) 类边数量为 \(n_2\)\(4\) 类边数量为 \(n_3\)

对于 \(2\) 类边,我们考虑钦定一些边形成环,剩下的边接到 \(4\) 类边去,因为 \(0 \to 0\)\(p \to 0\) 合并还是一条 \(0 \to 0\)

\(f_i\)\(2\) 类边形成 \(i\) 个环的方案数。枚举钦定了 \(j\) 条边形成环,有:

\[f_i = \sum\limits_{j = i}^{n_1} \binom{n_1}{j} \begin{bmatrix} j \\ i \end{bmatrix} (n1 - j + n3 - 1)^{\underline{n1 - j}} \]

最后乘的下降幂意义是,一条边 \(p_1 \to 0\) 可以和之后的还没处理的边 \(p_2 \to 0\) 合并变成 \(p_1 \to 0\),也可以直接和一条 \(0 \to 0\) 边合并。

需要特判 \(n_3 = 0\),此时 \(f_i = \begin{bmatrix} n_1 \\ i \end{bmatrix}\)

再设 \(g_i\)\(3\) 类边形成 \(i\) 个环的方案数。计算方法和上面一样,把 \(n_1\) 改成 \(n_2\) 即可。

最后设 \(h_i\)\(4\) 类边形成 \(i\) 个环的方案数,有:

\[h_i = \begin{bmatrix} n_3 \\ i \end{bmatrix} n_3! \]

最后乘 \(n_3!\) 是因为我们填排列的时候可以任意排列这些边的顺序。

\(f, g, h\) 做加法卷积即可。

时间复杂度 \(O(n^2)\)

code
// Problem: E. Complete the Permutations
// Contest: Codeforces - Codeforces Round 372 (Div. 1)
// URL: https://codeforces.com/problemset/problem/715/E
// Memory Limit: 256 MB
// Time Limit: 5000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 260;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res; 
}

ll n, a[maxn], b[maxn], c[maxn], fa[maxn], fac[maxn], ifac[maxn], S[maxn][maxn];
ll f[maxn], g[maxn], h[maxn], d[maxn], ans[maxn];

int find(int x) {
	return fa[x] == x ? x : fa[x] = find(fa[x]);
}

inline bool merge(int x, int y) {
	x = find(x);
	y = find(y);
	if (x != y) {
		fa[x] = y;
		return 1;
	} else {
		return 0;
	}
}

inline ll C(ll n, ll m) {
	if (n < m || n < 0 || m < 0) {
		return 0;
	} else {
		return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
	}
}

void solve() {
	scanf("%lld", &n);
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
		fa[i] = i;
	}
	int cnt = 0;
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &b[i]);
		if (a[i] && b[i] && !merge(a[i], b[i])) {
			++cnt;
		}
	}
	for (int i = 1; i <= n; ++i) {
		if (a[i]) {
			a[i] = find(a[i]);
		}
		if (b[i]) {
			b[i] = find(b[i]);
		}
	}
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[n] = qpow(fac[n], mod - 2);
	for (int i = n - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
	S[0][0] = 1;
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= i; ++j) {
			S[i][j] = (S[i - 1][j] * (i - 1) % mod + S[i - 1][j - 1]) % mod;
		}
	}
	int n1 = 0, n2 = 0, n3 = 0;
	for (int i = 1; i <= n; ++i) {
		if (a[i] && b[i]) {
			continue;
		}
		if (a[i] && !b[i]) {
			++c[a[i]];
			++n1;
		}
		if (!a[i] && b[i]) {
			++c[b[i]];
			++n2;
		}
		if (!a[i] && !b[i]) {
			++n3;
		}
	}
	for (int i = 1; i <= n; ++i) {
		if (c[i] == 2) {
			--n1;
			--n2;
			++n3;
		}
	}
	if (n3) {
		for (int i = 0; i <= n1; ++i) {
			for (int j = i; j <= n1; ++j) {
				f[i] = (f[i] + C(n1, j) * S[j][i] % mod * fac[n1 - j + n3 - 1] % mod * ifac[n3 - 1] % mod) % mod;
			}
		}
		for (int i = 0; i <= n2; ++i) {
			for (int j = i; j <= n2; ++j) {
				g[i] = (g[i] + C(n2, j) * S[j][i] % mod * fac[n2 - j + n3 - 1] % mod * ifac[n3 - 1] % mod) % mod;
			}
		}
	} else {
		for (int i = 0; i <= n1; ++i) {
			f[i] = S[n1][i];
		}
		for (int i = 0; i <= n2; ++i) {
			g[i] = S[n2][i];
		}
	}
	for (int i = 0; i <= n3; ++i) {
		h[i] = S[n3][i] * fac[n3] % mod;
	}
	// for (int i = 0; i <= n1; ++i) {
		// printf("%lld ", f[i]);
	// }
	// putchar('\n');
	// for (int i = 0; i <= n2; ++i) {
		// printf("%lld ", g[i]);
	// }
	// putchar('\n');
	// for (int i = 0; i <= n3; ++i) {
		// printf("%lld ", h[i]);
	// }
	// putchar('\n');
	for (int i = 0; i <= n1; ++i) {
		for (int j = 0; j <= n2; ++j) {
			d[i + j] = (d[i + j] + f[i] * g[j] % mod) % mod;
		}
	}
	for (int i = 0; i <= n1 + n2; ++i) {
		for (int j = 0; j <= n3; ++j) {
			ans[i + j] = (ans[i + j] + d[i] * h[j] % mod) % mod;
		}
	}
	// for (int i = 0; i <= n1 + n2 + n3; ++i) {
		// printf("%lld ", ans[i]);
	// }
	// putchar('\n');
	for (int i = 0; i < n; ++i) {
		int k = n - i;
		if (k < cnt) {
			printf("0 ");
		} else {
			printf("%lld ", ans[k - cnt]);
		}
	}
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}