FFT——快速处理卷积

发布时间 2023-05-30 20:50:21作者: Chy12321

前置知识

卷积

符号为 \(*\)

设多项式 \(A(x) = a_0 + a_1x + a_2x^2 + \cdots + a_nx^n, B(x) = b_0 + b_1x_1 + b_2x^2 + \cdots + b_nx^n\),则有

\[(A * B)[n] = \sum_{i = 0}^n A(i) \times B(n - i) \]

\((A * B)[n]\) 的意义是将两个多项式相乘后 \(n\) 次项的系数。

单位复根

定义

我们把满足 \(\omega^n = 1\) 的复数 \(\omega\) 称为 \(n\) 次单位复根,不难得到 \(n\) 次单位复根有 \(n\) 个。

由欧拉公式 \(e^{ix} = \cos x + i \sin x\) 可推知 \(e^{2\pi i} = \cos 2\pi + i\sin 2\pi = 1\),进而可以得到主 \(n\) 次单位根的表达式:

\[\omega_n = e^{\frac{2\pi i}n} \]

任意 \(n\) 次单位复根都是主 \(n\) 次单位复根的整次幂,记作 \(\omega_n^k (0 \le k \le n - 1)\)

不难得出:

\[\omega_n^k = e^\frac{2k\pi i}{n} \]

性质

  • 特殊值:\(\omega_{2n}^n = -1, \omega_n^0 = \omega_n^n = 1(n \in \N^*)\)

  • 消去引理:

    \[\omega_{dn}^{dk} = \omega_n^k (n, k, d \in \N^*) \]

  • 折半引理(前提条件:\(n\) 为偶数):

    \[\omega_n^{k + \frac n 2} = \omega_n^k \times \omega_n^{\frac n 2} = -\omega_n^k \]

  • 求和引理:

    \[\sum_{j = 0}^{n - 1}(\omega_n^k)^j = \begin{cases} 0, k \ne 0 \\ n, k = 0 \end{cases} \]

多项式的表示方法

系数表示法

\(A(x) = a_0 + a_1x + a_2x^2 + \cdots a_{n - 1}x^{n - 1}\) 描述的多项式。

点值表示法

\(A\)\(n\) 次多项式,则以 \(y = A(x)\) 的图像上任意不同的 \((n + 1)\) 个点可将其唯一确定。

也即 \(A\) 可用点值表示法表示为 \(\{(x_i, y_i)~|~ 0 \le i \le n\}\)

离散傅立叶变换(FFT)

作用

快速地将一个以系数表示法描述的多项式转化为以点值表示法描述的形式。

流程

\(n = 2^k, k \in \N^*\)

现有一多项式 \(A(x) = \sum\limits_{i = 0}^{n - 1} (a_ix^i)\),考虑将其化为两个次数为 \((\dfrac n 2 - 1)\) 的多项式,即:

\[A_0(x) = a_0 + a_2x + a_4x^2 + \cdots + a_{n - 1}x^{\frac n 2 - 1} \\ A_1(x) = a_1 + a_3x + a_5x^2 + \cdots + a_{n - 2}x^{\frac n 2 - 1} \]

则有:

\[A(x) = A_0(x^2) + xA_1(x^2) \]

\(x = \omega_n^0, \omega_n^1, \cdots, \omega_{n}^{n - 1}\) 依次代入求得对应值。此时再套上 消去引理折半引理,可以发现些有趣的性质。这里以 \(\omega_n^k(0 \le k < \frac n 2)\) 为例:

\[A(\omega_n^k) = A_0[(\omega_n^k)^2] + \omega_n A_1[(\omega_n^k)^2] = A_0(\omega_\frac n 2^k) + \omega_n^k A_1(\omega_\frac n 2^k) \\ A(\omega_n^{k + \frac n 2}) = A(-\omega_n^k) = A_0(\omega_\frac n 2^k) - \omega_n^k A_1(\omega_\frac n 2^k) \]

  • 若通过递归的方式求解 \(A(\omega_n^k)\)\(n\) 每次都会减小 \(\dfrac 1 2\),时间复杂度为 \(O(\log n)\)
  • \(A(\omega_n^k)\)\(A(\omega_n^{k + \frac n 2})\) 的递归式只有一项常数不同,在 \(O(\log n)\) 求解 \(A(\omega_n^k)\) 时可以 \(O(1)\) 求出 \(A(\omega_n^{k + \frac n 2})\)

优化

在递归版 FFT 的执行过程中,底层会反复进行出入栈操作,导致常数巨大,由此引出了迭代版 FFT。

