[PKUWC2018]猎人杀

发布时间 2023-05-02 19:23:55作者: Smallbasic

概率的分母在不断变化很麻烦,我们不妨令它可以打到已死的人。由于还活着的人概率之比没有变,显然是不会影响答案的。

考虑容斥,设 \(p(S)\) 表示集合 \(S\) 中的人在 \(1\) 后被打的方案数,那么答案就是 \(\sum_{S}(-1)^{|S|}p(S)\)\(p(S)\) 实际上就是无限开枪,每次不打 \(S\cup \{1\}\) 的概率,枚举打到 \(1\) 之前打了多少次,令 \(sum=\sum w_i\),则容易得到 \(p(S)=\sum \left({sum-w_1-\operatorname{sum}(S)\over sum}\right)^i\cdot {w_1\over sum}\)。这是个等比数列,容易得到 \({a_1\over a_1+\operatorname{sum}(S)}\)

再看答案的式子,枚举 \(S\) 是指数级的,不好算。观察数据范围我们发现 \(sum\) 只有 \(10^5\) 的级别,不妨直接枚举 \(sum\)。令 \(f(i)=\sum\limits_{\operatorname{sum}(S)=i}(-1)^{|S|}\),答案容易算出。关键在求 \(f(i)\),考虑生成函数。显然 \(f(i)=[x^i] \prod\limits_{i=2}^{n} 1-x^{w_i}\)。取 \(\ln\) 之后 \(\exp\) 即可。

#include <bits/stdc++.h>

using namespace std;

#define pi pair<int,int>
#define mp make_pair
#define poly vector<int>

const int N = 4e5 + 5, mod = 998244353, G = 3, I = 86583718, Gi = 332748118;

const int add(int a, int b) { return (a + b) >= mod ? a + b - mod : a + b; }
const int sub(int a, int b) { return a < b ? a - b + mod : a - b; }
const int mul(int a, int b) { return (1ll * a * b) % mod; }

const int power(int a, int b) {
	int t = 1, y = a, k = b;
	while (k) {
		if (k & 1) t = mul(t, y);
		y = mul(y, y); k >>= 1;
	} return t;
}

const int inv2 = power(2, mod - 2), inv2I = power(mul(2, I), mod - 2);

inline int read() {
	register int s = 0, f = 1; register char ch = getchar();
	while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
	while (isdigit(ch)) s = (s << 1) + (s << 3) + (ch & 15), ch = getchar();
	return s * f;
}

inline void FFT(int *a, int len, int typ) {
	for (int i = 0, j = 0, k; i < len; ++i) {
		if (i < j) swap(a[i], a[j]);
		for (k = len >> 1; k & j; k >>= 1) j ^= k;
		j ^= k;
	}
	for (int mid = 1; mid < len; mid <<= 1) {
		int wn = power(typ == 1 ? G : Gi, (mod - 1) / (mid << 1));
		for (int j = 0; j < len; j += mid << 1) {
			int bas = 1;
			for (int k = 0; k < mid; ++k, bas = mul(bas, wn)) {
				int x = a[j + k], y = ::mul(bas, a[j + mid + k]);
				a[j + k] = ::add(x, y);
				a[j + mid + k] = ::sub(x, y);
			}
		}
	}
	if (!typ) {
		const int iv = power(len, mod - 2);
		for (int i = 0; i < len; ++i)
			a[i] = ::mul(a[i], iv);
	}
}

inline int max_(int a, int b) {
	return a > b ? a : b;
}

inline int min_(int a, int b) {
	return a < b ? a : b;
}

inline poly add(poly a, int b) {
	for (int i = 0; i < a.size(); ++i) a[i] = ::add(a[i], b);
	return a;
}

inline poly sub(poly a, int b) {
	for (int i = 0; i < a.size(); ++i) a[i] = ::sub(a[i], b);
	return a;
}

inline poly mul(poly a, int b) {
	for (int i = 0; i < a.size(); ++i) a[i] = ::mul(a[i], b);
	return a;
}

inline poly div(poly a, int b) {
	b = ::power(b, mod - 2);
	for (int i = 0; i < a.size(); ++i) a[i] = ::mul(a[i], b);
	return a;
}

