洛谷 P8923 -『MdOI R5』Many Minimizations

发布时间 2023-07-18 19:26:24作者: tzc_wk

怎么 ARC 还能撞题的?只能说 Kubic 牛逼。

首先显然没法保序回归。考虑用类似于凸壳优化 DP 的做法解决原问题(也就是 P4331):

  • \(dp_{i,j}\) 表示考虑前 \(i\) 位,\(x_i=j\) 的最小代价,显然有 \(dp_{i,j}=\min_{k\le j}\{dp_{i-1,k}+|j-a_i|\}\)
  • \(dp\) 值显然是一个折线,用堆维护斜率改变的拐点,那么每次加入一个元素相当于:
    • 加入 \(a_i\),取出最大元素 \(x\),答案加上 \(x-a_i\),再加入 \(a_i\)

考虑转化贡献体,将“答案加上 \(x-a_i\)”,改为”数有多少 \(x\le t<a_i\),每出现一个这样的 \(t\),答案加一“。考虑这个枚举 \(t\),将 \(\le t\) 的看作 \(0\)\(>t\) 的看作 \(1\),那么我们发现堆中的数只有 \(01\) 之分,而答案会加以当且仅当堆中存在 \(1\) 且加入的 \(a_i\)\(0\)

这样有一个暴力 DP 是 \(dp_{i,j}\) 表示考虑前 \(i\) 个元素,目前堆中有 \(j\)\(1\) 有多少种方案数,这样可以做到 \(O(n^2m)\),但是还是不足以通过此题。

考虑优化,根据 DP 的过程可知,一个 \(t\) 对应的贡献之和其实是一个关于 \(t\)\(n+1\) 次多项式,于是不难想到求出这个多项式的系数,然后插值求出答案。那么怎么求出这个多项式的系数呢?重新审视这个 DP 过程,发现类似于网格图路径计数:每次可以向右上走一格或者向右下走一格,如果到 \(y=0\) 以下就回到 \(y=0\),向右上走有 \(m-t\) 的系数,向右下走有 \(t\) 的系数。然后如果向下走并且当前不在 \(y=0\) 就会产生 \(1\) 的答案。

怎么处理这个问题呢?如果我们强制要求不能走到 \(y=0\) 以下,那么问题是经典的反射容斥。但是这里的问题是走到 \(y=0\) 以下不会对答案产生贡献,那么我们很自然地想到那总的向下走的次数减去走到 \(y=0\) 时向下走的次数。后者可以对答案差分贡献:容易证明,一条在 \(y=0\) 时候向下走的次数 \(\ge c\) 的路径,与一条从 \((0,c)\) 出发,每次向右上或右下走,到达 \(x=n\) 且经过了 \(y=0\) 的路径一一对应,而后者也是经典的反射容斥,于是统计下 \(coef_p\) 表示答案多项式里 \(t^p(m-t)^{n-p}\) 的系数即可。

时间复杂度 \(O(n^2)\)

const int MAXN=5000;
int n,m,mod,c[MAXN+5][MAXN+5],coef[MAXN+5],v[MAXN+5];
int qpow(int x,int e){
	int ret=1;
	for(;e;e>>=1,x=1ll*x*x%mod)if(e&1)ret=1ll*ret*x%mod;
	return ret;
}
int calc(int x){return (c[n][n+x>>1]-c[n][n+x+2>>1]+mod)%mod;}
int main(){
	scanf("%d%d%d",&n,&m,&mod);
	for(int i=0;i<=n;i++){
		c[i][0]=1;
		for(int j=1;j<=i;j++)c[i][j]=(c[i-1][j]+c[i-1][j-1])%mod;
	}
	for(int s=0;s<=n;s++)if((n-s)%2==0){
		int v=1ll*calc(s)*((n-s)>>1)%mod;
		coef[(n-s)/2]=(coef[(n-s)/2]+v)%mod;
		coef[min(s+(n-s)/2+1,n+1)]=(coef[min(s+(n-s)/2+1,n+1)]-v+mod)%mod;
	}
	for(int s=1;s<=n+1;s++)coef[s]=(coef[s]+coef[s-1])%mod;
	int ss=0,res=0;
	for(int i=1;i<=n+2;i++){
		int cpw=qpow(m-i,n),stp=1ll*i*qpow(m-i,mod-2)%mod;
		for(int j=0;j<=n;j++)ss=(ss+1ll*coef[j]*cpw)%mod,cpw=1ll*cpw*stp%mod;
		v[i]=ss;
	}
	for(int i=1;i<=n+2;i++){
		int up=1,dw=1;
		for(int j=1;j<=n+2;j++)if(i!=j)
			up=1ll*up*(m-j+mod)%mod,dw=1ll*dw*(i-j+mod)%mod;
		res=(res+1ll*up*qpow(dw,mod-2)%mod*v[i])%mod;
	}printf("%d\n",res);
	return 0;
}