AtCoder Beginner Contest 288 Ex A Nameless Counting Problem

发布时间 2023-10-03 18:21:01作者: zltzlt

洛谷传送门

AtCoder 传送门

考虑到规定单调不降比较难搞。先设 \(g_t\) 为长度为 \(t\) 的满足条件的序列个数(可重且有顺序)。求这个可以设个 dp,\(f_{d, i}\) 表示考虑到从高到低第 \(d\) 位,当前 \(t\) 个数中有 \(i\) 个仍然顶上界,并且之前的位都满足异或分别等于 \(X\) 的限制。转移枚举从这一位开始不顶上界的数的个数即可。

现在考虑去重。设 \(f_t\) 为长度为 \(t\) 的满足条件的序列个数(不可重且无顺序)。注意到 \(x \oplus x = 0\),所以考虑枚举出现奇数次的数的总出现次数 \(i\),出现奇数次的数的种类数 \(j\),出现偶数次的数的种类数 \(k\)。设 \(\text{odd}(i, j)\)\(i\) 个可区分的元素划分到 \(j\) 个不可区分的集合中,且每个集合的大小都是奇数的方案数,对称地设一个 \(\text{even}(i, j)\)\(i\) 个可区分的元素划分到 \(j\) 个不可区分的集合中,且每个集合的大小都是偶数的方案数。那么我们有:

\[f_t = \sum\limits_{i = 0}^t \sum\limits_{j = 0}^i \binom{t}{i} \text{odd}(i, j) \sum\limits_{k = 0}^{t - i} \text{even}(t - i, k) (m + 1 - j)^{\underline k} g_j \]

从小到大算 \(f_t\) 即可。容易提前计算 \(h_{i, j} = \sum\limits_{k = 0}^i \text{even}(i, k) (m + 1 - j)^{\underline k}\) 做到 \(O(n^3)\) 计算。

现在我们知道了 \(f_t\) 了。因为 \(f_t\) 是有顺序的,所以计算答案时要先除一个 \(t!\) 转成无序。因为 \(f_t\) 不可重,所以我们枚举添加了 \(2i\) 个数,从 \(m\) 个数中可重地选出 \(2i\) 个数的方案数通过插板法可以算出是 \(\binom{m + i}{i}\),然后有:

\[\sum\limits_{i = 0}^{\left\lfloor\frac{n}{2}\right\rfloor} \frac{f_{n - 2i}}{(n - 2i)!} \binom{m + i}{i} \]

时间复杂度 \(O(n^3 \log V)\),瓶颈在一开始的 dp。

code
// Problem: Ex - A Nameless Counting Problem
// Contest: AtCoder - Toyota Programming Contest 2023 Spring Qual A(AtCoder Beginner Contest 288)
// URL: https://atcoder.jp/contests/abc288/tasks/abc288_h
// Memory Limit: 1024 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 210;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, m, K, fac[maxn], ifac[maxn], f[maxn], g[maxn][maxn], h[maxn][maxn], p[maxn][maxn], pw2[maxn];

inline ll C(ll n, ll m) {
	if (n < m || n < 0 || m < 0) {
		return 0;
	} else {
		return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
	}
}

inline void upd(ll &x, ll y) {
	((x += y) >= mod) && (x -= mod);
}

void solve() {
	scanf("%lld%lld%lld", &n, &m, &K);
	pw2[0] = 1;
	for (int i = 1; i <= n; ++i) {
		pw2[i] = pw2[i - 1] * 2 % mod;
	}
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[n] = qpow(fac[n], mod - 2);
	for (int i = n - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
	for (int t = 1; t <= n; ++t) {
		mems(g, 0);
		g[30][t] = 1;
		for (int d = 29; ~d; --d) {
			if (m & (1LL << d)) {
				for (int i = 0; i <= t; ++i) {
					for (int j = 0; j <= i; ++j) {
						int k = j & 1;
						k ^= ((K >> d) & 1);
						if (k && i == t) {
							continue;
						}
						upd(g[d][j], g[d + 1][i] * pw2[max(0, t - i - 1)] % mod * C(i, j) % mod);
					}
				}
			} else {
				for (int i = 0; i <= t; ++i) {
					if (i == t && (K & (1LL << d))) {
						continue;
					}
					g[d][i] = g[d + 1][i] * pw2[max(0, t - i - 1)] % mod;
				}
			}
		}
		for (int i = 0; i <= t; ++i) {
			upd(f[t], g[0][i]);
		}
	}
	mems(g, 0);
	g[0][0] = h[0][0] = 1;
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= n; ++j) {
			for (int k = 1; k <= i; ++k) {
				if (k & 1) {
					upd(g[i][j], g[i - k][j - 1] * C(i - 1, k - 1) % mod);
				} else {
					upd(h[i][j], h[i - k][j - 1] * C(i - 1, k - 1) % mod);
				}
			}
		}
	}
	for (int i = 0; i <= n; ++i) {
		for (int j = 0; j <= n; ++j) {
			ll mul = 1;
			for (int k = 0; k <= i; ++k) {
				upd(p[i][j], h[i][k] * mul % mod);
				if (m + 1 - j - k <= 0) {
					break;
				}
				mul = mul * (m + 1 - j - k) % mod;
			}
		}
	}
	f[0] = (K == 0);
	for (int i = 1; i <= n; ++i) {
		for (int j = 0; j <= i; ++j) {
			for (int k = 0; k <= j; ++k) {
				if (k < i) {
					upd(f[i], mod - C(i, j) * g[j][k] % mod * p[i - j][k] % mod * f[k] % mod);
				} else {
					ll coef = C(i, j) * g[j][k] % mod * p[i - j][k] % mod;
					f[i] = f[i] * qpow(coef, mod - 2) % mod;
				}
			}
		}
	}
	ll ans = 0;
	for (int i = 0; i * 2 <= n; ++i) {
		ll c = ifac[i];
		for (int j = m + i; j > m; --j) {
			c = c * j % mod;
		}
		upd(ans, f[n - i * 2] * ifac[n - i * 2] % mod * c % mod);
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}