多项式全家桶

发布时间 2023-07-06 11:41:14作者: Think927

FFT

#include<bits/stdc++.h>
#define int long long
#define N 4000005
#define pb push_back
#define fi first
#define se second
#define pii pair<int,int>
#define db double
#define PI 3.14159265358
using namespace std;
int s0[N],s1[N];
struct node{
	db ai,bi;
	node(db a_=0,db b_=0){ai=a_,bi=b_;}
	friend node operator + (const node& x,const node& y){
		return (node){x.ai+y.ai,x.bi+y.bi}; }
	friend node operator - (const node& x,const node& y){
		return (node){x.ai-y.ai,x.bi-y.bi}; }
	friend node operator * (const node& x,const node& y){
		return (node){x.ai*y.ai-x.bi*y.bi,x.bi*y.ai+x.ai*y.bi}; }
}; int L,lim,r[N];
int rev(int x){int ret=0; for(int i=0;i<L;i++) ret+=(x>>i&1)*(1<<(L-i-1)); return ret;}
struct poly{
	node a[N]; int ln;	
	poly(){memset(a,0,sizeof(a));}
	void FFT(int fl){
		for(int i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
		for(int md=1;md<lim;md<<=1){
			node wn=node(cos(PI/md),fl*sin(PI/md));
			for(int j=0;j<lim;j+=(md<<1)){
				node w=node(1,0);
				for(int k=0;k<md;k++,w=w*wn){
					node x=a[j+k],y=w*a[j+k+md];
					a[j+k]=x+y,a[j+k+md]=x-y;
				}}}
		if(fl==-1) for(int i=0;i<lim;i++) a[i].ai/=(1.0*lim),a[i].bi/=(1.0*lim);
	}void print(){
		for(int i=0;i<=ln;i++) printf("%lld ",(int)(a[i].ai+0.5));}
	void rd(int l_){ln=l_; for(int i=0;i<=l_;i++) scanf("%lf",&a[i].ai),a[i].bi=0;}
};
poly operator * (poly x,poly y){
	x.ln+=y.ln,L=0,lim=1;
	while(lim<=x.ln) lim<<=1,L++; 
	for(int i=0;i<lim;i++) r[i]=rev(i);
	x.FFT(1),y.FFT(1);
	for(int i=0;i<lim;i++) x.a[i]=x.a[i]*y.a[i];
	x.FFT(-1); return x;
} poly a,b; int n,m;
signed main(){
	scanf("%lld%lld",&n,&m),a.rd(n),b.rd(m);
	a=a*b,a.print(); return 0;
}

目前还有多项式开根,多点求值,插值,多项式三角函数,bluestein 没有写。常数过大,需要修改,重点是重构多项式求逆部分的代码,这一部分之前写得非常 shit.
NTT

#include<bits/stdc++.h>
#define I 1ll
#define ll long long
#define N 530005
//depending on situations 5.5 * n (given in statements)
#define pb push_back
#define fi first
#define se second
#define pii pair<int,int>
#define M 998244353
#define M0 998244352
#define ppo pair<poly,poly>
#define MP make_pair
using namespace std;
int Fast(int x,int y){ int ret=1;
	while(y){if(y&1) ret=I*ret*x%M; x=I*x*x%M,y>>=1;} return ret;	
}int inv(int x){return Fast(x,M-2);}
const int G=3,Gi=inv(3);
int R[N],Lim,L,fac[N],ifac[N],iv[N];
void init(){ fac[0]=1,iv[1]=1;
	for(int i=1;i<=N-5;i++){
		fac[i]=I*fac[i-1]*i%M; if(i>1) iv[i]=I*(M-M/i)*iv[M%i]%M;
	}ifac[N-5]=Fast(fac[N-5],M-2);
	for(int i=N-6;~i;--i) ifac[i]=I*ifac[i+1]*(i+1)%M;
}int rev(int x){int ret=0; for(int i=0;i<L;i++) ret+=(x>>i&1)*(1<<(L-i-1)); return ret;}
int add(int x,int y){if(x+y>=M) return x+y-M; return x+y;}
int dec(int x,int y){if(x<y) return x+M-y; return x-y;}
struct poly{
	int ln=0,p[N]={0};
	void NTT(int fl){
		for(int i=0;i<Lim;i++) if(i<R[i]) swap(p[i],p[R[i]]);
		for(int md=1;md<Lim;md<<=1){
			int wn=Fast((fl==1)?G:Gi,(M-1)/(md<<1));
			for(int i=0,w=1;i<Lim;i+=(md<<1),w=1)
				for(int j=0;j<md;j++,w=I*w*wn%M){
					int x=p[i+j],y=I*w*p[i+j+md]%M;
					p[i+j]=add(x,y),p[i+j+md]=dec(x,y);	
				}}}
}NUL; 
poly operator + (poly a,poly b){
	poly c; c.ln=max(a.ln,b.ln);
	for(int i=a.ln+1;i<=c.ln;i++) a.p[i]=0;
	for(int i=b.ln+1;i<=c.ln;i++) b.p[i]=0;
	for(int i=0;i<=c.ln;i++) c.p[i]=add(a.p[i],b.p[i]);
	return c;
}poly operator - (poly a,poly b){
	poly c; c.ln=max(a.ln,b.ln);
	for(int i=a.ln+1;i<=c.ln;i++) a.p[i]=0;
	for(int i=b.ln+1;i<=c.ln;i++) b.p[i]=0;
	for(int i=0;i<=c.ln;i++) c.p[i]=dec(a.p[i],b.p[i]);
	return c;	
}poly operator * (poly x,poly y){
	poly c; Lim=1,L=0,c.ln=x.ln+y.ln; 
	while(Lim<=c.ln) Lim<<=1,L++;
	for(int i=0;i<Lim;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
	for(int i=x.ln+1;i<Lim;i++) x.p[i]=0;
	for(int i=y.ln+1;i<Lim;i++) y.p[i]=0;
	x.NTT(1),y.NTT(1);
	for(int i=0;i<Lim;i++) c.p[i]=I*x.p[i]*y.p[i]%M;
	c.NTT(-1); int iv=Fast(Lim,M-2);
	for(int i=0;i<Lim;i++) c.p[i]=I*c.p[i]*iv%M; return c;
} poly inv(poly a,int l0){
	//store a[0....l0+1] , and specifically make a[l0+1]=0.
	a.ln=l0+1; poly b; int n=a.ln,o=1;
	for(int i=n+1;i<=n+n;i++) a.p[i]=0;
	b.p[0]=Fast(a.p[0],M-2),b.ln=1;
	while(o<=n){poly a0; a0.ln=o+o;
		for(int i=0;i<=o+o;i++) a0.p[i]=a.p[i]; 
		poly c=a0*b; for(int i=0;i<=c.ln;i++) c.p[i]=dec((i==0?2:0),c.p[i]);
		o<<=1,b=c*b; for(int i=o;i<=b.ln;i++) b.p[i]=0; b.ln=o-1;
	}for(int i=n+1;i<=b.ln;i++) b.p[i]=0;
	b.ln=n; return b;
}ppo modl(poly a,poly b){
	if(a.ln<b.ln) return MP(a,NUL);
	reverse(a.p,a.p+a.ln+1);
	reverse(b.p,b.p+b.ln+1); 
	poly Q=a*inv(b,a.ln),R;
	Q.ln=a.ln-b.ln; reverse(Q.p,Q.p+Q.ln+1);
	reverse(a.p,a.p+a.ln+1);
	reverse(b.p,b.p+b.ln+1);
	R=a-b*Q,R.ln=b.ln-2; return MP(Q,R);
}poly dirv(poly x){
	for(int i=0;i<x.ln;i++) x.p[i]=I*x.p[i+1]*(i+1)%M; x.p[x.ln--]=0;	return x;
}poly itgr(poly x){
	for(int i=x.ln;~i;i--) x.p[i+1]=I*x.p[i]*iv[i+1]%M;
	x.p[0]=0,x.ln++; return x;	
}poly Ln(poly x,int l0){assert(x.p[0]==1); //store x[0...l0]
	x=itgr(dirv(x)*inv(x,x.ln)),x.ln=l0; return x;
}poly Exp(poly x){ x.ln<<=1;
	poly x0,F0,F; F0.p[F0.ln=0]=1; int n=x.ln;
	for(int o=1;o<=n;o<<=1){ 
		x0.ln=o-1; for(int t=0;t<=x0.ln;t++) x0.p[t]=x.p[t];
		F=x0-Ln(F0,o-1),F.p[0]++,F.ln=o,F0=F*F0,F0.ln=2*o-1;
	}F0.ln=(n>>1); for(int i=F0.ln+1;i<=n;i++) F0.p[i]=0; return F0;
}poly Fast(poly x,int K,int K0,int fl){ int k=0,iv;
	while(k<=x.ln&&!x.p[k]) ++k;
	if(k>x.ln) return x; poly y=x; 
	if(k>0&&fl){
		for(int i=0;i<=x.ln;i++) x.p[i]=0; return x; 
	}iv=inv(x.p[k]);
	for(int i=k;i<=x.ln;i++) x.p[i-k]=I*y.p[i]*iv%M;
	x.ln-=k; poly t; t.ln=x.ln,t=Ln(x,x.ln);
	for(int i=0;i<=x.ln;i++) t.p[i]=I*t.p[i]*K%M;
	t=Exp(t),iv=Fast(inv(iv),K0); poly tp;
	for(int i=0;i<=t.ln;i++) tp.p[i]=0; tp.ln=y.ln;
	for(int i=0;i<=t.ln;i++) if(i+I*k*K<=y.ln) tp.p[i+I*k*K]=I*t.p[i]*iv%M;
	return tp;
}poly a,b; int n,fl; char s[N];
signed main(){
	//remember to use function init.
	scanf("%d%s",&n,s+1),a.ln=n-1,init();
	ll s0=0,s1=0; int l0=strlen(s+1); for(int i=1;i<=l0;i++){
		s0=I*s0*10+I*(s[i]-'0'); if(s0>=M) fl=1; s0%=M; 
		s1=(I*s1*10+I*(s[i]-'0'))%M0;
	}for(int i=0;i<n;i++) scanf("%d",&a.p[i]);
	b=Fast(a,(int)s0,(int)s1,fl); for(int i=0;i<=n-1;i++) printf("%d ",b.p[i]);
	return 0;
}/*
3 2
2 2 3*/