多项式全家桶

发布时间 2023-08-01 20:10:33作者: A_Big_Jiong

前言

多项式乱七八糟的公式和做法实在是太多了,有点遭不住,写一个学习笔记,记录一下多项式的各种奇奇怪怪的模板。

多项式乘法

系数表示法

即用这个多项式的每一项系数来表示这个多项式。

对于一个 \(n−1\)\(n\) 项多项式:

\[f(x)=\sum_{i=0}^{n−1}a_ix_i \]

用系数表示法为:\(f(x)=\{a_0,a_1 \dots a_i \dots a_{n−1}\}\)

点值表示法

所以,对于一个 \(n−1\) 次多项式,我们可以用 \(n\) 个点值来表示这个多项式。

即, \(f(x)=\{(x_0,f(x_0)),(x_1,f(x_1)) \dots (x_i,f(x_i)) \dots (x_{n−1},f(x_{n−1}))\}\)

点值表示法的重要优势在于,我们可以通过简单的算术运算来实现多项式的运算,举个例子:

\[f(x) · g(x)=\{(x_0,f(x_0)·g(x_0)),(x_1,f(x_1)·g(x_1)) \dots (x_i,f(x_i)·g(x_i)) \dots (x_{n−1},f(x_{n−1})·g(x_{n-1}))\} \]

单位根

将单位圆上的点所代表的复数叫作单位根,用 \(\omega^k_n\) 表示,\(n\) 表示圆分成的份数,且通常是 \(2\) 的次幂,\(k\) 表示逆时针数第 \(k\) 个点。

单位根有如下几个性质:

\[\omega^k_n = \cos{\frac {k} {n}} + \sin{\frac{k}{n}}\mathrm{i} \]

\[(\omega^1_n)^k = \omega^k_n \]

\[\omega^{2k}_{2n} = \omega^{k}_{n} \]

\[\omega^{k + \frac{2}{n}}_{n} = - \omega^{k}_{n} \]

FFT

\[\begin {alignedat}{3} A(x) & = \sum_{i=1}^{n-1}a_ix^i \\ & = (a_0 + a_2x^2 \dots a_{n-2}x^{n-2}) + (a_1x + a_3x^3 \dots a_{n-1}x^{n-1}) \\ & = (a_0 + a_2x^2 \dots a_{n-2}x^{n-2}) + x(a_1 + a_3x^2 \dots a_{n-1}x^{n-2}) \\ & = A_1(x^2) + xA_2(x^2) \\ \end {alignedat}\]

\(k < \frac{n}{2}\)然后将 \(\omega^k_n\) 带入上式, 则有,

\[\begin {alignedat}{3} A_1(x^2) + xA_2(x^2) & = A_1((\omega^{k}_{n})^2) + xA_2((\omega^{k}_{n})^2)\\ & = A_1(\omega^{2k}_{n}) + \omega^{k}_{n}A_2(\omega^{2k}_{n}) \\ \end {alignedat}\]

然后我们再将 \(\omega^{k + \frac{n}{2}}_n\) 带入上式, 则有,

\[\begin {alignedat}{3} A_1(x^2) + xA_2(x^2) & = A_1((\omega^{k + \frac{n}{2}}_n)^2) + xA_2((\omega^{k + \frac{n}{2}}_n)^2)\\ & = A_1(\omega^{2k + n}_n) + \omega^{k + \frac{n}{2}}_nA_2(\omega^{2k + n}_n) \\ & = A_1(\omega^{2k}_n) - \omega^{k}_nA_2(\omega^{2k}_n) \\ \end {alignedat}\]

因为有以上的性质存在,我们只需要求出 \(A_1(\omega^{k}_{\frac{n}{2}})\)\(A_2(\omega^{k}_{\frac{n}{2}})\) 的值,就可以求出 \(A (\omega^{k}_n)\)\(A (\omega^{k + \frac{n}{2}}_n)\) 的值。

因此可以倍增的将系数表达式转化成点值表达式,所需要的时间复杂度为 \(O(n\log{n})\)

IFFT

有这样一个结论,证明略。

一个多项式在分治的过程中乘上单位根的共轭复数,分治完的每一项除以 \(n\) 即为原多项式的每一项系数。

因此我们只需要在 FFT 中 \(\omega\) 的虚部乘上一个复数, 然后最终结果除以 \(n\) 即可

迭代优化

倍增过程中迭代的效率太差了,可以发现,分治后的位置,是其下标二进制翻转以后的位置,我们可以先预处理好它最后到达的位置,在 FFT 之前交换即可。

Code

FFT实现代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long lld;

const int N = 1e6 + 50;
const double PI = acos (-1.0);

struct Complex {
	double x, y;
	Complex (register double xx = 0, register double yy = 0) {
		x = xx, y = yy;
	}
	inline Complex operator + (const Complex &a) const {
		return Complex (a.x + x, a.y + y);
	}
	inline Complex operator - (const Complex &a) const {
		return Complex (x - a.x, y - a.y);
	}
	inline Complex operator * (const Complex &a) const {
		return Complex (x * a.x - y * a.y, x * a.y + y * a.x);
	}
} a[N << 2], b[N << 2];

