多项式(Ⅱ):进阶工业

发布时间 2023-06-26 20:51:02作者: Bloodstalk

书接上回 多项式(Ⅰ):基础工业。这部分主要写一下进阶的一些模板。

多项式求逆

多项式乘法逆

给定一个多项式 \(F(x)\),求出一个多项式 \(G(x)\),满足 \(F(x) * G(x) \equiv 1(\bmod \ x^n)\)。系数对 \(998244353\) 取模。

我们先讨论比较简单的模数是 \(998244353\) 这样类型的。

方法1:分治 FFT

目前还没学,咕掉先。时间复杂度是 \(O(n\log^2 n)\) 的。

方法2:倍增法

参考资料:多项式求逆-litble

如果多项式 \(F(x)\) 只有一项,那么显然 \(G_0\) 这个常数项就是 \(F_0\) 的逆元。

若有 \(n\) 项,考虑递归求解,像类似于归纳法一样往后推。

假设我们已经知道

\[F(x)H(x) \equiv 1(\bmod \ x^{\lceil \frac{n}{2}\rceil}) \]

因为 \(F(x) * G(x) \equiv 1(\bmod \ x^n)\),这就说明对于第 \(1 \sim n-1\) 次方项,两个多项式乘起来后系数为 \(0\),这就说明

\[F(x)G(x) \equiv 1(\bmod \ x^{\lceil \frac{n}{2}\rceil}) \]

两个式子相减

\[F(x)(G(x) - H(x)) \equiv 0(\bmod \ x^{\lceil \frac{n}{2}\rceil}) \]

\[G(x) - H(x) \equiv 0(\bmod \ x^{\lceil \frac{n}{2}\rceil}) \]

我们对上面这个式子进行平方。由于 \(G(x) - H(x)\) 在模 \(x^{\lceil \frac{n}{2}\rceil}\)
意义下为 \(0\),说明从 \(0\sim \lceil \frac{n}{2}\rceil-1\) 次项都为 \(0\)。设平方后的这个多项式为 \(P\),则

\[P_i = \sum_{j=0}^i(G(x)-H(x))_j (G(x)-H(x))_{i-j} \]

那么对于 \(i \leq n\) 的项,\(j\)\(i-j\) 至少有一项的次数 \(< \lceil \frac{n}{2 }\rceil\),那么就是 \(0\),所以

\[G(x)^2 - 2G(x)H(x) + H(x)^2 \equiv 0(\bmod \ x^n) \]

乘上一个 \(F(x)\),根据 \(F(x)G(x) \equiv 1(\bmod x^n)\),再移一下项,就得到

\[G(x) \equiv 2H(x) - F(x)H(x)^2(\bmod \ x^n) \]

转化到了这个地方便可以用 NTT 求解了。

复杂度用主定理分析一下:

\[T(n) = T(n/2) + O(n\log n) = O(n \log n) \]

此外还有一些常用的递归复杂度:

\[T(n) = T(n/2) + O(n) = O(n) \]

\[T(n) = 2T(n/2) + O(n) = O(n\log n) \]

\[T(n) = 2T(n/2) + O(n\log n) = O(n\log^2 n) \]

可以用主定理分析一下,感觉基本的这么多就可以了。

代码实现

有一些小细节需要注意:

  • 代码中 (n+1)>>1 就是向上取整,至于为什么是向上取整,这样就能保证 \(\lceil \frac{n}{2} \rceil\) 的平方一定是能包含到 \(n\) 的。

  • 为什么 lim < (n<<1) , 因为这样相乘以后就能保证两个 \(n\) 项的多项式相乘后的多项式能存的开。

  • 为什么要把 \(n \sim \text{lim}-1\) 设成 \(0\),因为模 \(x^n\) 意义下这些地方是同余与 \(0\) 的,为了避免干扰计算,就把系数也设成 \(0\) 了。后面的那个 \(b_i\)\(0\) 也是同理的。

il void Polyinv(int n,ll a[],ll b[])
	{
		if(n == 1) { b[0] = inv(a[0]); return ; }
		Polyinv((n+1)>>1,a,b);
		lim = 1;
		while(lim < (n<<1)) lim <<= 1;
		for(re int i=0;i<lim;i++) R[i] = (R[i>>1]>>1) + ((i&1) ? lim/2 : 0);
		for(re int i=0;i<n;i++) c[i] = a[i];
		for(re int i=n;i<lim;i++) c[i] = 0;
		invg = inv(g) , invn = inv(lim);
		NTT(c,lim,1) , NTT(b,lim,1);
		for(re int i=0;i<lim;i++) b[i] = 1ll * ((2 * b[i] % mod) - (c[i] * b[i] % mod * b[i] % mod) + mod) % mod;
		NTT(b,lim,-1);
		for(re int i=n;i<lim;i++) b[i] = 0;
		return ;
	}