我们定义在已知 \(A_0(\omega_\frac n2^k)\)\(\omega_n^k A_1(\omega_\frac n2^k)\) 的情况下,\(O(2)\) 求出 \(A(\omega_n^k)\)\(A(\omega_n^{k + \frac n2})\) 的操作为一次 蝴蝶操作

假设我们现在知道 FFT 迭代树中叶子的顺序。那么只需要模拟回溯的合并过程,可以就着代码理解迭代过程:

for (int i = 1; i < len; i <<= 1) { // 枚举单段区间长度
    Complex wn = {cos(PI / i), flag * sin(PI / i)}; // 求出主 n 次单位根
    for (int j = 0; j < len; j += (i << 1)) { // 两段两段区间地枚举(便于合并)
        Complex w = {1, 0};
        for (int k = j; k < j + i; k++) { // 枚举区间内值并进行蝴蝶操作
            Complex t = w * A[k + i];
            A[k + i] = A[k] - t;
            A[k] = A[k] + t;
            w = w * wn;
        }
    }
}

那么应如何求出叶子的顺序呢?

这就引出了另一个重要的定理—— 蝴蝶定理

首先,我们可以画出迭代树,它大概是长这样:

然后把叶子序列单拎出来和原序列对照着看,也就是:

0 1 2 3 4 5 6 7
0 4 2 6 1 5 3 7

再都换成二进制:

000 001 010 011 100 101 110 111
000 100 010 110 001 101 011 111

不难发现后者的每一项是前者对应项的反序,那么叶子序列也就可求了(求解叶子序列的位运算的式子特别难推,背板即可)。

设序列中最大值的二进制位数为 \(bits\),叶子序列为 \(rev\),则可以通过如下代码 \(O(n)\) 求出叶子序列:

for (int i = 0; i < len; i++) if (rev[i] > i) swap(A[i], A[rev[i]]);

\(len\) 为满足 \(2^k \ge n + m\) 的最小的 \(2^k (k \in \N^*)\)

于是有了 FFT 的代码:

void FFT(Complex A[]) {
    for (int i = 0; i < len; i++) if (rev[i] > i) swap(A[i], A[rev[i]]); // if 保证只换一次
    for (int i = 1; i < len; i <<= 1) {
        Complex wn = {cos(PI / i), sin(PI / i)};
        for (int j = 0; j < len; j += (i << 1)) {
            Complex w = {1, 0};
            for (int k = j; k < j + i; k++) {
                Complex t = w * A[k + i];
                A[k + i] = A[k] - t;
                A[k] = A[k] + t;
                w = w * wn;
            }
        }
    }
}

时间复杂度

FFT 实际上就是一种类似于线段树的二分分治做法,时间复杂度为 \(O(n \log n)\)

快速傅立叶逆变换(IFFT)

推导过程

设我们上面求得的点值表示法下 \(A(x)\) 可表示为 \(\{(\omega_n^k, y_k)~|~ 0 \le k < n\}\)

那么把 FFT 的过程写成矩阵乘法的形式就是:

