Codeforces 1909I - Short Permutation Problem

发布时间 2024-01-02 21:48:36作者: tzc_wk

介绍一下 k 老师教我的容斥做法。

考虑固定 \(m\) 对所有 \(k\) 求答案。先考虑 \(k=n-1\) 怎么做。我们将所有元素按照 \(\min(i,m-i)\) 为第一关键字,\(-i\) 为第二关键字从小到大插入,即按照 \(n,n-1,n-2,\cdots,m+1,m,1,m-1,2,\cdots\) 这样的顺序插入所有元素。那么你发现,最后得到的排列每对相邻元素之和都 \(\ge m\) 的充要条件是,每次插入一个元素时,它两端的元素都必须 \(\ge\dfrac{m}{2}\)。这是因为,根据我们插入的顺序可以知道每次插入产生的相邻元素之和都 \(\ge m\),而如果某一次插入出现了两个数 \(a<b\) 满足 \(a+b<m\),那么你如果后面这两个数中间没插元素那么直接 GG,否则你后面插的元素 \(c\) 必然满足 \(a<c<b\),此时 \(a+c<m\),到后面肯定必然会出现寄了的位置。这样的话我们动态维护一个 \(t\) 表示现在能插的位置个数,碰到 \(\ge\dfrac{m}{2}\) 的就加一,碰到 \(<\dfrac{m}{2}\) 的就减一,然后把所有时刻的 \(t\) 乘起来就行了。因为这个插入顺序的特殊性,预处理阶乘以后可以 \(O(1)\) 计算。

接下来考虑 \(k\ne n-1\) 的情况,考虑容斥。钦定 \(k\) 个相邻元素之和 \(\ge m\) 的位置,计算方案数以后二项式反演一遍即可得到原方案数。这等价于构建 \(k\) 条相邻元素之和均 \(\ge m\) 的链的方案数。但是你直接这么考虑的话就不像 \(k=n-1\) 时候能直接把一堆东西乘起来了。这时候我们做个微调:在每条链两端都加个 \(+\infty\),这样每次还是碰到 \(\ge\dfrac{m}{2}\) 的就加一,碰到 \(<\dfrac{m}{2}\) 的就减一,只不过初值变成了 \(k\)。但是这样又有一个问题:可能出现空的链。这时候我们再进行一次容斥,钦定若干条链为空,这样就能算得钦定 \(k\) 条相邻元素之和均 \(\ge m\) 的链的方案数。朴素地做是 \(O(n^3)\) 的,但是两部分容斥都可以用 NTT 优化到 \(O(n\log n)\),所以总复杂度 \(O(n^2\log n)\)

const int MOD=998244353;
const int MOD1=1e9+7;
const int MAXN=8000;
const int MAXP=8192;
const int pr=3;
const int ipr=332748118;
int qpow(int x,int e,int mod){int ret=1;for(;e;e>>=1,x=1ll*x*x%mod)if(e&1)ret=1ll*ret*x%mod;return ret;}
int n,x,fac[MAXN+5],ifac[MAXN+5],res;
void init_fac(int n){
	for(int i=(fac[0]=ifac[0]=ifac[1]=1)+1;i<=n;i++)ifac[i]=1ll*ifac[MOD%i]*(MOD-MOD/i)%MOD;
	for(int i=1;i<=n;i++)ifac[i]=1ll*ifac[i-1]*ifac[i]%MOD,fac[i]=1ll*fac[i-1]*i%MOD;
}
int binom(int n,int k){return 1ll*fac[n]*ifac[k]%MOD*ifac[n-k]%MOD;}
int rev[MAXP+5];
void NTT(vector<int>&a,int len,int type){
	int lg=31-__builtin_clz(len);
	for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
	for(int i=0;i<len;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int i=2;i<=len;i<<=1){
		int W=qpow((type<0)?ipr:pr,(MOD-1)/i,MOD);
		static int w[MAXP+5];
		for(int j=(w[0]=1);j<=(i>>1);j++)w[j]=1ll*w[j-1]*W%MOD;
		for(int j=0;j<len;j+=i){
			for(int k=0;k<(i>>1);k++){
				int X=a[j+k],Y=1ll*w[k]*a[(i>>1)+j+k]%MOD;
				a[j+k]=(X+Y>=MOD)?(X+Y-MOD):(X+Y);
				a[(i>>1)+j+k]=(X<Y)?(X-Y+MOD):(X-Y);
			}
		}
	}
	if(type==-1){
		int iv=qpow(len,MOD-2,MOD);
		for(int i=0;i<len;i++)a[i]=1ll*a[i]*iv%MOD;
	}
}
vector<int>conv(vector<int>a,vector<int>b){
	int LEN=1;while(LEN<a.size()+b.size())LEN<<=1;
	a.resize(LEN,0);b.resize(LEN,0);NTT(a,LEN,1);NTT(b,LEN,1);
	for(int i=0;i<LEN;i++)a[i]=1ll*a[i]*b[i]%MOD;
	NTT(a,LEN,-1);return a;
}
int main(){
	scanf("%d%d",&n,&x);init_fac(MAXN);
	for(int m=3;m<=n+1;m++){
		static int f[MAXN+5],g[MAXN+5],h[MAXN+5];
		memset(f,0,sizeof(f));memset(g,0,sizeof(g));
		for(int c=1;c<=n;c++){
			int prd=1ll*fac[n-m+c]*ifac[c-1]%MOD*
			qpow(1ll*(n-m+1+c)*(n-m+2+c)%MOD,(m-1)/2,MOD)%MOD;
			if(m%2==0)prd=1ll*prd*(n-m+1+c)%MOD;
			h[c]=prd;
		}
		vector<int>A(n+1),B(n+1),C;
		for(int i=0;i<=n;i++){
			A[i]=1ll*h[i]*ifac[i]%MOD;
			B[i]=1ll*((i&1)?(MOD-1):1)*ifac[i]%MOD;
		}C=conv(A,B);
		for(int k=0;k<n;k++)f[k]=1ll*fac[n-k]*C[n-k]%MOD;
		for(int i=0;i<=n;i++)A[n-i]=1ll*f[i]*fac[i]%MOD;
		C=conv(A,B);
		for(int k=0;k<n;k++)g[k]=1ll*C[n-k]*ifac[k]%MOD;
		for(int k=0;k<n;k++)res=(res+1ll*g[k]*qpow(x,n*m+k,MOD1))%MOD1;
	}printf("%d\n",res);
	return 0;
}