任意模数多项式乘法(MTT)学习笔记

发布时间 2023-06-13 20:21:51作者: _bzw

三模数 NTT

常数大、速度慢、精度高是它的特点。

在考虑三模数 NTT 之前先考虑一下中国剩余定理吧。

已知

\[\begin{cases} x\equiv x_1(\bmod m_1)\\ x\equiv x_2(\bmod m_2)\\ x\equiv x_3(\bmod m_3)\\ \end{cases} \]

\(x\bmod m_1m_2m_3\)

\[\begin{aligned} &k_1m_1+x_1=k_2m_2+x_2\\ &k_1\equiv \frac{x_2-x_1}{m_1}(\bmod m_2)\\ &x\equiv k_1m_1+x_1(\bmod m_1m_2)\\ &x_4=(k_1m_1+x_1)\bmod m_1m_2\\ &k_4m_1m_2+x_4=k_3m_3+x_3\\ &k_4\equiv \frac{x_3-x_4}{m_1m_2}(\bmod m_1m_2m_3)\\ &x\equiv k_4m_4+x_4(\bmod m_1m_2m_3)\\ \end{aligned} \]

一点疑惑的解答(自言自语):

因为 \(k_1\equiv \frac{x_2-x_1}{m_1}(\bmod m_2)\),所以 \(k_1=\frac{x_2-x_1}{m_1}+km_2\)。又因为 \(k_1m_1\le m_1m_2\),所以 \(k_1\le m_2\)。所以 \(k\ge 0\),所以 \(k_1\) 最小为 \(\frac{x_2-x_1}{m_1}\),即 \(x\equiv k_1m_1+x_1(\bmod m_1m_2)\\\)

进入正题:

所谓的三模数 NTT 指的是 以 \(998244353,1004535809,469762049\) 为模数(经典 NTT 模数,原根均为 \(3\))分别进行 NTT,最后用上文的计算方式计算即可。

因为以上三个模数的乘积为很大,一般答案即使不取模也不会大于该数,所以上式的 \(k_4m_4+x_4\) 就是原答案,直接对题目给出的模数取模即可。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define Big __int128
const int N=3e5+1;
const ll mo1=998244353,mo2=1004535809,mo3=469762049,G=3;
inline Big Ksm(Big x,Big y,ll mo){
    Big res=1;
    for(;y;y>>=1,x=x*x%mo)
        if(y&1)res=res*x%mo;
    return res;
}
ll MOD;
const ll inv1=Ksm(mo1,mo2-2,mo2),inv2=Ksm(mo1*mo2%mo3,mo3-2,mo3);
struct Int{
    ll a,b,c;
    Int(ll _x=0){a=b=c=_x;}
    Int(ll _a,ll _b,ll _c){a=_a,b=_b,c=_c;}
    inline Int operator + (const Int &x){return Int((ll)(a+x.a)%mo1,(ll)(b+x.b)%mo2,(ll)(c+x.c)%mo3);}
    inline Int operator - (const Int &x){return Int((ll)(a-x.a+mo1)%mo1,(ll)(b-x.b+mo2)%mo2,(ll)(c-x.c+mo3)%mo3);}
    inline Int operator * (const Int &x){return Int((ll)a*x.a%mo1,(ll)b*x.b%mo2,(ll)c*x.c%mo3);}
    inline Int operator * (ll x){return Int((ll)a*x%mo1,(ll)b*x%mo2,(ll)c*x%mo3);}
    void mulinv(ll x){
        a=a*Ksm(x,mo1-2,mo1)%mo1;
        b=b*Ksm(x,mo2-2,mo2)%mo2;
        c=c*Ksm(x,mo3-2,mo3)%mo3;
    }
    void inv(){
        a=Ksm(a,mo1-2,mo1)%mo1;
        b=Ksm(b,mo2-2,mo2)%mo2;
        c=Ksm(c,mo3-2,mo3)%mo3;
    }
    ll gettrue(){
        Big x=(Big)(b-a+mo2)%mo2*inv1%mo2*(Big)mo1+(Big)a;
        return (((Big)(c-x%mo3+mo3)%mo3*inv2%mo3*(mo1%MOD*mo2%MOD)%MOD+x%MOD)%MOD+MOD)%MOD;
    }
}; // mtt
int rev[N];
Int w[N];
void NTT(Int *a,int Len,bool type){
    for(int i=0;i<Len;i++){
        rev[i]=(rev[i>>1]>>1)+(i&1?Len>>1:0);
        if(rev[i]>i)swap(a[rev[i]],a[i]);
    }
    for(int d=1;d<Len;d<<=1){
        Int W=Int(Ksm(G,(mo1-1)/(d*2),mo1),Ksm(G,(mo2-1)/(d*2),mo2),Ksm(G,(mo3-1)/(d*2),mo3));
        if(type)W.inv();
        w[0]=Int(1); for(int i=1;i<d;i++)w[i]=w[i-1]*W;
        for(int fir=0;fir<Len;fir+=d<<1){
            int sec=fir+d;
            for(int i=0;i<d;i++){
                Int a0=a[fir+i],a1=w[i]*a[sec+i];
                a[fir+i]=a0+a1,a[sec+i]=a0-a1;
            }
        }
    }
    if(type){for(int i=0;i<Len;i++)a[i].mulinv(Len);}
}
int n,m;
Int f[N],g[N];
int main(){
    cin>>n>>m>>MOD;
    for(int i=0,x;i<=n;i++)cin>>x,x%=MOD,f[i]=Int(x);
    for(int i=0,x;i<=m;i++)cin>>x,x%=MOD,g[i]=Int(x);
    int Len=1;
    while(Len<=(n+m+4))Len<<=1;
    NTT(f,Len,0),NTT(g,Len,0);
    for(int i=0;i<Len;i++)f[i]=f[i]*g[i];
    NTT(f,Len,1);
    for(int i=0;i<=n+m;i++)cout<<f[i].gettrue()<<' ';
    cout<<'\n';
    return 0;
}

