任意模数多项式模板--zhengjun

发布时间 2023-11-30 09:05:02作者: A_zjzj
using LL=__int128;
int mod=998244353;
ll qpow(ll x,ll y=mod-2,ll ans=1){
	for(;y;(x*=x)%=mod,y>>=1)if(y&1)(ans*=x)%=mod;
	return ans;
}
mt19937 rnd(time(0));
int sqr(int n){
	int a,val;
	auto chk=[&](int x){
		return qpow(x,(mod-1)/2)==1;
	};
	do a=rnd()%mod;while(chk(val=(1ll*a*a-n)%mod));
	struct comp{
		int x,y;
		comp(int a=0,int b=0):x(a),y(b){}
		comp operator + (const comp &a)const{
			return comp((x+a.x)%mod,(y+a.y)%mod);
		}
	}x(a,1),ans(1,0);
	auto mul=[&](comp a,comp b){
		return comp((1ll*a.x*b.x+1ll*a.y*b.y%mod*val)%mod,(1ll*a.x*b.y+1ll*a.y*b.x)%mod);
	};
	int y=(mod+1)/2;
	for(;y;x=mul(x,x),y>>=1)if(y&1)ans=mul(ans,x);
	return min(ans.x,mod-ans.x);
}
namespace Poly{
	const long double pi=acos(-1);
	using comp=complex<long double>;
	const int N=1<<21;
	int lim,U=(1<<15)-1,rev[N];
	comp a0[N],a1[N],b0[N],b1[N],c0[N],c1[N],c2[N],pw[N];
	namespace Public{
		using poly=vector<int>;
		void init(int n){
			for(lim=1;lim<n;lim<<=1);
			for(int i=1;i<lim;i++)rev[i]=rev[i>>1]>>1|(i&1?lim>>1:0);
			for(int len=1;len<lim;len<<=1){
				for(int i=0;i<len;i++){
					pw[len|i]=comp(cos(pi/len*i),sin(pi/len*i));
				}
			}
		}
		void FFT(comp *a,int op){
			for(int i=0;i<lim;i++)if(rev[i]<i)swap(a[rev[i]],a[i]);
			for(int len=1;len<lim;len<<=1){
				for(int i=0;i<lim;i+=len<<1){
					for(int j=0;j<len;j++){
						comp x=a[i|j],y=a[i|j|len]*(op<0?conj(pw[len|j]):pw[len|j]);
						a[i|j]=x+y,a[i|j|len]=x-y;
					}
				}
			}
			if(op<0){
				for(int i=0;i<lim;i++)a[i]/=lim;
			}
		}
		poly operator * (const poly &a,const poly &b){
			int n=a.size(),m=b.size(),k=n+m-1;
			init(k);
			poly c(k);
			for(int i=0;i<n;i++)a0[i]=comp(a[i]&U,0),a1[i]=comp(a[i]>>15,0);
			for(int i=n;i<lim;i++)a0[i]=a1[i]=comp(0,0);
			for(int i=0;i<m;i++)b0[i]=comp(b[i]&U,0),b1[i]=comp(b[i]>>15,0);
			for(int i=m;i<lim;i++)b0[i]=b1[i]=comp(0,0);
			FFT(a0,1),FFT(a1,1),FFT(b0,1),FFT(b1,1);
			for(int i=0;i<lim;i++){
				c0[i]=a0[i]*b0[i];
				c1[i]=a0[i]*b1[i]+a1[i]*b0[i];
				c2[i]=a1[i]*b1[i];
			}
			FFT(c0,-1),FFT(c1,-1),FFT(c2,-1);
			for(int i=0;i<k;i++){
				ll s0=(ll)(real(c0[i])+0.5l);
				ll s1=(ll)(real(c1[i])+0.5l);
				ll s2=(ll)(real(c2[i])+0.5l);
				c[i]=((((s2%mod<<15)+s1)%mod<<15)+s0)%mod;
			}
			return c;
		}
		poly& operator *= (poly &a,const poly &b){
			return a=a*b;
		}
		poly& operator += (poly &a,const poly &b){
			int n=a.size(),m=b.size();
			if(n<m)a.resize(m);
			for(int i=0;i<m;i++)(a[i]+=b[i])%=mod;
			return a;
		}
		poly operator + (const poly &a,const poly &b){
			poly c(a);
			return c+=b;
		}
		poly& operator -= (poly &a,const poly &b){
			int n=a.size(),m=b.size();
			if(n<m)a.resize(m);
			for(int i=0;i<m;i++)(a[i]+=mod-b[i])%=mod;
			return a;
		}
		poly operator - (const poly &a,const poly &b){
			poly c(a);
			return c-=b;
		}
		poly operator - (const poly &a){
			return poly()-a;
		}
		poly& operator *= (poly &a,const int &b){
			for(int &x:a)x=1ll*x*b%mod;
			return a;
		}
		poly operator * (const poly &a,const int &b){
			poly c(a);
			return c*=b;
		}
		poly& operator /= (poly &a,const int &b){
			return a*=qpow(b);
		}
		poly operator / (const poly &a,const int &b){
			poly c(a);
			return c/=b;
		}
		poly& operator %= (poly &a,const int &b){
			if(a.size()>b)a.resize(b);
			return a;
		}
		poly operator % (const poly &a,const int &b){
			poly c(a);
			return c%=b;
		}
		poly inv(const poly &a,int k=-1){
			if(!~k)k=a.size();
			poly b{(int)qpow(a[0])};
			for(int i=1,x;i<k;i<<=1){
				x=min(i*2,k);
				(b*=poly({2})-a%x*b%x)%=x;
			}
			return b;
		}
		poly qiudao(const poly &a){
			int n=a.size();
			if(!n)return poly();
			poly b(n-1);
			for(int i=1;i<n;i++)b[i-1]=1ll*a[i]*i%mod;
			return b;
		}
		poly jifen(const poly &a){
			int n=a.size();
			if(!n)return poly(1);
			poly b(n+1);
			b[1]=1;
			for(int i=2;i<=n;i++)b[i]=1ll*b[mod%i]*(mod-mod/i)%mod;
			for(int i=1;i<=n;i++)b[i]=1ll*b[i]*a[i-1]%mod;
			return b;
		}
		poly ln(const poly &a,int k=-1){
			if(!~k)k=a.size();
			return jifen(qiudao(a)*inv(a,k)%k)%k;
		}
		poly exp(const poly &a,int k=-1){
			if(!~k)k=a.size();
			poly b(1);
			b[0]=1;
			for(int i=1,x;i<k;i<<=1){
				x=min(i*2,k);
				(b*=poly({1})-ln(b,x)+a%x)%=x;
			}
			return b;
		}
		poly sqrt(const poly &a,int k=-1){
			if(a.empty())return poly();
			if(!~k)k=a.size();
			poly b(1);
			b[0]=sqr(a[0]);
			for(int i=1,x;i<k;i<<=1){
				x=min(i*2,k);
				b=b/2+a%x*inv(b*2,x)%x;
			}
			return b;
		}
		poly operator << (const poly &a,const int &b){
			poly c(a.size()+b);
			copy(a.begin(),a.end(),c.begin()+b);
			return c;
		}
		poly operator <<= (poly &a,const int &b){
			return a=a<<b;
		}
		poly operator >> (const poly &a,const int &b){
			if(b>=a.size())return poly();
			return poly{a.begin()+b,a.end()};
		}
		poly operator >>= (poly &a,const int &b){
			return a=a>>b;
		}
		poly qpow(const poly &a,const ll &b,int k=-1){
			if(a.empty())return poly();
			if(!~k)k=a.size();
			int n=a.size(),st=0;
			for(;st<n&&!a[st];st++);
			if(st==n||(LL)st*b>=k)return poly(k);
			if(st)return qpow(a>>st,b,k-st*b)<<st*b;
			int x=a[0];
			return exp(ln(a/x,k)*(b%mod),k)*::qpow(x,b);
		}
		poly operator / (const poly &a,const poly &b){
			int n=a.size(),m=b.size(),k=n-m+1;
			if(k<=0)return poly();
			poly c(a),d(b);
			reverse(c.begin(),c.end());
			reverse(d.begin(),d.end());
			c=c%k*inv(d,k)%k;
			reverse(c.begin(),c.end());
			return c;
		}
		poly operator % (const poly &a,const poly &b){
			if(a.size()<b.size())return a;
			poly c=a/b;
			return a-b*c;
		}
		void div(const poly &a,const poly &b,poly &c,poly &d){
			if(a.size()<b.size()){
				c=poly(),d=a;
				return;
			}
			c=a/b,d=(a-b*c)%(b.size()-1);
		}
	}
}
using namespace Poly::Public;