多项式模板--zhengjun

发布时间 2023-11-30 07:53:49作者: A_zjzj

vector 实现。


using LL=__int128;
const int mod=998244353;
ll qpow(ll x,ll y=mod-2,ll ans=1){
	for(;y;(x*=x)%=mod,y>>=1)if(y&1)(ans*=x)%=mod;
	return ans;
}
namespace Poly{
	const int N=1<<21,g[2]={3,(mod+1)/3};
	int lim,rev[N],A[N],B[N],pw[2][N];
	namespace Public{
		using poly=vector<int>;
		void init(int n){
			for(lim=1;lim<n;lim<<=1);
			for(int i=1;i<lim;i++)rev[i]=rev[i>>1]>>1|(i&1?lim>>1:0);
			for(int t=0;t<2;t++){
				for(int len=1;len<lim;len<<=1){
					pw[t][len]=1;
					int w=qpow(g[t],(mod-1)/len/2);
					for(int i=1;i<len;i++){
						pw[t][len|i]=1ll*pw[t][len|i-1]*w%mod;
					}
				}
			}
		}
		void NTT(int *a,int op){
			for(int i=0;i<lim;i++)if(rev[i]<i)swap(a[rev[i]],a[i]);
			for(int len=1,x,y,z;len<lim;len<<=1){
				for(int i=0;i<lim;i+=len<<1){
					for(int j=0;j<len;j++){
						x=a[i|j],y=1ll*a[i|j|len]*pw[op<0][len|j]%mod;
						a[i|j]=(z=x+y)<mod?z:z-mod,a[i|j|len]=(z=x-y)<0?z+mod:z;
					}
				}
			}
			if(op<0){
				int x=qpow(lim);
				for(int i=0;i<lim;i++)a[i]=1ll*a[i]*x%mod;
			}
		}
		poly operator * (const poly &a,const poly &b){
			int n=a.size(),m=b.size(),k=n+m-1;
			init(k);
			copy(a.begin(),a.end(),A),fill(A+n,A+lim,0);
			copy(b.begin(),b.end(),B),fill(B+m,B+lim,0);
			poly c(k);
			NTT(A,1),NTT(B,1);
			for(int i=0;i<lim;i++)A[i]=1ll*A[i]*B[i]%mod;
			NTT(A,-1);
			for(int i=0;i<k;i++)c[i]=A[i];
			return c;
		}
		poly& operator *= (poly &a,const poly &b){
			return a=a*b;
		}
		poly& operator += (poly &a,const poly &b){
			int n=a.size(),m=b.size();
			if(n<m)a.resize(m);
			for(int i=0;i<m;i++)(a[i]+=b[i])%=mod;
			return a;
		}
		poly operator + (const poly &a,const poly &b){
			poly c(a);
			return c+=b;
		}
		poly& operator -= (poly &a,const poly &b){
			int n=a.size(),m=b.size();
			if(n<m)a.resize(m);
			for(int i=0;i<m;i++)(a[i]+=mod-b[i])%=mod;
			return a;
		}
		poly operator - (const poly &a,const poly &b){
			poly c(a);
			return c-=b;
		}
		poly operator - (const poly &a){
			return poly()-a;
		}
		poly& operator *= (poly &a,const int &b){
			for(int &x:a)x=1ll*x*b%mod;
			return a;
		}
		poly operator * (const poly &a,const int &b){
			poly c(a);
			return c*=b;
		}
		poly& operator /= (poly &a,const int &b){
			return a*=qpow(b);
		}
		poly operator / (const poly &a,const int &b){
			poly c(a);
			return c/=b;
		}
		poly& operator %= (poly &a,const int &b){
			if(a.size()>b)a.resize(b);
			return a;
		}
		poly operator % (const poly &a,const int &b){
			poly c(a);
			return c%=b;
		}
		poly inv(const poly &a,int k=-1){
			if(!~k)k=a.size();
			poly b{(int)qpow(a[0])};
			for(int i=1,x;i<k;i<<=1){
				x=min(i*2,k);
				(b*=poly({2})-a%x*b%x)%=x;
			}
			return b;
		}
		poly qiudao(const poly &a){
			int n=a.size();
			if(!n)return poly();
			poly b(n-1);
			for(int i=1;i<n;i++)b[i-1]=1ll*a[i]*i%mod;
			return b;
		}
		poly jifen(const poly &a){
			int n=a.size();
			if(!n)return poly(1);
			poly b(n+1);
			b[1]=1;
			for(int i=2;i<=n;i++)b[i]=1ll*b[mod%i]*(mod-mod/i)%mod;
			for(int i=1;i<=n;i++)b[i]=1ll*b[i]*a[i-1]%mod;
			return b;
		}
		poly ln(const poly &a,int k=-1){
			if(!~k)k=a.size();
			return jifen(qiudao(a)*inv(a,k)%k)%k;
		}
		poly exp(const poly &a,int k=-1){
			if(!~k)k=a.size();
			poly b(1);
			b[0]=1;
			for(int i=1,x;i<k;i<<=1){
				x=min(i*2,k);
				(b*=poly({1})-ln(b,x)+a%x)%=x;
			}
			return b;
		}
		poly sqrt(const poly &a,int k=-1){
			if(a.empty())return poly();
			if(!~k)k=a.size();
			poly b(1);
			b[0]=(int)::sqrt(a[0]);
			for(int i=1,x;i<k;i<<=1){
				x=min(i*2,k);
				b=b/2+a%x*inv(b*2,x)%x;
			}
			return b;
		}
		poly operator << (const poly &a,const int &b){
			poly c(a.size()+b);
			copy(a.begin(),a.end(),c.begin()+b);
			return c;
		}
		poly operator <<= (poly &a,const int &b){
			return a=a<<b;
		}
		poly operator >> (const poly &a,const int &b){
			if(b>=a.size())return poly();
			return poly{a.begin()+b,a.end()};
		}
		poly operator >>= (poly &a,const int &b){
			return a=a>>b;
		}
		poly qpow(const poly &a,const ll &b,int k=-1){
			if(a.empty())return poly();
			if(!~k)k=a.size();
			int n=a.size(),st=0;
			for(;st<n&&!a[st];st++);
			if(st==n||(LL)st*b>=k)return poly(k);
			if(st)return qpow(a>>st,b,k-st*b)<<st*b;
			int x=a[0];
			return exp(ln(a/x,k)*(b%mod),k)*::qpow(x,b);
		}
	}
}
using namespace Poly::Public;