gym103415A Math Ball

发布时间 2023-12-26 21:17:00作者: Smallbasic

套路生成函数。

写出答案的式子,设 \(f_i(x)=\sum j^{c_i} x^j\),不难得到答案为:

\[[x^W]{1\over 1-x}\prod_{i=1}^n f_i(x) \]

考虑求 \(f_i(x)\)。看到指数上有 \(c_i\),想到用斯特林数展开:

\[f_i(x)=\sum_{j=0}^{\infty} x^j \sum_{k=0}^{c_i} {c_i\brace k}\binom{j}{k}k! \]

\[=\sum_{k=0}^{c_i} {c_i\brace k} k! \sum_{j=k}^{\infty}\binom{j}{k}x^j \]

注意到后面的式子是组合数一列的生成函数,设其为 \(g_k(x)\),乘上 \(1\over 1-x\) 对其系数做前缀和,由组合数按列求和的公式可知: \({1\over 1-x}g_k(x)={1\over x}g_{k+1}(x),g_k(x)={x^k\over (1-x)^{k+1}}\)。将每个 \(x^k\over (1-x)^{k+1}\) 分母中的一个 \(1-x\) 提到外面,可以将答案写为:

\[[x^W]{1\over(1-x)^{n+1}}\prod_{i=1}^n\sum_{k=0}^{c_i}{c_i\brace k}k! \left({x\over 1-x}\right)^k \]

将后面的 \({x\over 1-x}\) 看成一整个变量,后面的那坨东西是可以分治 NTT 快速计算的。设得到的多项式为 \(F(t)\),那么总答案可以写为:

\[[x^W] \sum_k {F_kx^k\over (1-x)^{n+k+1}} \]

虽然 \(W\) 很大,但是 \(F\) 有值的项数非常少。我们枚举 \(F\) 的项,由牛顿二项式定理可知 \([x^{W-k}]{1\over (1-x)^{n+k+1}}=\binom{W+n}{n+k}\),处理下降幂即可通过。

#include <bits/stdc++.h>

using namespace std;

#define pi pair<int,int>
#define mp make_pair
#define poly vector<int>

typedef long long ll;

const int N = 3e5 + 5, mod = 998244353, G = 3, Gi = (mod + 1) / G;

const int add(int a, int b) { return (a + b) >= mod ? a + b - mod : a + b; }
const int sub(int a, int b) { return a < b ? a - b + mod : a - b; }
const int mul(int a, int b) { return (1ll * a * b) % mod; }

const int power(int a, int b) {
	int t = 1, y = a, k = b;
	while (k) {
		if (k & 1) t = mul(t, y);
		y = mul(y, y); k >>= 1;
	} return t;
}

inline int read() {
	register int s = 0, f = 1; register char ch = getchar();
	while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
	while (isdigit(ch)) s = (s << 1) + (s << 3) + (ch & 15), ch = getchar();
	return s * f;
}

inline ll readll() {
	register ll s = 0, f = 1; register char ch = getchar();
	while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
	while (isdigit(ch)) s = (s << 1) + (s << 3) + (ch & 15), ch = getchar();
	return s * f;
}

inline void FFT(int *a, int len, int typ) {
	for (register int i = 0, j = 0, k; i < len; ++i) {
		if (i < j) swap(a[i], a[j]);
		for (k = len >> 1; k & j; k >>= 1) j ^= k;
		j ^= k;
	}
	for (register int mid = 1; mid < len; mid <<= 1) {
		int wn = power(typ == 1 ? G : Gi, (mod - 1) / (mid << 1));
		for (register int j = 0; j < len; j += mid << 1) {
			int bas = 1;
			for (register int k = 0; k < mid; ++k, bas = mul(bas, wn)) {
				int x = a[j + k], y = ::mul(bas, a[j + mid + k]);
				a[j + k] = ::add(x, y);
				a[j + mid + k] = ::sub(x, y);
			}
		}
	}
	if (!typ) {
		const int iv = power(len, mod - 2);
		for (register int i = 0; i < len; ++i)
			a[i] = ::mul(a[i], iv);
	}
}

