分治FFT

发布时间 2023-07-13 21:09:22作者: NuclearReactor

考虑一下递推式:

$$
f_{i}=\sum\limits_{j=1}^i g_jf_{i-j}
$$

如果要暴力计算的话复杂度是 $n^2$,考虑到后面的是卷积的形式,但是唯一的问题是 $f$ 自己得出自己。考虑分治。设当前分治区间为 $l,r$,首先分治计算 $l,mid$ 的所有 $f$ 值,然后考虑 $l,mid$ 给 $mid+1,r$ 的 $f$ 值的贡献,具体来说,我们把 $l,mid$ 的 $f$ 和 $0,r-l$ 的 $g$ 卷一下得到给后面的贡献。

考虑这个题也可以求逆来做,设 $F(x)=\sum f_ix^i$,设 $G(x)=\sum g_ix^i$。由于 $f_0=1$,$g_0=0$,考虑有 $F(x)=F(x)G(x)+f_0\Rightarrow F(x)\equiv \frac{f_0}{1-G(x)} \bmod x^n$。求逆即可。

分治 NTT:

int n,g[N],tr[N],F[N],G[N],H[N],f[N];

inline int ksm(int a,int b,int mod){int res=1;while(b){if(b&1)res=1ll*a*res%mod;a=1ll*a*a%mod;b>>=1;}return res;}
inline int inv(int a){return ksm(a,mod-2,mod);}
inline void Gettr(int n){
    for(int i=0;i<n;i++) tr[i]=(tr[i>>1]>>1)|((i&1)?(n>>1):0);
}
inline void NTT(int *f,int n,int op){
    rep(i,0,n-1) if(i<tr[i]) swap(f[i],f[tr[i]]);
    for(int i=2;i<=n;i<<=1){
        int x=ksm(gg,(mod-1)/i,mod),l=i>>1;if(op==0) x=inv(x);
        for(int j=0;j<n;j+=i){
            int now=1;
            for(int k=j;k<j+l;k++){
                int tt=1ll*now*f[k+l]%mod;
                f[k+l]=(f[k]-tt)%mod;f[k]=(f[k]+tt)%mod;
                now=1ll*now*x%mod;
            }
        }
    }
    if(op==0){
        int invn=inv(n);rep(i,0,n-1) f[i]=1ll*f[i]*invn%mod;
    }
}
inline void Solve(int l,int r){
    if(l==r){if(l==0)f[0]=1;return;}
    int mid=(l+r)>>1,len=r-l+1;
    Solve(l,mid);rep(i,0,len-1) F[i]=G[i]=0;
    rep(i,0,mid-l) F[i]=f[i+l];
    rep(i,0,len-1) G[i]=g[i];
    // printf("l=%d r=%d\n",l,r);
    // rep(i,0,len-1) printf("F[%d]=%d\n",i,F[i]);
    // rep(i,0,len-1) printf("G[%d]=%d\n",i,G[i]);
    int nl=1;while(nl<len*2){nl<<=1;}Gettr(nl);
    rep(i,mid-l+1,nl-1) F[i]=0;rep(i,len,nl-1) G[i]=0;
    NTT(F,nl,1);NTT(G,nl,1);
    // rep(i,0,len-1) printf("F[%d]=%d\n",i,F[i]);
    // rep(i,0,len-1) printf("G[%d]=%d\n",i,G[i]);
    rep(i,0,nl-1) F[i]=1ll*F[i]*G[i]%mod;
    NTT(F,nl,0);rep(i,mid+1-l,r-l) (f[i+l]+=F[i])%=mod;
    // rep(i,mid+1,r-l) printf("F[%d]=%d\n",i,F[i]);
    Solve(mid+1,r);
}

int main(){
    // freopen("my.in","r",stdin);
    // freopen("my.out","w",stdout);
    read(n);rep(i,1,n-1) read(g[i]);
    Solve(0,n-1);
    rep(i,0,n-1) printf("%d ",(f[i]+mod)%mod);
    return 0;
}

求逆:

int tr[N];
inline int ksm(int a,int b,int mod){int res=1;while(b){if(b&1) res=1ll*a*res%mod;a=1ll*a*a%mod;b>>=1;}return res;}

inline void Gettr(int n){
    for(int i=0;i<n;i++) tr[i]=(tr[i>>1]>>1)|((i&1)?(n>>1):0);
}
inline void NTT(int *f,int len,int flag){
    for(int i=0;i<len;i++) if(i<tr[i]) swap(f[i],f[tr[i]]);
    for(int p=2;p<=len;p<<=1){
        int md=ksm(g,(mod-1)/p,mod),l=p>>1;
        if(flag==-1) md=ksm(md,mod-2,mod);
        for(int i=0;i<len;i+=p){
            int buf=1;
            for(int j=i;j<i+l;j++){
                int tt=1ll*f[j+l]*buf%mod;
                f[j+l]=((f[j]-tt)%mod+mod)%mod;
                f[j]=(f[j]+tt)%mod;buf=1ll*buf*md%mod;
            }
        }
    }
}

int n,m,c[N],a[N],b[N];

inline void GetInv(int len,int *a,int *b){
    if(len==1){b[0]=ksm(a[0],mod-2,mod);return;}
    GetInv((len+1)>>1,a,b);m=1;while(m<(len<<1)) m<<=1;
    Gettr(m);for(int i=0;i<len;i++) c[i]=a[i];
    for(int i=len;i<m;i++) c[i]=0;
    // printf("m=%d\n",m);
    // for(int i=0;i<m;i++) printf("%d ",c[i]);
    NTT(c,m,1);NTT(b,m,1);
    for(int i=0;i<m;i++) b[i]=1ll*(2-1ll*b[i]*c[i]%mod+mod)%mod*b[i]%mod;
    NTT(b,m,-1);int inv=ksm(m,mod-2,mod);for(int i=0;i<m;i++) b[i]=1ll*b[i]*inv%mod;
    for(int i=len;i<m;i++) b[i]=0;
    // printf("len=%d\n",len);
    // for(int i=0;i<n;i++) printf("%d ",b[i]);puts("");
}

int main(){
    // freopen("my.in","r",stdin);
    // freopen("my.out","w",stdout);
    read(n);for(int i=1;i<n;i++) read(a[i]);
    rep(i,0,n-1) a[i]=-a[i];a[0]++;
    GetInv(n,a,b);
    rep(i,0,n-1) printf("%d ",b[i]);
    return 0;
}