任意模数多项式乘法MTT(可拆系数FFT、三模数NTT)笔记

发布时间 2023-07-04 07:39:38作者: 一棵油菜花

任意模数多项式乘法

前言:
在教练讲的时候脑子并不清醒,所以没听懂。后来自己看博客学会了,但目前只学了一种方法:可拆系数FFT。为了方便日后复习,决定先写下这个的笔记,关于三模数NTT下次再补。

建议:准备好演算纸和笔,本篇含有大量推算部分。

注:本篇文章是本蒟写的,dalao随便看看就好,不必争论对错(但是欢迎指出文章存在的小错误)。

为什么不直接使用NTT/FFT

此处的模数是任意的,所以我们使用NTT时,有局限性。只有当模数满足下列情况时才可使用NTT。

模数为\(P\)\(P=r\times 2^k + 1\),其中\(k\)足够大,满足\(2^k \geq N\),其中\(N\)为多项式扩充后的项数,多项式乘法都需要将项数扩充到\(2\)的幂。

如果不满足上述条件,就不能直接使用NTT。

那为什么不使用FFT带模运算?
首先不考虑模数的情况下,多项式结果的系数不能超过\(double\)能表示的精度(一般在\(10^{16}\))。超过后,\(double\)所表示的结果将不再精确。

那怎么办?目前可以给出两种解决方法

  1. 可拆系数FFT
  2. 三模数NTT(可是我还没学懂QAQ)

所以,向可拆系数FFT进发!

可拆系数FFT

怎么用可拆系数FFT

首先我们还是先假设一个情景:

求:\(c_1\)\(c_2\)的卷积

我们可以把一个数拆成\(a\times 2^{15}+b\)的形式(不一定是\(2^{15}\),大概在\(\sqrt{值域}\)内)。

\[c_1=a_1\times 2^{15}+a_2 \\ c_2=b_1\times 2^{15}+b_2 \]

那这俩的积为

\[\begin{aligned} c_1\times c_2&=(a_1\times 2^{15}+a_2)\times(b_1\times 2^{15}+b_2)\\&=a_1b_1\times 2^{30}+(a_1b_2+a_2b_1)\times 2^{15}+a_2b_2 \end{aligned} \]

好耶!那我们可以直接做4次FFT(\(a_1b_1\)\(a_1b_2\)\(a_2b_1\)\(a_2b_2\))!
然后你发现正逆变换总共做了8次常数爆炸然后炸了

所以,我们需要优化!

优化可拆系数FFT

注意:我们这里的优化会用到复数,你可能会害怕得逃走,但是你无需害怕,因为(我也是这样过来的)我会简单地讲一讲。

小资料(可能不太学术规范):
啥是复数?

其实复数,是一种含实部和虚部的一种数。我们知道\(i=\sqrt{-1}\)\(i\)就是虚数。那我们以实部建立\(x\)轴、虚部建立\(y\)轴。

那我们假设在这个平面直角坐标系上有一个点\(A(2,3)\),那这个点的复数表示为\(2+3i\)

那你就基本知道什么是复数了,让我们学一下基本运算吧。(这个自己记一记好了)

我们可以利用一下复数的乘法运算:

\[(a+bi)(c+di)=(ac-bd)+(ad+bc)i\\ (a-bi)(c+di)=(ac+bd)+(ad-bc)i \]

现在令\(P_1=a_1+a_2i\)\(P_2=a_1-a_2i\)\(Q=b_1+b_2i\)
计算:

\[\begin{aligned} P_1Q+P_2Q&=(a_1b_1-a_2b_2)+(a_1b_2+a_2b_1)i+(a_1b_1+a_2b_2)+(a_1b_2-a_2b_1)i\\ &=2(a_1b_1+a_1b_2i) \end{aligned} \]

如果我们将上式除以\(2\),那我们可以得到\(a_1b_1,a_1b_2\)(分别通过“实部相加/2”、“虚部相加/2”可得)。

我们把得到的\(a_1b_1\)代入\(P_2Q\)的实部,可得\(a_2b_2\)
类似地,将\(a_1b_2\)代入\(P_1Q\)的虚部,可得\(a_2b_1\)