任意模数多项式乘法逆

把 NTT 换成 MTT ,然后再注意亿点细节即可。

感谢万能的 UOJ 群友解答。

三模 NTT 版本

三模 NTT 版本常数巨大无比,由于相乘会爆 long long,所以在一次递归中,三模 NTT 需要调用 \(18\) 次 DFT。常数可想而知,又加上题解区三模 NTT 版本的少之又少,导致我调了半上午 + 一下午才有了一个具体的理解。下面进入正文。

首先,因为 CRT 的合并会导致数很大,所以我们分成两次计算。

首先,我们计算出三个模数下的 \(F(x) \times H(x)\),然后我们直接把它用 exCRT 合并起来,目的是直接把这个转化成 \(\bmod \ 10^9+7\) 的意义下(模板题给定的模数是 \(10^9+7\))进行,这样就不会出锅了。

之后,我们把它取反,变成 \(-F(x)H(x)\),然后再将其变成 \(2 - F(x)H(x)\),也就是在这个多项式的 \(x^0\) 项加上一个 \(2\)。这么简单的问题我竟然当时没绕过弯来,其实就是这个 \(2\) 其实是 \(x^0\) 次方的系数,你不能在所有的 \(x^k\) 的系数前都加上 \(2\),这样显然不对。

然后我们在让这个多项式与 \(H(x)\) 相乘,这样便得到了 \(2H(x) - F(x)H(x)^2\),也就是我们的 \(G(x)\)

还是,注意取模

代码

只截取了主要部分,其中 inv 函数求逆元,是用快速幂实现的。

#include<bits/stdc++.h>
#define ll long long
#define next nxt
#define re register
#define il inline
const int N = 4e5 + 5;
const ll mod1 = 469762049;
const ll mod2 = 998244353;
const ll mod3 = 1004535809;
const ll Mod = 1e9 + 7;
const ll g = 3;
using namespace std;

int n,m,lim;
ll invg,invn,inv1,inv12;
ll a[N],b1[N],b2[N],b3[N],R[N],B[N],aa[N],bb[N];

namespace Poly
{
	il void NTT(ll A[],int n,int type,ll mod)
	{
		//do something
	}
	il ll crt(ll a,ll b,ll c)//采用 exCRT 的合并方法就不会爆 long long了
	{
		ll k = ((b-a) % mod2 + mod2) % mod2 * inv1 % mod2;
		ll x = k * mod1 + a;
		k = ((c-x) % mod3 + mod3) % mod3 * inv12 % mod3;
		x = (x + k * mod1 % Mod * mod2 % Mod) % Mod;
		return x;
	}
	il void Mul(ll *A,ll *B,ll *C,int n,ll modx)
	{
		for(re int i=0;i<n;i++) aa[i] = A[i] , bb[i] = B[i];
		for(re int i=n;i<lim;i++) aa[i] = bb[i] = 0;
		NTT(aa,lim,1,modx) , NTT(bb,lim,1,modx);
		for(re int i=0;i<lim;i++) C[i] = aa[i] * bb[i] % modx;
		NTT(C,lim,-1,modx);
		for(re int i=n;i<lim;i++) C[i] = 0;
		return ;
	}
	il void Polyinv(int n)
	{
		if(n == 1) { B[0] = inv(a[0],Mod); return ; }
		Polyinv((n+1)>>1);
		lim = 1;
		while(lim < (n<<1)) lim <<= 1;
		for(re int i=0;i<lim;i++) R[i] = (R[i>>1]>>1) + ((i&1) ? lim/2 : 0);
		Mul(a,B,b1,n,mod1) , Mul(a,B,b2,n,mod2) , Mul(a,B,b3,n,mod3);//求 F(x)H(x)	
		for(re int i=0;i<n;i++) b1[i] = b2[i] = b3[i] = (Mod - crt(b1[i],b2[i],b3[i])) % Mod;//取反
		b1[0] += 2 , b2[0] += 2 , b3[0] += 2;//x^0项加2
		Mul(b1,B,b1,n,mod1) , Mul(b2,B,b2,n,mod2) , Mul(b3,B,b3,n,mod3);//求 2H(x) - F(x)H(x)^2
		for(re int i=0;i<n;i++) B[i] = crt(b1[i],b2[i],b3[i]);//合并之后就是答案
		return ;
	}
}
using namespace Poly;

