【学习笔记】任意模数多项式乘法

发布时间 2023-07-16 21:06:13作者: SoyTony

三模数 NTT

由于多数 NTT 的操作对应值域 \(10^9\),规模 \(10^5\),所以选取三个常用 NTT 模数 \(p_1=998244353\)\(p_2=1004535809\)\(p_3=469702049\) 做三次乘法也就是九次 NTT。

三个模数的乘积大于结果的理论最大值,所以可以 CRT 合并得到原数再取模。使用 EXCRT 可以不开 __int128

EXCRT 具体过程是先把前两个结果 \(h_{1,i}\)\(h_{2,i}\)\(p_1\)\(p_2\) 下合并,解得一个 \(k\),使得 \(h_i\equiv kp_1+h_{1,i}\pmod {p_1p_2}\),之后拿这个值去和 \(h_{3,i}\)\(p_1p_2\)\(p_3\) 下合并,解得一个 \(k'\) 使得 \(h_i\equiv k'p_1p_2+kp_1+h_{1,i} \pmod {p_1p_2p_3}\),这个数对给定的 \(p\) 取模即可。

常数极大。

点击查看代码
inline int q_pow(int A,int B,int P){
    int res=1;
    while(B){
        if(B&1) res=1ll*res*A%P;
        A=1ll*A*A%P;
        B>>=1;
    }
    return res;
}
inline ll exgcd(ll A,ll B,ll &X,ll &Y){
    if(!B){
        X=1,Y=0;
        return A;
    }
    ll D=exgcd(B,A%B,Y,X);
    Y-=A/B*X;
    return D;
}

int rev[maxn];
int base,w[maxn];
struct Poly{
    const static int g=3;
    int deg;
    vector<ull> f;
    ull& operator[](const int &i){return f[i];}
    ull operator[](const int &i)const{return f[i];}
    inline void set(int L){deg=L;f.resize(L);}
    inline void clear(int L,int R){for(int i=L;i<=R;++i)f[i]=0;}
    inline void output(int L){for(int i=0;i<L;++i)printf("%llu ",f[i]);printf("\n");}
    inline void NTT(int L,bool type,int P){
        set(L);
        int inv_g=q_pow(g,P-2,P);
        for(int i=1;i<L;++i){
            rev[i]=(rev[i>>1]>>1)+(i&1?L>>1:0);
            if(i<rev[i]) swap(f[i],f[rev[i]]);
        }
        for(int d=1;d<L;d<<=1){
            base=q_pow(type?g:inv_g,(P-1)/(2*d),P);
            w[0]=1;
            for(int i=1;i<d;++i) w[i]=1ll*w[i-1]*base%P;
            for(int i=0;i<L;i+=d<<1){
                for(int j=0;j<d;++j){
                    ull x=f[i+j],y=f[i+d+j]*w[j]%P;
                    f[i+j]=x+y,f[i+d+j]=x-y+P;
                }
            }
        }
        for(int i=0;i<L;++i) f[i]%=P;
        if(!type){
            int inv_L=q_pow(L,P-2,P);
            for(int i=0;i<L;++i) f[i]=f[i]*inv_L%P;
        }
    }
}F,G,H[3];

int n,m,p;
int a[maxn],b[maxn],c[maxn];
ll mod[3]={998244353,1004535809,469762049};

inline int solve(ll A,ll B,ll C){
    ll X1,Y1,X2,Y2;
    exgcd(mod[0],mod[1],X1,Y1);
    X1=(X1%mod[1]+mod[1])%mod[1];
    X1=((B-A)%mod[1]+mod[1])%mod[1]*X1%mod[1];
    exgcd(mod[0]*mod[1],mod[2],X2,Y2);
    X2=(X2%mod[2]+mod[2])%mod[2];
    X2=((C-(X1*mod[0]+A)%(mod[0]*mod[1]))%mod[2]+mod[2])%mod[2]*X2%mod[2];
    return (X2%p*mod[0]%p*mod[1]%p+X1%p*mod[0]%p+A%p)%p;

}

int main(){
    n=read(),m=read(),p=read();
    for(int i=0;i<=n;++i) a[i]=read();
    for(int i=0;i<=m;++i) b[i]=read();
    int L=1;
    while(L<n+m+1) L<<=1;
    F.set(L),G.set(L);
    for(int i=0;i<=2;++i){
        H[i].set(L);
        F.clear(0,L-1),G.clear(0,L-1);
        for(int j=0;j<=n;++j) F[j]=a[j];
        for(int j=0;j<=m;++j) G[j]=b[j];
        F.NTT(L,1,mod[i]),G.NTT(L,1,mod[i]);
        for(int j=0;j<L;++j) H[i][j]=F[j]*G[j]%mod[i];
        H[i].NTT(L,0,mod[i]);
    }
    for(int i=0;i<=n+m;++i) c[i]=solve((ll)H[0][i],(ll)H[1][i],(ll)H[2][i]);
    for(int i=0;i<=n+m;++i) printf("%d ",c[i]);
    printf("\n");
    return 0;
}

参考资料