(注:这里的运算请自己用演算纸推一下)

我们就可以带回\(c_1\times c_2\)了。

那我们的任务就完成啦!!

代码实现

注:此处的FFT部分与FFT版题的部分是一模一样的,可参照本人之前所写的FFT笔记

板子:P4245 【模板】任意模数多项式乘法

code:

#include<bits/stdc++.h>
using namespace std;

#define ll long long
#define rp(i,o,p) for(ll i=o;i<=p;++i)
#define pr(i,o,p) for(ll i=o;i>=p;--i)
#define cp complex<long double>

const ll MAXN=1e5+5,P=1e9+7;
const ll lim=32000; // lim = sqrt(值域) <- 1e9
const long double PI=acos(-1.0);

cp p1[MAXN<<2],p2[MAXN<<2],q[MAXN<<2];
ll n,m,limit;
cp p[MAXN<<2],omg[MAXN<<2],inv[MAXN<<2];

void init() {
    for (ll i = 0; i < limit; ++i) {
        omg[i] = cp(cos(2 * PI * i / limit), sin(2 * PI * i / limit));
        inv[i] = conj(omg[i]);
    }
}

void fft(cp *a, cp *omg) {
    for (ll i = 0, j = 0; i < limit; ++i) {
        if (i > j)
            swap(a[i], a[j]);
        for (ll l = limit >> 1; (j ^= l) < l; l >>= 1)
            ;
    }
    for (ll l = 2; l <= limit; l <<= 1) {
        ll m = l >> 1;
        for (cp *p = a; p != a + limit; p += l) {
            rp(i, 0, m - 1) {
                cp t = omg[limit / l * i] * p[i + m];
                p[i + m] = p[i] - t;
                p[i] += t;
            }
        }
    }
}

int main()
{
    scanf("%lld%lld",&n,&m);
    rp(i,0,n)
    {
        ll x;
        scanf("%lld",&x);
        p1[i]=cp(x/lim,x%lim);
        p2[i]=cp(x/lim,-x%lim);
    }
    rp(i,0,m)
    {
        ll x;
        scanf("%lld",&x);
        q[i]=cp(x/lim,x%lim);
    }

    limit=1;
    while(limit<=n+m) limit<<=1;

    init();
    fft(p1,omg),fft(p2,omg),fft(q,omg);
    
    rp(i,0,limit-1)
    {
    	long double xx,xy;
    	
        xx=p1[i].real()*q[i].real(),yy=p1[i].imag()*q[i].imag();
        xy=p1[i].real()*q[i].imag(),yx=p1[i].imag()*q[i].real();
        p1[i]=cp(xx/limit-yy/limit,xy/limit+yx/limit);
        
        xx=p2[i].real()*q[i].real(),yy=p2[i].imag()*q[i].imag();
        xy=p2[i].real()*q[i].imag(),yx=p2[i].imag()*q[i].real();
        p2[i]=cp(xx/limit-yy/limit,xy/limit+yx/limit);
    }
    /*
    	上面的循环等价于这两行循环,但是因为c++中complex模板为const类型,
    	单独的real()或imag()值不可以直接修改,只可以两个同时赋值来修改,
    	所以上面的循坏还用到了复数的除法运算,自己看看吧
      	rp(i,0,limit-1) q[i].real()/=limit,q[i].imag()/=limit;
      	rp(i,0,limit-1) p1[i]*=q[i],p2[i]*=q[i];
    */
    fft(p1,inv),fft(p2,inv);

    rp(i,0,n+m)
    {
        ll a1b1=(ll)((p1[i].real()+p2[i].real())/2+0.5)%P;
        ll a1b2=(ll)((p1[i].imag()+p2[i].imag())/2+0.5)%P;
        ll a2b1=(ll)((p1[i].imag()+0.5)-a1b2)%P;
        ll a2b2=(ll)((p2[i].real()+0.5)-a1b1)%P;
        ll ans=(a1b1*lim*lim+(a1b2+a2b1)*lim+a2b2)%P;

        ans=(ans%P+P)%P;
        printf("%lld ",ans);
    }

    return 0;
}

后记

三模数NTT的内容会尽快补上的。