AtCoder Beginner Contest 214 G Three Permutations

发布时间 2023-06-01 18:50:48作者: zltzlt

洛谷传送门

AtCoder 传送门

比较平凡的一个容斥。

考虑把问题转化成,求 \(\forall i \in [1, n], r_i \ne i \land r_i \ne p_i\)\(r\) 方案数。考虑到不弱于错排,所以容斥。设钦定 \(i\)\(r_i\) 取了 \(i, p_i\) 中的一个的方案数为 \(f_i\),其余任意,那么:

\[ans = \sum\limits_{i = 0}^n (-1)^i f_i (n - i)! \]

考虑求 \(f_i\)。连边 \(i \to p_i\),对每个环单独考虑。设第 \(i\) 个环点数为 \(s_i\)。这个东西抽象到环上就相当于,每一个点,可以不选,可以选择选它自己,也可以选择选它在环上的下一个点。设 \(h_{i, j, 0/1}\) 表示当前考虑到环上第 \(i\) 个点,有 \(j\) 个点选了,这个点是否选择第 \(i + 1\) 个点。枚举第一个点的状态,然后直接做即可。合并到 \(f_i\),就是做一个加法卷积,暴力即可。

总时间复杂度 \(O(\sum s_i^2 + n \sum s_i) = O(n^2)\)

code
// Problem: G - Three Permutations
// Contest: AtCoder - AtCoder Beginner Contest 214
// URL: https://atcoder.jp/contests/abc214/tasks/abc214_g
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 3030;
const ll mod = 1000000007;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, a[maxn], b[maxn], c[maxn], fac[maxn], ifac[maxn], fa[maxn], sz[maxn], f[maxn], g[maxn], h[maxn][maxn][2];

inline ll C(ll n, ll m) {
	if (n < m || n < 0 || m < 0) {
		return 0;
	} else {
		return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
	}
}

int find(int x) {
	return fa[x] == x ? x : fa[x] = find(fa[x]);
}

inline void merge(int x, int y) {
	x = find(x);
	y = find(y);
	if (x != y) {
		fa[x] = y;
		sz[y] += sz[x];
	}
}

void solve() {
	scanf("%lld", &n);
	for (int i = 1; i <= n; ++i) {
		fa[i] = i;
		sz[i] = 1;
	}
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[n] = qpow(fac[n], mod - 2);
	for (int i = n - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
		c[a[i]] = i;
	}
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &b[i]);
		b[i] = c[b[i]];
		merge(i, b[i]);
	}
	int s = 0;
	f[0] = 1;
	for (int _ = 1; _ <= n; ++_) {
		if (fa[_] == _) {
			int m = sz[_];
			for (int i = 0; i <= n; ++i) {
				g[i] = f[i];
				f[i] = 0;
			}
			if (m == 1) {
				f[0] = g[0];
				for (int i = 1; i <= n; ++i) {
					f[i] = (g[i] + g[i - 1]) % mod;
				}
				++s;
				continue;
			}
			for (int x = 0; x <= 2; ++x) {
				for (int i = 0; i <= m; ++i) {
					for (int j = 0; j <= i; ++j) {
						for (int k = 0; k < 2; ++k) {
							h[i][j][k] = 0;
						}
					}
				}
				h[1][x >= 1][x == 2] = 1;
				for (int i = 2; i <= m; ++i) {
					for (int j = 0; j < i; ++j) {
						for (int p = 0; p <= 1; ++p) {
							for (int q = 0; q <= 2; ++q) {
								if (p && q == 1) {
									continue;
								}
								int nj = j + (q >= 1), np = (q == 2);
								h[i][nj][np] = (h[i][nj][np] + h[i - 1][j][p]) % mod;
							}
						}
					}
				}
				for (int i = 0; i <= m; ++i) {
					for (int j = 0; j <= s; ++j) {
						ll val = h[m][i][0];
						if (x != 1) {
							val = (val + h[m][i][1]) % mod;
						}
						f[i + j] = (f[i + j] + val * g[j] % mod) % mod;
					}
				}
			}
			s += m;
		}
	}
	ll ans = 0;
	for (int i = 0; i <= n; ++i) {
		ans = (ans + ((i & 1) ? mod - 1 : 1) * f[i] % mod * fac[n - i] % mod) % mod;
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}