在任意代数结构上的多项式乘法 学习笔记

发布时间 2023-12-26 13:47:37作者: 383494

前言

Stop learning useless algorithms, go and solve some problems, learn how to use binary search.

以下内容大多是作者看完《如何在任意代数结构上做多项式乘法》[1] 后口胡的,所以可能和原文章不太一样。如果错了或者有更好的做法请告诉我。

分圆多项式

定义为 \(\Phi_n(x) = \prod_{1 \le k \lt n,\gcd(k, n)=1}(x - \omega_n^k)\).

也可以感性理解为 \(x=\omega_n\)\(x^n-1=0\),约掉一些“显然”不为 \(0\) 的因式后剩下的素多项式。

分圆多项式都是整系数素多项式,且 \(\Phi_n(x)\) 最高次数为 \(\varphi(n)\)

结论:多项式 \(f(x)\) 代入 \(x=\omega_n\) 后做运算得到的结果(用 \(\omega_n\) 表示,最后再把 \(\omega_n\) 换成 \(x\))等于先做运算再 \(\bmod \Phi_n(x)\) 得到的结果。感性理解就是 \(\Phi_n(\omega_n)=0\)

\(n=p^m\)\(p\) 为素数,\(m \ge 1\))时,\(\Phi_n(x)=\sum\limits_{i=0}^{\varphi(n)}x^i\)

算法原理

要求:三个群 \((A,+_A),(B,+_B),(C,+_C)\),乘法运算 \(\cdot:A\times B \rightarrow C\) 具有分配律 \((a_1 +_A a_2) \cdot (b_1 +_B b_2) = a_1 \cdot b_1 +_C a_1 \cdot b_2 +_C a_2 \cdot b_1 +_C a_2 \cdot b_2\)

此时必有 \(\forall b \in B, e_A \times b = e_c\)\(\forall a \in A,a \times e_b = e_c\),其中 \(e_A,e_B,e_C\) 分别为 \(A,B,C\) 中的单位元。于是将 \(A\)\(B\) 的高位填对应的单位元即可。

证明:

\(a \cdot b = (a+_Ae_A) \cdot b = a\cdot b +_C e_A \cdot b\),两边加上 \(a \cdot b\)\(C\) 中的逆元即可。

\(e_B\) 是类似的。

\(C\) 最好还能有较快(\(O(1)\))的自然数乘,定义为多个相同的元素加在一起。

Part 1. 解决除法

IDFT 最后要除以长度,而 \(C\) 中没有定义自然数乘的逆。

一个解决方法是,分别做长为 \(2\) 的幂的 DFT 和长为 \(3\) 的幂的 DFT,这样每个元素的 \(2^{c_2}\) 倍和 \(3^{c_3}\) 倍都已知(\(c_2\)\(c_3\) 取决于长度),类似辗转相除做即可。

Part 2. 解决单位根

这是一个很神仙的做法。

考虑把一部分 \(x\) 代入 \(\omega_m\) 满足 \(\varphi(m) \gt \deg A(x)+\deg B(x)\),然后将 \(m\) 拆成 \(m=pq\)。取 \(p=q=\sqrt m\) 能保证最优复杂度,读者自(wo)证(bu)不(hui)难(zheng)。具体实现可以参考代码。

具体地,

\[\begin{aligned} A(x) & = \sum\limits_{j=0}^{q-1}(\sum\limits_{i=0}^{\varphi(p)-1}a_{iq+j}x^{iq})x^j \\ & = \sum\limits_{j=0}^{q-1}(\sum\limits_{i=0}^{\varphi(p)-1}a_{iq+j}\omega_{pq}^{iq})x^j \\ & = \sum\limits_{j=0}^{q-1}(\sum\limits_{i=0}^{\varphi(p)-1}a_{iq+j}\omega_{p}^{i})x^j \end{aligned} \]

然后将内层带 \(\omega_p\) 的东西看成系数对外层做 DFT。实现时可以做成指针套数组的形式。这个部分可能不太好理解,可以看代码。

做完 DFT 要进行内层元素相乘,可以递归。

最后对分圆多项式取模即可。实现时可以暴力将高位减到低位。

应用

好像没啥用...

目前想到的就是做 \(c_k = \prod_{i+j=k}a_i^{b_j}\) 之类的卷积?

实现

给出一份大常数的实现。期待有大佬能优化。

题目是 lgP3803.

