比较平凡的一个容斥。
考虑把问题转化成,求 \(\forall i \in [1, n], r_i \ne i \land r_i \ne p_i\) 的 \(r\) 方案数。考虑到不弱于错排,所以容斥。设钦定 \(i\) 个 \(r_i\) 取了 \(i, p_i\) 中的一个的方案数为 \(f_i\),其余任意,那么:
\[ans = \sum\limits_{i = 0}^n (-1)^i f_i (n - i)!
\]
考虑求 \(f_i\)。连边 \(i \to p_i\),对每个环单独考虑。设第 \(i\) 个环点数为 \(s_i\)。这个东西抽象到环上就相当于,每一个点,可以不选,可以选择选它自己,也可以选择选它在环上的下一个点。设 \(h_{i, j, 0/1}\) 表示当前考虑到环上第 \(i\) 个点,有 \(j\) 个点选了,这个点是否选择第 \(i + 1\) 个点。枚举第一个点的状态,然后直接做即可。合并到 \(f_i\),就是做一个加法卷积,暴力即可。
总时间复杂度 \(O(\sum s_i^2 + n \sum s_i) = O(n^2)\)。
code
// Problem: G - Three Permutations
// Contest: AtCoder - AtCoder Beginner Contest 214
// URL: https://atcoder.jp/contests/abc214/tasks/abc214_g
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 3030;
const ll mod = 1000000007;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, a[maxn], b[maxn], c[maxn], fac[maxn], ifac[maxn], fa[maxn], sz[maxn], f[maxn], g[maxn], h[maxn][maxn][2];
inline ll C(ll n, ll m) {
if (n < m || n < 0 || m < 0) {
return 0;
} else {
return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
}
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
inline void merge(int x, int y) {
x = find(x);
y = find(y);
if (x != y) {
fa[x] = y;
sz[y] += sz[x];
}
}
void solve() {
scanf("%lld", &n);
for (int i = 1; i <= n; ++i) {
fa[i] = i;
sz[i] = 1;
}
fac[0] = 1;
for (int i = 1; i <= n; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[n] = qpow(fac[n], mod - 2);
for (int i = n - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
c[a[i]] = i;
}
for (int i = 1; i <= n; ++i) {
scanf("%lld", &b[i]);
b[i] = c[b[i]];
merge(i, b[i]);
}
int s = 0;
f[0] = 1;
for (int _ = 1; _ <= n; ++_) {
if (fa[_] == _) {
int m = sz[_];
for (int i = 0; i <= n; ++i) {
g[i] = f[i];
f[i] = 0;
}
if (m == 1) {
f[0] = g[0];
for (int i = 1; i <= n; ++i) {
f[i] = (g[i] + g[i - 1]) % mod;
}
++s;
continue;
}
for (int x = 0; x <= 2; ++x) {
for (int i = 0; i <= m; ++i) {
for (int j = 0; j <= i; ++j) {
for (int k = 0; k < 2; ++k) {
h[i][j][k] = 0;
}
}
}
h[1][x >= 1][x == 2] = 1;
for (int i = 2; i <= m; ++i) {
for (int j = 0; j < i; ++j) {
for (int p = 0; p <= 1; ++p) {
for (int q = 0; q <= 2; ++q) {
if (p && q == 1) {
continue;
}
int nj = j + (q >= 1), np = (q == 2);
h[i][nj][np] = (h[i][nj][np] + h[i - 1][j][p]) % mod;
}
}
}
}
for (int i = 0; i <= m; ++i) {
for (int j = 0; j <= s; ++j) {
ll val = h[m][i][0];
if (x != 1) {
val = (val + h[m][i][1]) % mod;
}
f[i + j] = (f[i + j] + val * g[j] % mod) % mod;
}
}
}
s += m;
}
}
ll ans = 0;
for (int i = 0; i <= n; ++i) {
ans = (ans + ((i & 1) ? mod - 1 : 1) * f[i] % mod * fac[n - i] % mod) % mod;
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}
- Permutations Beginner AtCoder Contest Threepermutations beginner atcoder contest permutations atcoder regular contest contest programming beginner atcoder beginner atcoder contest 296 beginner atcoder contest 295 beginner atcoder contest abcde beginner atcoder contest 335 beginner atcoder contest 328 beginner atcoder contest 334 beginner atcoder contest 332