Berlekamp_Massey与常系数齐次线性递推

发布时间 2023-06-11 22:41:12作者: ruizhangj

Berlekamp_Massey

BM线性递推,看了这篇博客才学会的:[link]((29条消息) [学习笔记]Berlekamp-Massey算法_cz_xuyixuan的博客-CSDN博客) 。这里稍微记录一下做法,符号和下表尽量精细。

假设有序列 \(\{a_1,a_2,\dots,a_n\}\) ,称序列 \(\{r_1,r_2,\dots ,r_m\}\) 为线性递推式,当且仅当 \(\forall m+1\leq i\leq n,a_i=\sum_{j=1}^{m}r_ja_{i-j}\)

BM算法就是要找到一个最短的线性递推式。

考虑使用增量法,假设目前已经构造了 \(\{a_1,a_2,\dots ,a_{i-1}\}\) 的线性递推式,然后加入 \(a_i\) 构造新的线性递推式。由于线性递推式会修改若干次,那么把每一次修改得到的线性递推式记录下来,记作 \(\{R_0,\dots ,R_{c}\}\) ,其中 \(R_{c}\) 就是最新的线性递推式,初值是 \(c=0,R_0=\{\}\) ,记 \(m_i\) 表示 \(R_i\) 的长度。

具体的构造方案如下:

\(\Delta_{c,i}=a_i-\sum_{j=1}^{m_c}R_{c,j}a_{i-j}\)

首先检验 \(R_{c}\)\(a_i\) 是否满足线性递推,如果 \(\Delta_{c,i}=0\) 那么说明满足,不需要进行修改。

如果 \(\Delta_{c,i}\ne 0\) ,说明需要修正了,如果 \(c=0\) 说明 \(a_i\) 是第一个非零数,那么直接令 \(c=1,R_1=\{0,\dots,0\}\) ,其中 \(0\) 的个数为 \(i\) ,这样 \(m_1\geq i\) ,不需要在任何位置判断。