#define DEBUG 0
#include <iostream>
#include <algorithm>
#include <cmath>
#define UP(i,s,e) for(auto i=s; i<e; ++i)
#define DOWN(i,e,s) for(auto i=e; i-->s;)
using std::cin; using std::cout;
namespace Poly{ // }{{{
template<int BASE, typename T>
void change(T* arr, int len){
    int *rev = new int[len];
    rev[0] = 0;
    UP(i, 1, len){
        rev[i] = rev[i/BASE]/BASE;
        rev[i] += i%BASE*(len/BASE);
    }
    UP(i, 0, len) if(rev[i] > i) std::swap(arr[i], arr[rev[i]]);
    delete[] rev;
}
template<int BASE, class A>
void fft(A **a, int len, int siz, bool idft){ // siz == len(a[0])
    static A *tmp[BASE];
    UP(i, 0, BASE){
        tmp[i] = new A[siz];
        //UP(j, siz, siz*BASE){
        //    tmp[i][j].unit();
        //}
    }
    change<BASE>(a, len);
    int wn = siz/BASE;
    for(int h=BASE; h<=len; h*=BASE){
        for(int st=0; st<len; st+=h){
            int w=0;
            UP(i, st, st+h/BASE){
                UP(j, 0, BASE) std::swap(a[i+h/BASE*j], tmp[j]);
                UP(j, 0, BASE){
                    auto &now = a[i+h/BASE*j];
                    std::copy(tmp[0], tmp[0]+siz, now);
                    UP(k, 1, BASE){
                        UP(l, 0, siz){
                            int idx = l-(idft?-1:1)*(w+siz/BASE*j)*k;
                            idx %= siz; idx = idx < 0 ? idx + siz : idx;
                            now[l] += tmp[k][idx];
                        }
                    }
                }
                w += wn;
            }
        }
        wn /= BASE;
    }
    UP(i, 0, BASE) delete[] tmp[i];
    //delete[] tmp;
}
// mod Phi_len(x)
// len = BASE**n
template<int BASE, class A, class B, class C>
int polymul_base(A *a, B *b, C *ret, int len
#if DEBUG
        , int test=0
#endif
        ){
    UP(i, 0, len/BASE*(BASE-1)) ret[i].unit();
    if(len < 100
#if DEBUG
            && !test
#endif
            ){
        int phi_len = len / BASE * (BASE-1);
        UP(i, 0, len) UP(j, 0, len){
            if((i+j)%len >= phi_len) UP(k, 1, BASE){
                ret[(i+j)%len-len/BASE*k] += (a[i]*b[j]).inv();
            } else {
                ret[(i+j)%len] += a[i]*b[j];
            }
        }
        return 1;
    }
    int tim = std::round(std::log(len)/std::log(BASE));
    int p = std::round(std::pow(BASE, tim/2+1));
    int q = std::round(std::pow(BASE, (tim-1)/2));
    A **aa = new A*[BASE*q];
    B **bb = new B*[BASE*q];
    C **cc = new C*[BASE*q];
    UP(i, 0, BASE*q){
        aa[i] = new A[p];
        bb[i] = new B[p];
        cc[i] = new C[p];
    }
    UP(i, 0, q*BASE) UP(j, 0, p){
        aa[i][j].unit(); bb[i][j].unit();// cc[i][j].unit();
    }
    UP(i, 0, q*BASE) UP(j, p/BASE*(BASE-1), p){
        cc[i][j].unit();
    }
    UP(i, 0, q){
        UP(j, 0, p){
            if(j*q+i >= len){ 
                break;
                //aa[i][j].unit(); bb[i][j].unit();
            }
            else {
                aa[i][j] = a[j*q+i];
                bb[i][j] = b[j*q+i];
            }
        }
        //UP(j, p/BASE*(BASE-1), p){ aa[i][j].unit(); bb[i][j].unit(); }
    }
    //UP(i, q, BASE*q){
        //UP(j, 0, p){ aa[i][j].unit(); bb[i][j].unit(); }
    //}
    fft<BASE>(aa, BASE*q, p, false);
    fft<BASE>(bb, BASE*q, p, false);
    int scale;
    UP(i, 0, BASE*q){
        scale = polymul_base<BASE>(aa[i], bb[i], cc[i], p
#if DEBUG
                , test ? test-1 : 0
#endif
                );
    }
    UP(i, 0, BASE*q){
        delete[] aa[i];
        delete[] bb[i];
    }
    delete[] aa;
    delete[] bb;
    fft<BASE>(cc, BASE*q, p, true);
    int pq = p*q;
    int phi_pq = pq/BASE*(BASE-1);
    UP(i, 0, BASE*q) UP(j, 0, p){
        int pl = (i+j*q)%pq;
        if(pl >= phi_pq) UP(k, 1, BASE) ret[(pl-pq/BASE*k)%len] += cc[i][j].inv();
        else ret[pl%len] += cc[i][j];
    }
    UP(i, 0, BASE*q) delete[] cc[i];
    delete[] cc;
    return scale * BASE * q;
}
template<class A, class B, class C>
void polymul(A *a, B *b, C *ret, int len){
    bool swapped = false;
    C *tmp = new C[len*2];
    int l2 = std::round(std::pow(2, std::ceil(std::log(len*2) / std::log(2))));
    int l3 = std::round(std::pow(3, std::ceil(std::log(len*3/2) / std::log(3))));
    int tim2 = polymul_base<2>(a, b, tmp, l2);
    int tim3 = polymul_base<3>(a, b, ret, l3);
    while(tim3 != 1){
        if(tim2 > tim3){
            int scale = tim2 / tim3;
            UP(i, 0, len) tmp[i] += ret[i].inv() * scale;
            tim2 %= tim3;
        }
        std::swap(tim2, tim3);
        std::swap(ret, tmp);
        swapped ^= 1;
    }
    if(swapped){ std::swap(ret, tmp); std::copy(tmp, tmp+len, ret); }
    delete[] tmp;
}
} // {}}}
namespace m{ // }{{{
constexpr int N = 5e6+2;
struct u32{
    unsigned val;
    u32(){}
    u32(unsigned v):val(v){}
    void unit(){val = 0;}
    u32 inv(){ return -val; }
    u32 &operator+=(u32 b){ val += b.val; return *this; }
    u32 &operator*=(u32 b){ val *= b.val; return *this; }
    u32 operator*(u32 b){ return b *= *this;}
    u32 operator*(unsigned x){ return val*x; }
} ia[N], ib[N], ic[N];
int in, im;
void work(){
    cin >> in >> im;
    UP(i, 0, in+1){
        cin >> ia[i].val;
    }
    UP(i, 0, im+1){
        cin >> ib[i].val;
    }
    Poly::polymul(ia, ib, ic, in+im+1);
    UP(i, 0, in+im+1){
        cout << ic[i].val << ' ';
    }
}
} // {}}}
int main(){cin.tie(0)->sync_with_stdio(0); m::work(); return 0;}

  1. https://www.cnblogs.com/whx1003/p/16214952.html ↩︎