书接上回 多项式(Ⅰ):基础工业。这部分主要写一下进阶的一些模板。
多项式求逆
多项式乘法逆
给定一个多项式 \(F(x)\),求出一个多项式 \(G(x)\),满足 \(F(x) * G(x) \equiv 1(\bmod \ x^n)\)。系数对 \(998244353\) 取模。
我们先讨论比较简单的模数是 \(998244353\) 这样类型的。
方法1:分治 FFT
目前还没学,咕掉先。时间复杂度是 \(O(n\log^2 n)\) 的。
方法2:倍增法
参考资料:多项式求逆-litble
如果多项式 \(F(x)\) 只有一项,那么显然 \(G_0\) 这个常数项就是 \(F_0\) 的逆元。
若有 \(n\) 项,考虑递归求解,像类似于归纳法一样往后推。
假设我们已经知道
因为 \(F(x) * G(x) \equiv 1(\bmod \ x^n)\),这就说明对于第 \(1 \sim n-1\) 次方项,两个多项式乘起来后系数为 \(0\),这就说明
两个式子相减
即
我们对上面这个式子进行平方。由于 \(G(x) - H(x)\) 在模 \(x^{\lceil \frac{n}{2}\rceil}\)
意义下为 \(0\),说明从 \(0\sim \lceil \frac{n}{2}\rceil-1\) 次项都为 \(0\)。设平方后的这个多项式为 \(P\),则
那么对于 \(i \leq n\) 的项,\(j\) 和 \(i-j\) 至少有一项的次数 \(< \lceil \frac{n}{2 }\rceil\),那么就是 \(0\),所以
乘上一个 \(F(x)\),根据 \(F(x)G(x) \equiv 1(\bmod x^n)\),再移一下项,就得到
转化到了这个地方便可以用 NTT 求解了。
复杂度用主定理分析一下:
此外还有一些常用的递归复杂度:
可以用主定理分析一下,感觉基本的这么多就可以了。
代码实现
有一些小细节需要注意:
-
代码中
(n+1)>>1
就是向上取整,至于为什么是向上取整,这样就能保证 \(\lceil \frac{n}{2} \rceil\) 的平方一定是能包含到 \(n\) 的。 -
为什么
lim < (n<<1)
, 因为这样相乘以后就能保证两个 \(n\) 项的多项式相乘后的多项式能存的开。 -
为什么要把 \(n \sim \text{lim}-1\) 设成 \(0\),因为模 \(x^n\) 意义下这些地方是同余与 \(0\) 的,为了避免干扰计算,就把系数也设成 \(0\) 了。后面的那个 \(b_i\) 设 \(0\) 也是同理的。
il void Polyinv(int n,ll a[],ll b[])
{
if(n == 1) { b[0] = inv(a[0]); return ; }
Polyinv((n+1)>>1,a,b);
lim = 1;
while(lim < (n<<1)) lim <<= 1;
for(re int i=0;i<lim;i++) R[i] = (R[i>>1]>>1) + ((i&1) ? lim/2 : 0);
for(re int i=0;i<n;i++) c[i] = a[i];
for(re int i=n;i<lim;i++) c[i] = 0;
invg = inv(g) , invn = inv(lim);
NTT(c,lim,1) , NTT(b,lim,1);
for(re int i=0;i<lim;i++) b[i] = 1ll * ((2 * b[i] % mod) - (c[i] * b[i] % mod * b[i] % mod) + mod) % mod;
NTT(b,lim,-1);
for(re int i=n;i<lim;i++) b[i] = 0;
return ;
}
任意模数多项式乘法逆
把 NTT 换成 MTT ,然后再注意亿点细节即可。
感谢万能的 UOJ 群友解答。
三模 NTT 版本
三模 NTT 版本常数巨大无比,由于相乘会爆 long long
,所以在一次递归中,三模 NTT 需要调用 \(18\) 次 DFT。常数可想而知,又加上题解区三模 NTT 版本的少之又少,导致我调了半上午 + 一下午才有了一个具体的理解。下面进入正文。
首先,因为 CRT 的合并会导致数很大,所以我们分成两次计算。
首先,我们计算出三个模数下的 \(F(x) \times H(x)\),然后我们直接把它用 exCRT 合并起来,目的是直接把这个转化成 \(\bmod \ 10^9+7\) 的意义下(模板题给定的模数是 \(10^9+7\))进行,这样就不会出锅了。
之后,我们把它取反,变成 \(-F(x)H(x)\),然后再将其变成 \(2 - F(x)H(x)\),也就是在这个多项式的 \(x^0\) 项加上一个 \(2\)。这么简单的问题我竟然当时没绕过弯来,其实就是这个 \(2\) 其实是 \(x^0\) 次方的系数,你不能在所有的 \(x^k\) 的系数前都加上 \(2\),这样显然不对。
然后我们在让这个多项式与 \(H(x)\) 相乘,这样便得到了 \(2H(x) - F(x)H(x)^2\),也就是我们的 \(G(x)\)。
还是,注意取模
代码
只截取了主要部分,其中 inv
函数求逆元,是用快速幂实现的。
#include<bits/stdc++.h>
#define ll long long
#define next nxt
#define re register
#define il inline
const int N = 4e5 + 5;
const ll mod1 = 469762049;
const ll mod2 = 998244353;
const ll mod3 = 1004535809;
const ll Mod = 1e9 + 7;
const ll g = 3;
using namespace std;
int n,m,lim;
ll invg,invn,inv1,inv12;
ll a[N],b1[N],b2[N],b3[N],R[N],B[N],aa[N],bb[N];
namespace Poly
{
il void NTT(ll A[],int n,int type,ll mod)
{
//do something
}
il ll crt(ll a,ll b,ll c)//采用 exCRT 的合并方法就不会爆 long long了
{
ll k = ((b-a) % mod2 + mod2) % mod2 * inv1 % mod2;
ll x = k * mod1 + a;
k = ((c-x) % mod3 + mod3) % mod3 * inv12 % mod3;
x = (x + k * mod1 % Mod * mod2 % Mod) % Mod;
return x;
}
il void Mul(ll *A,ll *B,ll *C,int n,ll modx)
{
for(re int i=0;i<n;i++) aa[i] = A[i] , bb[i] = B[i];
for(re int i=n;i<lim;i++) aa[i] = bb[i] = 0;
NTT(aa,lim,1,modx) , NTT(bb,lim,1,modx);
for(re int i=0;i<lim;i++) C[i] = aa[i] * bb[i] % modx;
NTT(C,lim,-1,modx);
for(re int i=n;i<lim;i++) C[i] = 0;
return ;
}
il void Polyinv(int n)
{
if(n == 1) { B[0] = inv(a[0],Mod); return ; }
Polyinv((n+1)>>1);
lim = 1;
while(lim < (n<<1)) lim <<= 1;
for(re int i=0;i<lim;i++) R[i] = (R[i>>1]>>1) + ((i&1) ? lim/2 : 0);
Mul(a,B,b1,n,mod1) , Mul(a,B,b2,n,mod2) , Mul(a,B,b3,n,mod3);//求 F(x)H(x)
for(re int i=0;i<n;i++) b1[i] = b2[i] = b3[i] = (Mod - crt(b1[i],b2[i],b3[i])) % Mod;//取反
b1[0] += 2 , b2[0] += 2 , b3[0] += 2;//x^0项加2
Mul(b1,B,b1,n,mod1) , Mul(b2,B,b2,n,mod2) , Mul(b3,B,b3,n,mod3);//求 2H(x) - F(x)H(x)^2
for(re int i=0;i<n;i++) B[i] = crt(b1[i],b2[i],b3[i]);//合并之后就是答案
return ;
}
}
using namespace Poly;
signed main()
{
n = read();
for(re int i=0;i<n;i++) a[i] = read() % Mod;
inv1 = inv(mod1 % mod2,mod2) , inv12 = inv(mod1 * mod2 % mod3,mod3);
Poly::Polyinv(n);
for(re int i=0;i<n;i++) cout << B[i] << " ";
return 0;
}
拆系数 FFT 版本
发现拆系数 FFT 还是比三模 NTT 好写+好理解,相比于正常的拆系数 FFT 没有什么大变动。
仍然只保留主要部分。我采用的是 \(4\) 次 FFT 版本,相比于三模 NTT 来说,它每轮只需要 \(8\) 次调用 DFT,所以最后的速度比 NTT 快了一倍左右,FFT 是 10s,NTT 接近 20s。
#include<bits/stdc++.h>
#define int long long
#define ll long long
#define complex comple
#define double long double
#define re register
#define il inline
const int N = 4e5 + 5;
const int mod = 1e9 + 7;
const double Pi = acos(-1.0);
const ll bas = (1<<15) , bas2 = bas * bas;
using namespace std;
int max(int x,int y){return x > y ? x : y;}
int min(int x,int y){return x < y ? x : y;}
int n,m,lim;
ll a[N],b[N],c[N],R[N];
ll x,AC,AD,BC,BD;
namespace Poly
{
struct complex
{
double x,y;
complex (double xx = 0 , double yy = 0) { x = xx; y = yy; }
}f[N],g[N],h[N];
//运算符重载
il ll ksm(ll a,ll b)
{
//do something
} ll inv(ll x) { return ksm(x,mod-2); }
il void FFT(complex A[],int n,int type)
{
// do something
}
il void Mul(ll a[],ll b[],ll c[],int n)
{
for(re int i=0;i<lim;i++) f[i] = g[i] = h[i] = {0,0};
for(re int i=0;i<n;i++)
{
a[i] %= mod , b[i] %= mod;
f[i].x = (a[i]>>15) , f[i].y = a[i] & 32767;
h[i].x = (b[i]>>15) , h[i].y = b[i] & 32767;
}//优化 4 次 FFT
FFT(f,lim,1) , FFT(h,lim,1);
g[0] = {f[0].x,-f[0].y};
for(re int i=1;i<lim;i++) g[i] = {f[lim-i].x,-f[lim-i].y};
for(re int i=0;i<lim;i++) h[i].x /= lim , h[i].y /= lim , f[i] = f[i] * h[i] , g[i] = g[i] * h[i];
FFT(f,lim,-1) , FFT(g,lim,-1);
for(re int i=0;i<n;i++)
{
AC = ((ll)((f[i].x+g[i].x)/2+0.5)) % mod;
AD = ((ll)((f[i].y+g[i].y)/2+0.5)) % mod;
BD = ((ll)(g[i].x-AC+0.5) % mod + mod) % mod;
BC = ((ll)(f[i].y-AD+0.5) % mod + mod) % mod;
AC = bas2 % mod * AC % mod , AD = bas % mod * (AD+BC) % mod;
c[i] = (AC + AD + BD) % mod;
}
}
il void Polyinv(int n)
{
if(n == 1) { b[0] = inv(a[0]); return ; }
Polyinv((n+1)>>1);
lim = 1;
while(lim < (n<<1)) lim <<= 1;
for(re int i=0;i<lim;i++) R[i] = (R[i>>1]>>1) + ((i&1) ? lim/2 : 0) , c[i] = 0;
Mul(a,b,c,n);//算第一次
for(re int i=0;i<n;i++) c[i] = (mod - c[i]) % mod;
c[0] = (c[0] + 2) % mod;//取反求第二次
Mul(b,c,b,n);
for(re int i=n;i<lim;i++) b[i] = 0;
return ;
}
}
using namespace Poly;
signed main()
{
n = read();
for(re int i=0;i<n;i++) a[i] = read();
Poly::Polyinv(n);
for(re int i=0;i<n;i++) cout << b[i] << " ";
return 0;
}