拆系数 FFT

常数小,速度快,精度低(\(\operatorname{long double}\) 信仰跑)是它的特色。

如果直接对原数列进行 FFT 的话会炸精度的。考虑拆系数,即 \(A_i=J\times A'_i+A''_i\)\(A''_i< J\))。

那么:

\[\begin{aligned} F&=A\times B=(J\times A'+A'')\times(J\times B'+B'')\\ &=J^2A'B'+J(A'B''+A''B')+A''B''\\ \end{aligned} \]

如果直接计算的话需要四次 dft,三次 idft,和九次 ntt 的三模数 NTT 差距并不大。

考虑优化,然而 dft/idft 中有什么地方没有用到捏?虚部!考虑将 \(A'\)\(A''\)\(B'\)\(B''\) 合并在一起进行 dft。

设:

\[\begin{aligned} P_i=A'_i+A''_ii\\ Q_i=A'_i-A''_ii\\ E_i=B'_i+B''_ii\\ \end{aligned} \]

有:

\[\begin{aligned} &W_i=(P\times E)_i=(A'_iB'_i-A''_iB''_i)+(A'_iB''_i+A''_iB'_i)i\\ &R_i=(Q\times E)_i=(A'_iB'_i+A''_iB''_i)+(A'_iB''_i-A''_iB'_i)i\\ \end{aligned} \]

我们可以通过 \(W\)\(R\) 的加减得到我们想要的系数。

\[\begin{aligned} &W_i+R_i=2\times(A'_iB'_i+A'B''_ii)\\ &R_i-W_i=2\times(A''_iB''_i+A''_iB'_ii)\\ \end{aligned} \]

注意: 是先除以二再取整!!!(代码 \(\texttt{39}\) 行)。

#include <bits/stdc++.h>
#define poly vector<int>
using namespace std;
const int N=5e5+11;
int mo;
const int base=32768;
namespace Poly{
    using db = long double;
    const db pi=acos(-1);
    struct cp{
        db x,y;
        cp operator + (const cp &a){return {x+a.x,y+a.y};}
        cp operator - (const cp &a){return {x-a.x,y-a.y};}
        cp operator * (const cp &a){return {x*a.x-y*a.y,x*a.y+y*a.x};}
    };
    cp w[N]; int rev[N];
    void init_rev(int Len){
        for(int i=0;i<Len;i++)
            rev[i]=(rev[i>>1]>>1)+(i&1?Len>>1:0);
    }
    void FFT(cp *a,int Len,bool type){
        for(int i=0;i<Len;i++)if(rev[i]>i)swap(a[rev[i]],a[i]);
        for(int d=1;d<Len;d<<=1){
            cp W={cos(pi/d),sin(pi/d)};
            if(type)W.y=-W.y;
            w[0]={1,0};
            for(int i=1;i<d;i++)w[i]=w[i-1]*W;
            for(int fir=0;fir<Len;fir+=d<<1){
                int sec=fir+d;
                for(int i=0;i<d;i++){
                    cp a0=a[fir+i],a1=w[i]*a[sec+i];
                    a[fir+i]=a0+a1,a[sec+i]=a0-a1;
                }
            }
        }
        if(type)for(int i=0;i<Len;i++)a[i].x/=Len,a[i].y/=Len;
    }
    cp f[N],g[N],e[N];
    long long C(db x){return (long long)(x/2.+0.5)%mo;} // important!!!
    poly mul(poly x,poly y){
        int tot=x.size()+y.size()-1,Len=1;
        while(Len<=(tot+2))Len<<=1;
        init_rev(Len);
        for(int i=0;i<=Len;i++)f[i]=g[i]=e[i]={0,0};
        for(int i=0;i<x.size();i++){
            int a0=x[i]/base,a1=x[i]%base;
            f[i]={a0,a1},g[i]={a0,-a1};
        }
        for(int i=0;i<y.size();i++){
            int b0=y[i]/base,b1=y[i]%base;
            e[i]={b0,b1};
        }
        FFT(f,Len,0),FFT(g,Len,0),FFT(e,Len,0);
        for(int i=0;i<Len;i++)f[i]=f[i]*e[i],g[i]=g[i]*e[i];
        FFT(f,Len,1),FFT(g,Len,1);
        poly ret(tot,0);
        for(int i=0;i<tot;i++){
            ret[i]=1ll*base*base%mo*(C(f[i].x+g[i].x))%mo;
            ret[i]+=1ll*base*((C(f[i].y+g[i].y))+(C(f[i].y-g[i].y)))%mo;
            ret[i]%=mo;
            ret[i]+=(C(g[i].x-f[i].x))%mo;
            ret[i]%=mo;
        }
        return ret;
    }
} using Poly::mul;
int a[N],n,m;
poly solve(int l,int r){
    if(l==r)return {1,a[l]};
    int mid=l+r>>1;
    return mul(solve(l,mid),solve(mid+1,r));
}
int main(){
    cin>>n>>m>>mo;
    poly a(n+1,0),b(m+1,0);
    for(int i=0;i<=n;i++) cin>>a[i];
    for(int i=0;i<=m;i++) cin>>b[i];
    a=mul(a,b);
    for(int i:a)cout<<i<<' ';
    return 0;
}