CodeForces 1919E Counting Prefixes

发布时间 2024-01-09 18:22:06作者: zltzlt

洛谷传送门

CF 传送门

考虑一个很类似的。我们把正数和负数分开来考虑,最后用 \(0\) 连接一些连续段,形如 \(0 - \text{正} - 0 - \text{正} - 0 - \text{负}\)

先考虑正数。设 \(f_{i, j}\) 为考虑了 \(\ge i\) 的正数,形成了 \(j\) 个连续段的方案数。设 \(i\) 的出现次数为 \(c_i\)

那么之前的每个段两端都需要接一个 \(i\) 下来,两段之间也可以只用一个 \(i\) 连接。

特别地,如果已经考虑到了结尾位置 \(n\),右端不用接数。于是我们状态再记一个 \(f_{i, j, 0/1}\) 表示包含位置 \(n\) 的段是否出现。

那么对于 \(f_{i + 1, j, 0}\) 的转移,新的段数 \(k = c_i - j\) 可以直接被计算出来。转移系数是 \(c_i\) 个数分配给 \(j + 1\) 个空的插板。我们有:

\[f_{i, k, 0} \gets f_{i + 1, j, 0} \times \binom{c_i - 1}{j} \]

\[f_{i, k, 1} \gets f_{i + 1, j, 0} \times \binom{c_i - 1}{j} \]

对于 \(f_{i + 1, j, 1}\) 的转移,新的段数为 \(k = c_i - j + 1\)。有转移:

\[f_{i, k, 1} \gets f_{i + 1, j, 1} \times \binom{c_i - 1}{j - 1} \]

同样地考虑负数,设 \(g_{i, j}\) 为考虑了 \(\le -i\) 的负数,形成了 \(j\) 个连续段的方案数。转移类似。

计算答案,只用枚举正数的段数 \(i\),如果 \(0\) 不为 \(n\) 的前缀和,那么负数的段数为 \(c_0 - i\),否则为 \(c_0 - i - 1\)。讨论一下 \(n\) 的前缀和是给 \(0\),正数还是负数,再乘上给段选择位置的系数。所以:

\[ans = \sum\limits_{i = 0}^{c_0} f_{1, i, 1} \times g_{1, c_0 - i, 0} \times \binom{c_0 - 1}{i - 1} + f_{1, i, 0} \times g_{1, c_0 - i, 1} \times \binom{c_0 - 1}{i} + f_{1, i, 0} \times g_{1, c_0 - i - 1, 0} \times \binom{c_0 - 1}{i} \]

直接计算 \(f, g\) 复杂度为 \(O(n^2)\)。但是注意到 \(f, g\) 初值只有 \(O(1)\) 个位置有值,每次转移一个位置会转移到固定的另一个位置,最后一维只会进行 \(0 \to 1\) 的转移不会进行 \(1 \to 0\) 的转移。所以 dp 数组有值的位置数是 \(\color{red}{O(1)}\) 的。如果我们用 unordered_map 把有值的位置存下来,复杂度将是 \(\color{red}{O(n)}\)

code
// Problem: E. Counting Prefixes
// Contest: Codeforces - Hello 2024
// URL: https://codeforces.com/contest/1919/problem/E
// Memory Limit: 256 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 5050;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, a[maxn], fac[maxn], ifac[maxn], b[maxn], c[maxn];
unordered_map<ll, ll> f[2][2], g[2][2];

inline ll C(ll n, ll m) {
	if (n < m || n < 0 || m < 0) {
		return 0;
	} else {
		return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
	}
}

void solve() {
	scanf("%lld", &n);
	for (int i = 0; i <= n; ++i) {
		b[i] = c[i] = 0;
	}
	ll L = 0, R = 0;
	++b[0];
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
		L = min(L, a[i]);
		R = max(R, a[i]);
		if (a[i] >= 0) {
			++b[a[i]];
		} else {
			++c[-a[i]];
		}
	}
	a[0] = 0;
	sort(a, a + n + 1);
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[n] = qpow(fac[n], mod - 2);
	for (int i = n - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
	for (int i = 0; i < 2; ++i) {
		for (int j = 0; j < 2; ++j) {
			f[i][j].clear();
			g[i][j].clear();
		}
	}
	if (R) {
		f[1][0][b[R]] = f[1][1][b[R]] = 1;
		int o = 0;
		for (int i = R - 1; i; --i, o ^= 1) {
			f[o][0].clear();
			f[o][1].clear();
			for (pii p : f[o ^ 1][0]) {
				ll j = p.fst, v = p.scd;
				if (b[i] >= j + 1) {
					int nj = b[i] - j;
					f[o][0][nj] = (f[o][0][nj] + v * C(b[i] - 1, j)) % mod;
					f[o][1][nj] = (f[o][1][nj] + v * C(b[i] - 1, j)) % mod;
				}
			}
			for (pii p : f[o ^ 1][1]) {
				ll j = p.fst, v = p.scd;
				if (b[i] >= j) {
					int nj = b[i] - j + 1;
					f[o][1][nj] = (f[o][1][nj] + v * C(b[i] - 1, j - 1)) % mod;
				}
			}
		}
	} else {
		f[0][0][0] = 1;
	}
	if (L) {
		L = -L;
		g[1][0][c[L]] = g[1][1][c[L]] = 1;
		int o = 0;
		for (int i = L - 1; i; --i, o ^= 1) {
			g[o][0].clear();
			g[o][1].clear();
			for (pii p : g[o ^ 1][0]) {
				ll j = p.fst, v = p.scd;
				if (c[i] >= j + 1) {
					int nj = c[i] - j;
					g[o][0][nj] = (g[o][0][nj] + v * C(c[i] - 1, j)) % mod;
					g[o][1][nj] = (g[o][1][nj] + v * C(c[i] - 1, j)) % mod;
				}
			}
			for (pii p : g[o ^ 1][1]) {
				ll j = p.fst, v = p.scd;
				if (c[i] >= j) {
					int nj = c[i] - j + 1;
					g[o][1][nj] = (g[o][1][nj] + v * C(c[i] - 1, j - 1)) % mod;
				}
			}
		}
	} else {
		g[0][0][0] = 1;
	}
	ll ans = 0;
	for (int i = 0; i <= b[0]; ++i) {
		ans = (ans + f[R & 1][1][i] * g[L & 1][0][b[0] - i] % mod * C(b[0] - 1, i - 1)) % mod;
		ans = (ans + f[R & 1][0][i] * g[L & 1][1][b[0] - i] % mod * C(b[0] - 1, i)) % mod;
		if (i < b[0]) {
			ans = (ans + f[R & 1][0][i] * g[L & 1][0][b[0] - i - 1] % mod * C(b[0] - 1, i)) % mod;
		}
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}