inline int read () {
	register int x = 0, w = 1;
	register char ch = getchar ();
	for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
	for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
	return x * w;
}

int n, m;

int limit = 1;
int l, r[N << 2];
inline void FFT (register Complex * A, register int type) {
	for (register int i = 0; i < limit; i ++)
		if (i < r[i])  swap (A[i], A[r[i]]);
	for (register int mid = 1; mid < limit; mid <<= 1) {
		register Complex Wn (cos (PI / mid), type * sin (PI / mid));
		for (register int R = mid << 1, j = 0; j < limit; j += R) {
			register Complex w (1, 0);
			for (register int k = 0; k < mid; k ++, w = w * Wn) {
				register Complex x = A[j + k], y = w * A[j + mid + k];
				A[j + k] = x + y;
				A[mid + j + k] = x - y;
			}
		}
	}
}

int main () {
	n = read(), m = read();
	for (register int i = 0; i <= n; i ++)  a[i].x = read();
	for (register int i = 0; i <= m; i ++)  b[i].x = read();
	while (limit <= n + m)  limit <<= 1, l ++;
	for (register int i = 0; i < limit; i ++)  r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
	FFT (a, 1);
	FFT (b, 1);
	for (register int i = 0; i <= limit; i ++)  a[i] = a[i] * b[i];
	FFT (a, -1);
	for (register int i = 0; i <= n + m; i ++)  printf ("%d%c", (int)(a[i].x / limit + 0.5), i == limit - 1 ? '\n' : ' ');;
	return 0;	
}

NTT

因为 FFT 涉及到了 double, 精度堪忧, 所以考虑寻找一种不需要使用 double 数据类型可以解决的多项式转化点值表达式方法,那么需要找到一个合适的东西来代替单位根,下面的只涉及到定义,而跟具体内容无关。

\(a,p\)为 整数,且 \(\gcd{(a,p)} = 1\) ,
使 \(a^n \equiv 1 \pmod{p}\) 成立的最小正整数 \(n\) 叫做 \(a\)\(p\) 的阶,记作 \(\delta_p(a) = n\)

\(p\) 为正整数, \(a\) 为整数, 若 \(a\)\(p\) 的阶等于 \(\varphi(p)\) ,即 \(\delta_p(a) = \varphi(p)\), 则称 \(a\)\(m\) 的一个原根。

对于一个模数 \(p\) ,设他的原根为 \(g\) ,则有,

\[g^{\varphi(p)} \equiv 1 \pmod{p} \]

\(p\) 为质数, 则有 \(g^{p - 1} \equiv 1 \pmod{p}\), 根据单位根的性质, 有

\[(\omega^1_n)^n \equiv g^{p - 1} \equiv 1 \pmod{p} \]

\[\omega^1_n \equiv g^{\frac{p - 1}{n} } \pmod{p} \]

\[\omega^{-1}_n \equiv g^{-\frac{p - 1}{n} } \pmod{p} \]

因此,可以使用原根来代替单位根,从而在取模意义下避过了 double 的精度限制, 从而具有更好的精度表现

NTT实现代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long lld;

const int N = 1e6 + 50;
const int mod = 998244353;

int a[N << 2], b[N << 2];

inline int read () {
	register int x = 0, w = 1;
	register char ch = getchar ();
	for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
	for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
	return x * w;
}

int n, m;

int limit = 1;
int l, r[N << 2];

inline int qpow (register int a, register int b) {
	register int base = 1;
	while (b) {
		if (b & 1)  base = 1ll * base * a % mod;
		a = 1ll * a * a % mod;
		b >>= 1;
	}
	return base;
}

inline void FFT (register int * A, register int type) {
	for (register int i = 0; i < limit; i ++)
		if (i < r[i])  swap (A[i], A[r[i]]);
	for (register int mid = 1; mid < limit; mid <<= 1) {
		register int Gn = qpow (type, (mod - 1) / (mid << 1));
		for (register int R = mid << 1, j = 0; j < limit; j += R) {
			register int g = 1;
			for (register int k = 0; k < mid; k ++, g = 1ll * g * Gn % mod) {
				register int x = A[j + k], y = 1ll * g * A[j + mid + k] % mod;
				A[j + k] = (x + y) % mod;
				A[mid + j + k] = (x - y + mod) % mod;
			}
		}
	}
}

int main () {
	n = read(), m = read();
	for (register int i = 0; i <= n; i ++)  a[i] = read();
	for (register int i = 0; i <= m; i ++)  b[i] = read();
	while (limit <= n + m)  limit <<= 1, l ++;
	for (register int i = 0; i < limit; i ++)  r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
	FFT (a, 3);
	FFT (b, 3);
	for (register int i = 0; i <= limit; i ++)  a[i] = 1ll * a[i] * b[i] % mod;
	FFT (a, qpow (3, mod - 2));
	for (register int i = 0; i <= n + m; i ++)  printf ("%lld%c", 1ll * a[i] * qpow (limit, mod - 2) % mod, i == limit - 1 ? '\n' : ' ');;
	return 0;	
}