inline poly add(poly a, poly b) {
	a.resize(max_(a.size(), b.size()));
	for (int i = 0; i < b.size(); ++i) a[i] = ::add(a[i], b[i]);
	return a;
}

inline poly sub(poly a, poly b) {
	a.resize(max_(a.size(), b.size()));
	for (int i = 0; i < b.size(); ++i) a[i] = ::sub(a[i], b[i]);
	return a;
}

inline poly mul(poly a, poly b) {
	int p = a.size() + b.size() - 1; int len = 1 << (int)ceil(log2(p));
	a.resize(len); b.resize(len);
	FFT(&a[0], len, 1); FFT(&b[0], len, 1);
	for (int i = 0; i < len; ++i)
		a[i] = ::mul(a[i], b[i]);
	FFT(&a[0], len, 0); a.resize(p);
	return a;
}

inline poly inv(poly a, int len) {
	if (len == 1) return poly(1, power(a[0], mod - 2));
	int n = 1 << ((int)ceil(log2(len)) + 1);
	poly x = inv(a, len + 1 >> 1), y;
	x.resize(n); y.resize(n);
	for (int i = 0; i < len; ++i) y[i] = a[i];
	FFT(&x[0], n, 1); FFT(&y[0], n, 1);
	for (int i = 0; i < n; ++i) x[i] = ::mul(x[i], ::sub(2, ::mul(x[i], y[i])));
	FFT(&x[0], n, 0);
	x.resize(len);
	return x;
}

inline poly inv(poly a) {
	return inv(a, a.size());
}

inline poly rev(poly a) {
	reverse(a.begin(), a.end());
	return a;
}

inline poly div(poly a, poly b) {
	if (a.size() < b.size()) return poly(1, 0);
	int p = a.size() - b.size() + 1;
	poly ra = rev(a), rb = rev(b);
	ra.resize(p), rb.resize(p);
	ra = mul(ra, inv(rb));
	ra.resize(p);
	return rev(ra);
}

inline poly remainder(poly a, poly b) {
	if (a.size() < b.size()) return a;
	poly c = div(a, b), d = sub(a, mul(b, c));
	while (d.size() && !d.back()) d.pop_back();
	if (!d.size()) d = poly(1, 0);
	return d;
}

inline poly det(poly a) {
	int n = a.size();
	for (int i = 1; i < n; ++i) a[i - 1] = ::mul(a[i], i);
	a.resize(n - 1);
	return a;
}

inline poly inter(poly a) {
	int n = a.size(); a.resize(n + 1);
	for (int i = n; i >= 1; --i)
		a[i] = ::mul(a[i - 1], power(i, mod - 2));
	a[0] = 0; return a;
}

inline poly ln(poly a) {
	int n = a.size();
	a = inter(mul(det(a), inv(a)));
	a.resize(n); return a;
}

inline poly exp(poly a, int len) {
	if (len == 1) return poly(1, 1);
	poly x = exp(a, len + 1 >> 1); x.resize(len);
	poly y = ln(x);
	for (int i = 0; i < len; ++i) y[i] = ::sub(a[i], y[i]);
	++y[0]; x = mul(x, y); x.resize(len);
	return x;
}

inline poly exp(poly a) {
	return exp(a, a.size());
}

int n, inver[N], tim[N], sum = 0, w[N];
poly a;

int main() {
	n = read();
	for (int i = 1; i <= n; ++i)
		sum += w[i] = read();
	a.resize(sum + 1);
	for (int i = 2; i <= n; ++i) ++tim[w[i]];
	for (int i = 1; i <= sum; ++i) inver[i] = power(i, mod - 2);
	for (int i = 1; i <= sum; ++i)
		if (tim[i])
			for (int j = 1; j <= sum / i; ++j)
				a[j * i] = sub(a[j * i], ::mul(inver[j], tim[i]));
	a = exp(a); int res = 0;
	for (int i = 0; i <= sum; ++i) {
		res += 1ll * a[i] * (1ll * w[1] * power(w[1] + i, mod - 2) % mod) % mod;
		if (res >= mod) res -= mod;
	} printf("%d\n", res);
	return 0;
}