否则记录错位位置 \(p_c=i\) ,表示线性递推式 \(R_c\) 在位置 \(p_c\) 第一次出错了,然后尝试使用构造一个修正用的线性递推式 \(r'\) ,使其满足 \(\forall m'+1\leq k<i,\sum_{j=1}^{m'}r'_ja_{k-j}=0\) 并且 \(\sum_{j=1}^{m'}r'_ja_{i-j}=\Delta_{c,i}\) ,就是说把修正线性递推式加上之后不影响前面的递推值,并且能把 \(\Delta_{c,i}\) 补回来。

尝试使用 \(R_{c-1}\) 来构造 \(r'\) ,可以证明是必定有解的。

\(v=\frac{\Delta_{c,i}}{\Delta_{c-1,p_{c-1}}}\) ,构造 \(r'=\{0,\dots,0,v,-vR_{c-1,1},\dots ,-cR_{c,m_c}\}\) ,其中 \(0\) 的个数为 \(i-p_{c-1}-1\) ,后面的部分是 \(v\{1,-R_{c-1}\}\)

简单计算一下这个 \(r'\) 在各个位置上的贡献。对于 \(\Delta_i\) 来说,前面的 \(i-p_{c-1}-1\)\(0\)\(a_{p_{c-1}+1,\dots i-1}\) 的贡献给抹掉了,后面其实是 \(v\Delta_{c-1,p_{c-1}}\) ,那么就是说 \(\Delta_i=v\Delta_{c-1,p_{c-1}}=\Delta_{c,i}\) ,正好就补上了。再考虑另外一个限制,由于 \(r'\) 中有 \(i-p_{c-1}-1\) 个前导 \(0\) ,并且 \(i\) 已经判断过了,所以实际上要判断的是 \(v\{1,-R_{c-1}\}\) 能否在 \(k\leq p_{c-1}-1\) 的位置上满足 \(=0\) 的要求,这个显然是满足的,因为 \(k<p_{c-1}\) 的位置都可以用 \(R_{c-1}\) 来递推。

所以就构造出了一个可行的 \(r'\) ,再令 \(R_{c+1}=r'+R_c\) ,其中 \(+\) 是对位加,就构造出了能够线性递推 \(\{a_1,\dots,a_i\}\) 的线性递推式 \(R_{c+1}\)

想要最短可以观察 \(r'\) 的长度和什么有关,可以发现 \(m'=i-p_t-1+1+m_t=i-p_t+m_t\) ,那么只需要取 \(m'\) 最小的 \(t\) 即可,并不需要强制 \(t=c-1\) ,只需要 \(t>0\) 即可。

常系数齐次线性递推

思路是把每个 \(a_n\) 都用 \(\{a_1,\dots,a_k\}\) 来表示,即把 \(\{a_1,\dots ,a_k\}\) 看作一个基底,根据上面 BM 的知识,可以知道 \(k=m_c\)

常系数齐次线性递推需要用到这样的一个性质:若 \(a_n=\sum_{i=1}^{k}p_ia_i\) ,那么 \(a_{n+m}=\sum_{i=1}^{k}p_ia_{i+m}\) ,这个是比较显然的,证明直接换元即可。

那么考虑如果已经知道了 \(a_n=\sum_{i=1}^{k}p_ia_i,a_m=\sum_{i=1}^{k}q_ia_i\) ,通过这些信息计算 \(a_{n+m}\)

具体如下:

\[a_{n+m}=\sum_{i=1}^{k}p_ia_{i+m}=\sum_{i=1}^{k}p_i\sum_{j=1}^{k}q_ja_{i+j} \]

先用一次卷积计算出 \(a_{1,\dots ,2k}\) 的系数,然后用线性递推式把 \(a_{k+1,\dots ,2k}\) 的系数还原到 \(a_{1,\dots ,k}\) 上即可。

知道这个之后直接倍增做就好了,若要求 \(a_w\) 时间复杂度为卷积的复杂度乘 \(\log w\) ,即 \(O(k^2\log w)\)\(O(k\log k \log w)\)


洛谷上的版题的代码如下:

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
inline void add(int &x,int y){if ((x+=y)>=mod) x-=mod;}
inline void sub(int &x,int y){if ((x-=y)<0) x+=mod;}
inline int Mod(int x){return x>=mod?x-mod:x;}
inline int ksm(int x,int y){int res=1;for (;y;y>>=1,x=1ll*x*x%mod) if (y&1) res=1ll*res*x%mod;return res;}
int R[10005][5005],m[10005],p[10005],r[5005],d[10005],c=0;
int a[10005];
int n,w;
int ans[5005],base[5005],t[10005];
void mul(int *a,int *b,int *c,int k){
	for (int i=1;i<=2*k;++i) t[i]=0;
	for (int i=1;i<=k;++i)
		for (int j=1;j<=k;++j)
			add(t[i+j],1ll*a[i]*b[j]%mod);
	for (int i=2*k;i>k;--i)
		for (int j=1;j<=k;++j)
			add(t[i-j],1ll*t[i]*r[j]%mod);
	for (int i=1;i<=k;++i) c[i]=t[i];
}
void ksm(int *ans,int *base,int y,int k){
	ans[1]=base[1]=1;
	while (y){
		if (y&1) mul(ans,base,ans,k);
		mul(base,base,base,k);
		y>>=1;
	}
}
int main(){
	scanf("%d%d",&n,&w);
	for (int i=1;i<=n;++i) scanf("%d",&a[i]);
	for (int i=1;i<=n;++i){
		int delta=a[i];
		for (int j=1;j<=m[c];++j) sub(delta,1ll*R[c][j]*a[i-j]%mod);
		if (!delta) continue;
		p[c]=i,d[c]=delta;
		if (!c){
			m[1]=i;
			for (int j=1;j<=i;++j) R[1][j]=0;
			++c;
		}
		else{
			int t=c-1;
			for (int j=c-2;j>=1;--j)
				if (-p[j]+m[j]<-p[t]+m[t]) t=j;
			int v=1ll*delta*ksm(d[t],mod-2)%mod;
			for (int j=1;j<=i-p[t]-1;++j) r[j]=0;
			r[i-p[t]]=v;
			for (int j=1;j<=m[t];++j) r[i-p[t]+j]=Mod(mod-1ll*R[t][j]*v%mod);
			m[c+1]=max(i-p[t]+m[t],m[c]);
			for (int j=1;j<=m[c+1];++j) R[c+1][j]=Mod(R[c][j]+r[j]);
			++c;
		}
	}
	for (int i=1;i<=m[c];++i) printf("%d ",R[c][i]);puts("");
	for (int i=1;i<=m[c];++i) r[i]=R[c][i];
	ksm(ans,base,w,m[c]);
	int sum=0;
	for (int i=1;i<=m[c];++i) add(sum,1ll*ans[i]*a[i]%mod);
	printf("%d\n",sum);
	return 0;
}