inline int max_(int a, int b) {
	return a > b ? a : b;
}

inline int min_(int a, int b) {
	return a < b ? a : b;
}

inline poly add(poly a, int b) {
	for (register int i = 0; i < a.size(); ++i) a[i] = ::add(a[i], b);
	return a;
}

inline poly sub(poly a, int b) {
	for (register int i = 0; i < a.size(); ++i) a[i] = ::sub(a[i], b);
	return a;
}

inline poly mul(poly a, int b) {
	for (register int i = 0; i < a.size(); ++i) a[i] = ::mul(a[i], b);
	return a;
}

inline poly div(poly a, int b) {
	b = ::power(b, mod - 2);
	for (register int i = 0; i < a.size(); ++i) a[i] = ::mul(a[i], b);
	return a;
}

inline poly add(poly a, poly b) {
	a.resize(max_(a.size(), b.size()));
	for (register int i = 0; i < b.size(); ++i) a[i] = ::add(a[i], b[i]);
	return a;
}

inline poly sub(poly a, poly b) {
	a.resize(max_(a.size(), b.size()));
	for (register int i = 0; i < b.size(); ++i) a[i] = ::sub(a[i], b[i]);
	return a;
}

inline poly mul(poly a, poly b) {
	int p = a.size() + b.size() - 1; int len = 1 << (int)ceil(log2(p));
	a.resize(len); b.resize(len);
	FFT(&a[0], len, 1); FFT(&b[0], len, 1);
	for (register int i = 0; i < len; ++i)
		a[i] = ::mul(a[i], b[i]);
	FFT(&a[0], len, 0); a.resize(p);
	return a;
}

int fac[N], ifac[N];

inline poly stir2R(int n) {
	poly a, b, c;
	for (int i = 0; i <= n; ++i) {
		a.push_back((1ll * power(i, n) * ifac[i]) % mod);
		b.push_back(ifac[i]); if (i & 1) b[i] = -b[i] + mod;
		if (b[i] >= mod) b[i] -= mod;
	} c = mul(a, b); c.resize(n + 1); return c;
}

inline int C(int n, int m) {
	if (n < m) return 0;
	return 1ll * fac[n] * (1ll * ifac[m] * ifac[n - m] % mod) % mod;
}

inline void init() {
	fac[0] = 1;
	for (register int i = 1; i < N; ++i) fac[i] = (1ll * i * fac[i - 1]) % mod;
	ifac[N - 1] = power(fac[N - 1], mod - 2);
	for (register int i = N - 2; ~i; --i) ifac[i] = (1ll * (i + 1) * ifac[i + 1]) % mod;
}

inline void otp(int x) {
	(x >= 10) ? otp(x / 10), putchar((x % 10) ^ 48) : putchar(x ^ 48);
}

int n, c[N], cnt[N], mx = 0;
ll W;
vector<poly> F;
poly g;

inline poly calc(int l, int r) {
	if (l == r) return F[l];
	if (l + 1 == r) return mul(F[l], F[r]);
	int mid = l + r >> 1;
	return mul(calc(l, mid), calc(mid + 1, r));
}

int main() {
	init();
	n = read(); W = readll();
	for (int i = 1; i <= n; ++i) ++cnt[c[i] = read()], mx = max(mx, c[i]);
	for (int i = 0; i <= mx; ++i) {
		if (!cnt[i]) continue;
		poly s = stir2R(i);
		for (int j = 0; j <= i; ++j)
			s[j] = 1ll * s[j] * fac[j] % mod;
		while (cnt[i]--) F.push_back(s);
	} g = calc(0, F.size() - 1);
	int x = 1, res = 0;
	for (int i = 0; i < n; ++i) x = 1ll * ((W + n - i) % mod) * x % mod;
	for (int k = 0; k < g.size() && x; ++k) {
		res += 1ll * g[k] * (1ll * x * ifac[n + k] % mod) % mod;
		if (res >= mod) res -= mod;
		x = 1ll * ((W - k) % mod) * x % mod;
	} printf("%d\n", res);
	return 0;
}