[PKUWC2018] 猎人杀

发布时间 2023-07-28 07:24:47作者: 灰鲭鲨

题目描述

猎人杀是一款风靡一时的游戏“狼人杀”的民间版本,他的规则是这样的:

一开始有 \(n\) 个猎人,第 \(i\) 个猎人有仇恨度 \(w_i\) ,每个猎人只有一个固定的技能:死亡后必须开一枪,且被射中的人也会死亡。

然而向谁开枪也是有讲究的,假设当前还活着的猎人有 \([i_1\ldots i_m]\),那么有 \(\frac{w_{i_k}}{\sum_{j = 1}^{m} w_{i_j}}\) 的概率是向猎人 \(i_k\) 开枪。

一开始第一枪由你打响,目标的选择方法和猎人一样(即有 \(\frac{w_i}{\sum_{j=1}^{n}w_j}\) 的概率射中第 \(i\) 个猎人)。由于开枪导致的连锁反应,所有猎人最终都会死亡,现在 \(1\) 号猎人想知道它是最后一个死的的概率。

答案对 \(998244353\) 取模。

对于 \(100\%\) 的数据,有 \(w_i>0\),且 \(1\leq \sum\limits_{i=1}^{n}w_i \leq 100000\)

题解

非常妙的第一步。

现在这个分母会改变,很不好算。所以把题目进行一个转换:假设死了的猎人尸体仍然在那里,然后在打枪的时候如果达到一个尸体就再打一次。那么这样对于任何一个猎人,他的死亡概率仍然是一样的。做了这个转换后,一个人被打到的概率就很固定了。

然后直接算不好算,考虑容斥。定义 \(f(S)\) 为在 1 前面被打到的集合恰好为 \(S\) 的概率,\(g(S)\) 为在 1 前面死的集合包含于 \(S\) 的概率。然后 \(g(S)\) 是可算的。枚举在第 \(i+1\) 次打到 1,那么在之前的打枪过程中一定是打到 \(S\) 中的猎人,那么概率 为 \((\frac{\sum\limits_{i\in S}w_i}{\sum\limits_{i=1}^nw_i})^i\),枚举的次数可以到无限,所以等比数列求和。最后乘上一个选择 1 的概率,也就是 \(\frac {w_1}{\sum\limits_{i=1}^nw_i}\)

计算 \(f(S)=\sum\limits_{T\in S}g(T)(-1)^{|T|-|S|}\)即可。

发现式子中之和集合中 \(w\) 的和有关,所以可以用背包跑出所有 \(w\) 的和,但是这样会超时,将背包改成分治 FFT 就好了

#include<bits/stdc++.h>
using namespace std;
const int N=1e6+5,P=998244353;
int n,w[N],ret,ans,rr[N];
vector<int>g[N];
int pown(int x,int y)
{
	if(!y)
		return 1;
	int t=pown(x,y>>1);
	if(y&1)
		return 1LL*t*t%P*x%P;
	return 1LL*t*t%P;
}
int read()
{
	int s=0;
	char ch=getchar();
	while(ch<'0'||ch>'9')
		ch=getchar();
	while(ch>='0'&&ch<='9')
		s=s*10+ch-48,ch=getchar();
	return s;
}
void ntt(vector<int>&a,int op)
{
	for(int i=0;i<a.size();i++)
		if(rr[i]<i)
			swap(a[i],a[rr[i]]);
	for(int md=1;md<a.size();md<<=1)
	{
		int g=pown(op? 3:332748118,(P-1)/(md<<1));
		for(int i=0;i<a.size();i+=md<<1)
		{
			int pw=1;
			for(int j=0;j<md;j++,pw=1LL*pw*g%P)
			{
				int k=a[i+j+md]*1LL*pw%P;
				a[i+j+md]=(a[i+j]-k+P)%P;
				(a[i+j]+=k)%=P;
			}
		}
	}
	if(!op)
	{
		int pw=pown(a.size(),P-2);
		for(int i=0;i<a.size();i++)
			a[i]=1LL*a[i]*pw%P;
	}
}
vector<int> merge(int l,int r)
{
	if(l==r)
		return g[l];
	int md=l+r>>1;
	vector<int>a=merge(l,md),b=merge(md+1,r);
	int s=1<<(int)log2(a.size()+b.size()-2)+1;
	for(int i=1;i<s;i++)
		rr[i]=rr[i>>1]>>1|(i&1)*s/2;
	a.resize(s);
	b.resize(s);
	ntt(a,1);
	ntt(b,1);
	for(int i=0;i<s;i++)
		a[i]=1LL*a[i]*b[i]%P;
	ntt(a,0);
	return a;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
	{
		w[i]=read();
		g[i].push_back(1);
		for(int j=1;j<w[i];j++)
			g[i].push_back(0);
		g[i].push_back(P-1);
		ret+=w[i];
	}
	vector<int>a=merge(2,n);
	for(int i=1;i<ret;i++)
		(ans+=1LL*i*pown(ret-i,P-2)%P*a[i]%P)%=P;
	if(n%2==0)
		ans=1LL*ans*(P-1)%P;
	printf("%lld\n",ans*1LL*w[1]%P*pown(ret,P-2)%P);
}