《拉格朗日插值》小记

发布时间 2023-10-01 17:02:19作者: daduoli

随便学学,主要是又被卡科技了。

参考文章:

\(Alex\_Wei\) 的拉格朗日插值与多项式乘法

\(Alex\_Wei\) 的多项式 I:拉格朗日插值与快速傅里叶变换

\(yyc\) 的从拉插到快速插值求值

算法介绍

公式口糊

主要用来对于一个给定的 \(n\) 次多项式,用 \(n+1\) 个点值在 \(O(n^2)\) 的时间复杂度内求出多项式的各项系数。这就是插值。

其实这东西和中国剩余定理的思路是很像的,就是对于一个点值而言当 \(k=x_i\) 时,只有 \(y_i\) 是显出来的。其他都是不显的。

所以我们的式子中得有 \(k-x_i\) 的因数。然后我们还要再乘上一个 \(y_i\)

那么我们式子就变成了 \(f(k)=\sum\limits_{i=1}^n y_i\prod\limits_{i\not =j} {k-x_j}\)

但是我们发现这样仍然不对,我们还要再除以一个 \(x_i-x_j\) 才对,所以最后公式长这样。

\(f(k)=\sum\limits_{i=1}^n y_i\prod\limits_{i\not =j} \frac {k-x_j}{x_i-x_j}\)

拉插的公式还是比较好理解的。

实际上拉插的思路就是构造出一个等价于原函数的函数,以此求值。

如何求解原函数的各项系数呢。

我们可以记 \(D(k)=\prod\limits_{i=1}^n k-x_i,c_i=y_i\prod\limits_{i\not =j}^n \frac 1{x_i-x_j}\)\(D(k)\) 是一个多项式而非点值)

那么原式变成 \(\sum\limits_{i=1}^n y_i\frac {D(k)}{k-x_i}\)

我们记 \(A_i(k)\)\(\frac {D(k)}{k-x_i}\) (这是一个多项式,而不是一个点值)

那么我们记答案函数为 \(g\)

\(g=c_1A_i(k)+c_2A_2(k)+...+c_nA_n(k)\)

首先对于 \(D(k)\) 我们可以在 \(O(n^2)\) 的时间复杂度内预处理出来(注意 \(k\) 是什么不影响,因为我们求的是系数,而非真正的值,所以可以看成 \(x\) ,就是说没有具体的值)

而对于每个 \(A_i(k)\) 可以在 \(O(n)\) 的时间内用 \(D(k)\) 除以一次多项式的方法得到,具体实现和反背包类似。

然后把所有系数加起来就得到了我们 \(g\) 的答案。

时间复杂度 \(O(n^2)\)

点击查看代码
#include<bits/stdc++.h>
typedef long long LL;

using namespace std;
const int MAXN=5e3+10,MODD=998244353;
int n;
LL x[MAXN],y[MAXN],d[MAXN],D[MAXN],g[MAXN],f[MAXN];
LL ksm(LL x,LL y) {
	LL ret=1;
	while(y) {
		if(y&1) ret=ret*x%MODD;
		x=x*x%MODD;
		y>>=1;
	}
	return ret;
}
int main () {
	scanf("%d",&n);
	for(int i=1;i<=n;++i) {
		scanf("%d%d",&x[i],&y[i]);
	}
	for(int i=1;i<=n;++i) {
		d[i]=1;
		for(int j=1;j<=n;++j) {
			if(i==j) continue;
			d[i]=(d[i]*(x[i]-x[j]+MODD)%MODD)%MODD;
		}
		d[i]=ksm(d[i],MODD-2)*y[i]%MODD;
	}
	D[0]=1;
	for(int i=1;i<=n;++i) {
		for(int j=n-1;j>=1;--j) {
			D[j]=(D[j-1]+D[j]*(MODD-x[i])%MODD)%MODD;
		}
		D[0]=D[0]*(MODD-x[i])%MODD;
	}
	for(int i=1;i<=n;++i) {
		LL cs=ksm(MODD-x[i],MODD-2);
		g[0]=D[0]*cs%MODD;
		for(int j=1;j<n;++j) {
			g[j]=(D[j]-g[j-1]+MODD)*cs%MODD;
		}
		for(int j=0;j<n;++j) {
			f[j]=(f[j]+d[i]*g[j]%MODD)%MODD;
		}
	}
	for(int i=0;i<n;++i) {
		printf("%lld ",f[i]);
	}
	return 0;
}