\[\begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega_n & \omega_n^2 & \cdots & \omega_n^{n - 1} \\ 1 & \omega_n^2 & \omega_n^4 & \cdots & \omega_n^{2(n - 1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n - 1} & \omega_n^{2(n - 1)} & \cdots & \omega_n^{(n - 1)(n - 1)} \end{bmatrix} \times \begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ \vdots \\ a_{n - 1} \end{bmatrix} = \begin{bmatrix} y_0 \\ y_1 \\ y_2 \\ \vdots \\ y_{n - 1} \end{bmatrix} \]

要想求得 \(a_0, a_1, a_2, \cdots, a_{n - 1}\),只要等式两边同时乘上第一个大矩阵(范德蒙德矩阵)的逆矩阵即可。

考虑到范德蒙德矩阵和其逆矩阵 \(T\) 相乘后应为单位阵,即:

\[\begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega_n & \omega_n^2 & \cdots & \omega_n^{n - 1} \\ 1 & \omega_n^2 & \omega_n^4 & \cdots & \omega_n^{2(n - 1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n - 1} & \omega_n^{2(n - 1)} & \cdots & \omega_n^{(n - 1)(n - 1)} \end{bmatrix} \times T = \begin{bmatrix} 1 & 0 & 0 & \cdots & 0 \\ 0 & 1 & 0 & \cdots & 0 \\ 0 & 0 & 1 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & 1 \end{bmatrix} \]

矩阵乘法后的结果只有 \(0, 1\) 两种取值,由此想到上面提到的 求和引理

将范德蒙德矩阵每一项取倒数后与其相乘,则有:

\[\begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega_n & \omega_n^2 & \cdots & \omega_n^{n - 1} \\ 1 & \omega_n^2 & \omega_n^4 & \cdots & \omega_n^{2(n - 1)} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n - 1} & \omega_n^{2(n - 1)} & \cdots & \omega_n^{(n - 1)(n - 1)} \end{bmatrix} \times \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \dfrac 1{\omega_n} & \dfrac 1{\omega_n^2} & \cdots & \dfrac 1{\omega_n^{n - 1}} \\ 1 & \dfrac 1{\omega_n^2} & \dfrac 1{\omega_n^4} & \cdots & \dfrac 1{\omega_n^{2(n - 1)}} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \dfrac 1{\omega_n^{n - 1}} & \dfrac 1{\omega_n^{2(n - 1)}} & \cdots & \dfrac 1{\omega_n^{(n - 1)(n - 1)}} \end{bmatrix} = \begin{bmatrix} n & 0 & 0 & \cdots & 0 \\ 0 & n & 0 & \cdots & 0 \\ 0 & 0 & n & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \cdots & n \end{bmatrix} \]

于是有:

\[\begin{bmatrix} y_0 \\ y_1 \\ y_2 \\ \vdots \\ y_{n - 1} \end{bmatrix} \times \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 \\ 1 & \dfrac 1{\omega_n} & \dfrac 1{\omega_n^2} & \cdots & \dfrac 1{\omega_n^{n - 1}} \\ 1 & \dfrac 1{\omega_n^2} & \dfrac 1{\omega_n^4} & \cdots & \dfrac 1{\omega_n^{2(n - 1)}} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \dfrac 1{\omega_n^{n - 1}} & \dfrac 1{\omega_n^{2(n - 1)}} & \cdots & \dfrac 1{\omega_n^{(n - 1)(n - 1)}} \end{bmatrix} = \begin{bmatrix} na_0 \\ na_1 \\ na_2 \\ \vdots \\ na_{n - 1} \end{bmatrix} \]

\(\theta = \dfrac{2k\pi i}n\),则 \(\dfrac 1{\omega_n^k} = \omega_n^{-k} = e^{i\theta} = \cos(-\theta) + i\sin(-\theta) = \cos \theta - i\sin\theta\),所以可以在 FFT 的代码中加入一个表示 \(i \sin\theta\) 符号的参数 \(flag\)\(flag = 1\) 时为 FFT,\(flag = -1\) 时为 IFFT。

代码:

void FFT(Complex A[], int flag) {
    for (int i = 0; i < len; i++) if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int i = 1; i < len; i <<= 1) {
        Complex wn = {cos(PI / i), flag * sin(PI / i)}; // 整个函数唯一一次用到 flag
        for (int j = 0; j < len; j += (i << 1)) {
            Complex w = {1, 0};
            for (int k = j; k < j + i; k++) {
                Complex t = w * A[k + i];
                A[k + i] = A[k] - t;
                A[k] = A[k] + t;
                w = w * wn;
            }
        }
    }
}

时间复杂度

显然与 FFT 相同,为 \(O(n \log n)\)

模板

洛谷 P3803 【模板】多项式乘法(FFT)

#include <bits/stdc++.h>

#define MAXN 2100000

using namespace std;

const double PI = acos(-1);

int n, m;
int len = 1, rev[MAXN];

struct Complex {
    double r, i;

    Complex operator+(const Complex &rhs) const {
        return {r + rhs.r, i + rhs.i};
    }

    Complex operator-(const Complex &rhs) const {
        return {r - rhs.r, i - rhs.i};
    }

    Complex operator*(const Complex &rhs) const {
        return {(r * rhs.r) - (i * rhs.i), (r * rhs.i) + (rhs.r * i)};
    }
} a[MAXN], b[MAXN];

void FFT(Complex A[], int flag) {
    for (int i = 0; i < len; i++) if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int i = 1; i < len; i <<= 1) {
        Complex wn = {cos(PI / i), flag * sin(PI / i)};
        for (int j = 0; j < len; j += (i << 1)) {
            Complex w = {1, 0};
            for (int k = j; k < j + i; k++) {
                Complex t = w * A[k + i];
                A[k + i] = A[k] - t;
                A[k] = A[k] + t;
                w = w * wn;
            }
        }
    }
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(nullptr), cout.tie(nullptr);

    cin >> n >> m;
    for (int i = 0; i <= n; i++) cin >> a[i].r;
    for (int i = 0; i <= m; i++) cin >> b[i].r;
    int bits = 0;
    while (len <= n + m) len <<= 1, bits++;
    for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bits - 1));
    FFT(a, 1), FFT(b, 1);
    for (int i = 0; i < len; i++) a[i] = a[i] * b[i];
    FFT(a, -1);
    for (int i = 0; i <= n + m; i++) cout << int(a[i].r / len + 0.5) << ' ';
    return 0;
}