AtCoder Beginner Contest 313 Ex Group Photo

发布时间 2023-09-21 18:04:38作者: zltzlt

洛谷传送门

AtCoder 传送门

考虑若重排好了 \(a\),如何判断可不可行。显然只用把 \(b\) 排序,把 \(\min(a_{i - 1}, a_i)\) 也排序(定义 \(a_0 = a_{n + 1} = +\infty\)),按顺序逐个判断是否大于即可。

这启示我们将 \(\min(a_{i - 1}, a_i)\) 排序考虑。考虑从大到小加入 \(a_i\),那么加入一个 \(a_i\),和它相邻的 \(\min(a_{i - 1}, a_i)\) 一定是 \(a_i\)。每个时刻已经加入的 \(a_i\) 形成了若干个连续段,考虑连续段 dp:

\(f_{i, j}\) 为考虑了 \(a\) 中前 \(i\) 大的数,当前形成了 \(j\) 个连续段。那么有 \(i - j\) 对相邻的数。

初值为 \(f_{2, 2} = 1\),表示一开始加入了 \(a_0\)\(a_{n + 1}\)

有转移:

  • \(a_i\) 新开一个段,这个段可以在两段之间的空隙中:\(f_{i, j + 1} \gets f_{i - 1, j} \times (j - 1)\)
  • \(a_i\) 接入一个段的开头或末尾,需要满足 \(a_i < b_{i - j}\),此时有:\(f_{i, j} \gets f_{i - 1, j} \times (2j - 2)\)
  • \(a_i\) 塞进两个相邻段的空隙中,然后合并这两个段,需要满足 \(a_i < b_{i - j + 1}\),此时有:\(f_{i, j - 1} \gets f_{i - 1, j} \times (j - 1)\)

答案为 \(f_{n + 2, 0}\)

时间复杂度 \(O(n^2)\)

code
// Problem: Ex - Group Photo
// Contest: AtCoder - AtCoder Beginner Contest 313
// URL: https://atcoder.jp/contests/abc313/tasks/abc313_h
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 5050;
const ll mod = 998244353;

ll n, a[maxn], b[maxn], f[maxn][maxn];

inline void upd(ll &x, ll y) {
	((x += y) >= mod) && (x -= mod);
}

void solve() {
	scanf("%lld", &n);
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
	}
	for (int i = 1; i <= n + 1; ++i) {
		scanf("%lld", &b[i]);
	}
	sort(a + 1, a + n + 1, greater<ll>());
	sort(b + 1, b + n + 2, greater<ll>());
	f[2][2] = 1;
	for (int i = 3; i <= n + 2; ++i) {
		for (int j = 1; j < i; ++j) {
			if (!f[i - 1][j]) {
				continue;
			}
			upd(f[i][j + 1], f[i - 1][j] * (j - 1) % mod);
			if (a[i - 2] < b[i - j]) {
				upd(f[i][j], f[i - 1][j] * (j * 2 - 2) % mod);
			}
			if (i - j <= n && a[i - 2] < b[i - j + 1]) {
				upd(f[i][j - 1], f[i - 1][j] * (j - 1) % mod);
			}
		}
	}
	printf("%lld\n", f[n + 2][1]);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}