signed main()
{
	n = read();
	for(re int i=0;i<n;i++) a[i] = read() % Mod;
	inv1 = inv(mod1 % mod2,mod2) , inv12 = inv(mod1 * mod2 % mod3,mod3);
	Poly::Polyinv(n);
	for(re int i=0;i<n;i++) cout << B[i] << " ";
	return 0;
}

拆系数 FFT 版本

发现拆系数 FFT 还是比三模 NTT 好写+好理解,相比于正常的拆系数 FFT 没有什么大变动。

仍然只保留主要部分。我采用的是 \(4\) 次 FFT 版本,相比于三模 NTT 来说,它每轮只需要 \(8\) 次调用 DFT,所以最后的速度比 NTT 快了一倍左右,FFT 是 10s,NTT 接近 20s。

#include<bits/stdc++.h>
#define int long long
#define ll long long
#define complex comple
#define double long double
#define re register
#define il inline
const int N = 4e5 + 5;
const int mod = 1e9 + 7;
const double Pi = acos(-1.0);
const ll bas = (1<<15) , bas2 = bas * bas;
using namespace std;
int max(int x,int y){return x > y ? x : y;}
int min(int x,int y){return x < y ? x : y;}

int n,m,lim;
ll a[N],b[N],c[N],R[N];
ll x,AC,AD,BC,BD;

namespace Poly
{
	struct complex
	{
		double x,y;
		complex (double xx = 0 , double yy = 0) { x = xx; y = yy; }
	}f[N],g[N],h[N];
    //运算符重载
    
	il ll ksm(ll a,ll b)
	{
		//do something
	} ll inv(ll x) { return ksm(x,mod-2); }
	il void FFT(complex A[],int n,int type)
	{
		// do something
	}
	il void Mul(ll a[],ll b[],ll c[],int n)
	{
		for(re int i=0;i<lim;i++) f[i] = g[i] = h[i] = {0,0};
		for(re int i=0;i<n;i++)
		{
			a[i] %= mod , b[i] %= mod;
			f[i].x = (a[i]>>15) , f[i].y = a[i] & 32767;
			h[i].x = (b[i]>>15) , h[i].y = b[i] & 32767;
		}//优化 4 次 FFT
		FFT(f,lim,1) , FFT(h,lim,1);
		g[0] = {f[0].x,-f[0].y};
		for(re int i=1;i<lim;i++) g[i] = {f[lim-i].x,-f[lim-i].y};
		for(re int i=0;i<lim;i++) h[i].x /= lim , h[i].y /= lim , f[i] = f[i] * h[i] , g[i] = g[i] * h[i];
		FFT(f,lim,-1) , FFT(g,lim,-1);
		for(re int i=0;i<n;i++)
		{
			AC = ((ll)((f[i].x+g[i].x)/2+0.5)) % mod;
			AD = ((ll)((f[i].y+g[i].y)/2+0.5)) % mod;
			BD = ((ll)(g[i].x-AC+0.5) % mod + mod) % mod;
			BC = ((ll)(f[i].y-AD+0.5) % mod + mod) % mod;
			AC = bas2 % mod * AC % mod , AD = bas % mod * (AD+BC) % mod;
			c[i] = (AC + AD + BD) % mod;
		}
	}
	il void Polyinv(int n)
	{
		if(n == 1) { b[0] = inv(a[0]); return ; }
		Polyinv((n+1)>>1);
		lim = 1;
		while(lim < (n<<1)) lim <<= 1;
		for(re int i=0;i<lim;i++) R[i] = (R[i>>1]>>1) + ((i&1) ? lim/2 : 0) , c[i] = 0;
		Mul(a,b,c,n);//算第一次
		for(re int i=0;i<n;i++) c[i] = (mod - c[i]) % mod;
		c[0] = (c[0] + 2) % mod;//取反求第二次
		Mul(b,c,b,n);
		for(re int i=n;i<lim;i++) b[i] = 0;
		return ;
	}
}
using namespace Poly;

signed main()
{
	n = read();
	for(re int i=0;i<n;i++) a[i] = read();
	Poly::Polyinv(n);
	for(re int i=0;i<n;i++) cout << b[i] << " ";
	return 0;
}