Atcoder Regular Contest 118 F - Growth Rate

发布时间 2023-07-14 21:20:37作者: tzc_wk

想到插值其实就挺套路的了吧……

\(dp_{i,j}\) 表示有多少种方法确定 \(a_i\sim a_n\) 使得 \(a_i=j\)。那么有 \(dp_{i,j}=\sum\limits_{k\ge ja_i}dp_{i+1,k}\)。边界条件是 \(dp_{n+1,1\sim m}=1\)。不难发现复杂度与值域有关,一眼过不去的样子。

考虑优化,记 \(lim_i\) 满足 \(lim_{n+1}=m\)\(lim_i=\lfloor\dfrac{lim_{i+1}}{a_i}\rfloor\)。那么显然 \(dp_i\) 只有前 \(lim_i\) 位是有用的,而如果我们设 \(sdp_{i,j}=\sum\limits_{k=1}^jdp_{i,k}\),那么 \(dp_{i,j}=sdp_{i+1,lim_{i+1}}-sdp_{i+1,ja_i-1}(j\le lim_i)\)。可以递归证明 \(dp_{i,j}\) 是关于 \(j\)\(n-i+1\) 次多项式,\(sdp_{i,j}\) 是关于 \(j\)\(n-i+2\) 次多项式,再结合 \(n\) 比较小不难想到插值。如果直接用“存系数”的方法维护一个多项式显然实现起来比较复杂,比较简单的方法是直接维护前 \(n-i+1\) 项的点值,由于插值可以线性,所以从 \(dp_{i+1}\) 推到 \(dp_i\) 只需要带入 \(n-i+1\) 个点值,这样复杂度是 \(O(n^3)\) 的。很 trivial 的优化的 \(a\) 中非一元素个数是 \(O(\log V)\) 的,因此对每个 \(a_i\ne 1\) 才需要暴力跑 \(O(n)\) 次插值,对于 \(a_i=1\) 只需求出 \(sdp_{i+1,lim_{i+1}}\),后面 \(sdp_{i+1,j-1}\) 查表即可。

时间复杂度 \(O(n^2\log V)\)

const int MAXN=1005;
const int MOD=998244353;
int n,a[MAXN+5],fac[MAXN+5],ifac[MAXN+5];ll M,lim[MAXN+5];
void init_fac(int n){
	for(int i=(fac[0]=ifac[0]=ifac[1]=1)+1;i<=n;i++)ifac[i]=1ll*ifac[MOD%i]*(MOD-MOD/i)%MOD;
	for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%MOD,ifac[i]=1ll*ifac[i-1]*ifac[i]%MOD;
}
vector<int>f[MAXN+5],sf[MAXN+5];
int calc(int p,ll v){
	v%=MOD;int c=n+2-p,res=0;if(v<=c)return sf[p][v];
	static int pre[MAXN+5],suf[MAXN+5];pre[0]=v;suf[c+1]=1;
	for(int i=1;i<=c;i++)pre[i]=1ll*pre[i-1]*(v-i+MOD)%MOD;
	for(int i=c;~i;i--)suf[i]=1ll*suf[i+1]*(v-i+MOD)%MOD;
	for(int i=0;i<=c;i++){
		int prd=1ll*ifac[i]*ifac[c-i]%MOD*sf[p][i]%MOD;
		if(i!=0)prd=1ll*prd*pre[i-1]%MOD;prd=1ll*prd*suf[i+1]%MOD;
		if((c-i)&1)prd=MOD-prd;res=(res+prd)%MOD;
	}return res;
}
int main(){
	scanf("%d%lld",&n,&M);for(int i=1;i<=n;i++)scanf("%d",&a[i]);init_fac(MAXN);
	for(int i=1;i<=n+1;i++)f[i].resize(n+3-i),sf[i].resize(n+3-i);
	lim[n+1]=M;for(int i=n;i;i--)lim[i]=lim[i+1]/a[i];
	f[n+1][1]=sf[n+1][1]=1;
	for(int i=n;i;i--){
		int mx=n+2-i,val_lim=calc(i+1,lim[i+1]);
		for(int j=1;j<=mx;j++)f[i][j]=(val_lim-calc(i+1,(1ll*j*a[i]-1+MOD)%MOD)+MOD)%MOD;
		for(int j=1;j<=mx;j++)sf[i][j]=(sf[i][j-1]+f[i][j])%MOD;
	}printf("%d\n",calc(1,lim[1]));
	return 0;
}