对于点值取值连续的处理

\(pre_i=\prod\limits_{j=1}^i k-i,suf_i=\prod\limits_{j=i}^n k-i\)

\[\begin{aligned} f(k)&=\sum\limits_{i=1}^n y_i\prod\limits_{i\not =j} \frac {k-x_j}{x_i-x_j}=\sum\limits_{i=1}^n y_i\prod\limits_{j=1}^{i-1} \frac {(k-j)}{i-j} (-1)^{n-i} \prod\limits_{j=i+1}^{n}\frac {(k-j)}{j-i} \\ &=\sum\limits_{i=1}^n y_i\frac {pre_{i-1}}{(i-1)!}\frac {suf_{i+1}}{(n-i)!}(-1)^{n-i} \end{aligned} \]

上面的东西可以快速 \(O(n)\) 预处理出来。

所以对于点值连续的我们可以做到 \(O(n)\) 完成插值。

拉格朗日插值求解 \(dp\)

有时候在做一些 \(dp\) 题的时候,我们可以得到一个 \(dp_{i,j}\) 的转移但是第一维 \(j\) 很大,或第一维 \(i\) 很大,而我们通过证明答案是一个次数与另一维很小的变量有关的一个多项式,那么我们可以带若干个点值进去,然后对于无法求解的答案,通过插值获取。

The Sum of the k-th Powers

首先答案是一个 \(k+1\) 次多项式,虽然我不会证明。

然后我们求出前 \(k+1\) 个的点值,然后由于 \(k\) 比较大,所以考虑点值取值连续,然后 \(O(n)\) 插值即可。

时间复杂度关于 \(k\) 的先行对数,瓶颈在于求 \(k+1\) 个点值,用线性筛只筛质数,可以做到线性复杂度。

点击查看代码
#include<bits/stdc++.h>
typedef long long LL;

using namespace std;
const int MAXN=1e6+10,MODD=1e9+7;
LL n,k;
LL s[MAXN],pre[MAXN],suf[MAXN],pw[MAXN];
LL ksm(LL x,LL y) {
	LL ret=1;
	while(y) {
		if(y&1) ret=ret*x%MODD;
		x=x*x%MODD;
		y>>=1;
	}
	return ret;
}
int main () {
	scanf("%lld%lld",&n,&k);
	for(int i=1;i<=k+2;++i) {
		s[i]=(s[i-1]+ksm(i,k))%MODD;
	}
	pre[0]=1;
	for(int i=1;i<=k+2;++i) {
		pre[i]=pre[i-1]*(n-i)%MODD;
	}
	suf[k+3]=1;
	for(int i=k+2;i>=1;--i) {
		suf[i]=suf[i+1]*(n-i)%MODD;
	}
	pw[0]=1;
	for(int i=1;i<=k+2;++i) pw[i]=pw[i-1]*i%MODD;
	LL ans=0;
	for(int i=1;i<=k+2;++i) {
		ans=(ans+((k+2-i)&1?-1:1)*s[i]*pre[i-1]%MODD*suf[i+1]%MODD*ksm(pw[i-1],MODD-2)%MODD*ksm(pw[k+2-i],MODD-2)%MODD+MODD)%MODD;
	}
	printf("%lld\n",ans);
	return 0; 
}

P4463 [集训队互测 2012] calc

首先直接暴力 \(O(nk)\) 是很简单的,拆贡献,记 \(f_{i,j}\)\(i\) 选了 \(j\) 个数的所有方案各自乘积的和,然后一起乘上 \(i\) 即可。

最后再乘一个 \(n!\) 就可以得到答案。

考虑如何进一步求解。

由于 \(i\) 很大,我们可以猜测 \(f_{i,j}\) 是关于 \(i\)\(kj+b\) 次多项式( \(k,b\) 是常数)

\(f_{i,j}=f_{i-1,j}+f_{i-1,j-1}\times i\)

\(\color{red}\text{一个经典套路,差分证明次数}\)

\(f_{i,j}-f_{i-1}{j}=i\times f_{i-1,j-1}\)

我们记 \(g_{i,j}\) 表示 \(f_{i,j}\) 是关于 \(i\) 的一个 \(g_{i,j}\) 次多项式。

那么根据多项式差分有左边式子的次数是 \(g_{i,j}-1\) 。然后右边式子的次数显然是 \(g_{i-1,j-1}+1\)

所以 \(g_{i,j}-1=g_{i-1,j-1}+1\)

而当 \(j=0\)\(g_{i,0}=0\) ,所以 \(g_{i,j}\) 是一个 \(2j\) 次多项式。

然后带 \(2n+1\) 个点值进去然后插一下就好了。

时间复杂度 \(O(n^2)\)

点击查看代码
#include<bits/stdc++.h>
typedef long long LL;

using namespace std;
const int MAXN=1010;
LL k,n,P;
LL f[MAXN][MAXN];
LL ksm(LL x,LL y) {
	LL ret=1;
	while(y) {
		if(y&1) ret=ret*x%P;
		x=x*x%P;
		y>>=1;
	}
	return ret;
}
int main () {
	scanf("%lld%lld%lld",&k,&n,&P);
	f[0][0]=1;
	for(int i=1;i<=2*n+1;++i) {
		for(int j=0;j<=2*n+1;++j) {
			f[i][j]=f[i-1][j];
			if(j) f[i][j]=(f[i][j]+f[i-1][j-1]*i)%P;
		}
	}
	LL ans=0;
	for(int i=1;i<=2*n+1;++i) {
		LL fz=f[i][n],fm=1;
		for(int j=1;j<=2*n+1;++j) {
			if(i==j) continue;
			fz=fz*(k-j)%P; fz=(fz+P)%P;
			fm=(fm*(i-j)%P+P)%P;
		}
		ans=(ans+fz*ksm(fm,P-2)%P)%P;
	}
	for(int i=1;i<=n;++i) ans=(ans*i)%P;
	printf("%lld\n",ans);
	return 0; 
}

Cowmpany Cowmpensation

证明一下答案是 \(n\) 次多项式即可。

然后简单地树形 \(dp\)

时间复杂度 \(O(n^2)\)

点击查看代码
#include<bits/stdc++.h>
typedef long long LL;

using namespace std;
const int MAXN=3010,P=1e9+7;
int n,D;
vector<int> e[MAXN];
void adline(int f,int t) {
	e[f].push_back(t);
}
LL add(LL x,LL y) {
	return (x+y>=P?x+y-P:x+y);
}
LL f[MAXN][MAXN],g[MAXN][MAXN],sz[MAXN],cz[MAXN];
void dfs(int u) {
	for(int i=1;i<=n+1;++i) {
		f[u][i]=1;
	}
	sz[u]=1;
	for(auto t:e[u]) {
		dfs(t); sz[u]+=sz[t];
		for(int j=1;j<=n+1;++j) {
			f[u][j]=f[u][j]*g[t][j]%P;
		}
	}
	for(int i=1;i<=n+1;++i) {
		g[u][i]=add(g[u][i-1],f[u][i]);
	}
}
LL ksm(LL x,LL y) {
	LL ret=1;
	while(y) {
		if(y&1) ret=ret*x%P;
		x=x*x%P;
		y>>=1;
	}
	return ret;
}
void lglrcz(LL *a,LL n,LL k) {
	LL ans=0;
	for(int i=1;i<=n;++i) {
		LL fz=a[i],fm=1;
		for(int j=1;j<=n;++j) {
			if(i==j) continue;
			fz=(fz*(k-j))%P;
			fm=(fm*(i-j))%P;
		}
		fz=(fz+P)%P; fm=(fm+P)%P;
		ans=(ans+fz*ksm(fm,P-2)%P)%P;
	}
	printf("%lld\n",ans);
}
int main () {
	scanf("%d%d",&n,&D);
	for(int i=2;i<=n;++i) {
		int x;
		scanf("%d",&x);
		adline(x,i);
	}
	dfs(1);
	for(int i=1;i<=n+1;++i) {
		cz[i]=g[1][i];
	}
	lglrcz(cz,n+1,D);
	return 0; 
}