UOJ312 【UNR #2】梦中的题面

发布时间 2023-08-02 22:17:11作者: _kkio

好题。

容斥后插板,要计算的形如 \(\binom{Sum}{m}\) 的样子。这个 \(Sum\) 可能会很大,不能直接设进状态,但是我们 \(dp\) 需要 \(Sum\) 计算组合数。解决方法是用范德蒙德卷积

\[\sum_{i=0}^{k}{\binom{n}{i}\binom{m}{k-i}} = \binom{n+m}{k} \]

\(dp_i\) 表示当前所有 \(\binom{Sum}{i}\) 的总和,如果要给 \(Sum\) 加上一个什么数的话,直接用如上的卷积就行了。

但是这样导致了一个严重的问题:我们使用范德蒙德卷积,将组合数拓展到了实数域上(也就是上标为负也可能不是 \(0\) 一类的),会计算出错。我们要保证 \(dp\) 过程中 \(Sum\) 始终不会变成负的,怎么办?

题目告诉我们,给的 \(a_i\) 以几次方几次方的形式出现,而且开始的总和不超过他们的总和,暗示我们数位dp。那我们相当于一些数可以减,一些数不能减,可以看成 \(01\) 序列,然后就是经典数位dp了。 \(f_i\) 表示没解除限制的,\(g_i\) 表示解除了的。相互转移即可。

卷积可以用多项式优化。

加强!!!

#include <bits/stdc++.h>
using namespace std;
namespace ZPoly
{

using LL=long long;
constexpr int MOD=998244353,G=114514,MAXN=1<<21;
inline int qpow(LL a,LL b) { int r=1;for(;b;(b&1)?r=r*a%MOD:0,a=a*a%MOD,b>>=1);return r; }
inline int madd(int x) { return x; }
inline int mmul(int x) { return x; }
inline int msub(int x,int y) { return (x-=y)<0?x+=MOD:x; }
inline int mdiv(int x,int y) { return (LL)x*qpow(y,MOD-2)%MOD; }
template<typename ...Args>inline int madd(int x,Args ...y) { return (x+=madd(y...))>=MOD?x-=MOD:x; }
template<typename ...Args>inline int mmul(int x,Args ...y) { return (LL)x*mmul(y...)%MOD; }

class Polynomial
{
 private:
	static constexpr int NTT_LIM=180;
	static int g[MAXN+5],c1[MAXN+5],c2[MAXN+5];
	int deg;
	vector<int> c;
 public:
	static void init()
	{
		for(int i=2,gn;i<=MAXN;i<<=1)
		{
			g[i>>1]=1,gn=qpow(G,(MOD-1)/i);
			for(int j=(i>>1)+1;j<i;j++) g[j]=mmul(g[j-1],gn);
		}
	}
	static void DIT(int *a,int len)
	{
		for(int i=len>>1;i;i>>=1)
			for(int j=0;j<len;j+=i<<1)
				for(int k=0,x,y;k<i;k++)
					x=a[j+k],y=a[i+j+k],a[j+k]=madd(x,y),a[i+j+k]=mmul(g[i+k],msub(x,y));
	}
	static void DIF(int *a,int len)
	{
		for(int i=1;i<len;i<<=1)
			for(int j=0;j<len;j+=i<<1)
				for(int k=0,x,y;k<i;k++)
					x=a[j+k],y=mmul(g[i+k],a[i+j+k]),a[j+k]=madd(x,y),a[i+j+k]=msub(x,y);
		int x=qpow(len,MOD-2);
		for(int i=0;i<len;i++) a[i]=mmul(a[i],x);
		reverse(a+1,a+len);
	}
 private:
	static void __polyinv(const int *a,int *b,int len)
	{
		if(len==1) return b[0]=qpow(a[0],MOD-2),void();
		__polyinv(a,b,(len+1)>>1);
		int nn=1<<(__lg((len<<1)-1)+1);
		memcpy(c1,a,len<<2);
		memset(b+len,0,(nn-len)<<2);
		memset(c1+len,0,(nn-len)<<2);
		DIT(b,nn),DIT(c1,nn);
		for(int i=0;i<nn;i++) b[i]=mmul(b[i],msub(2,mmul(b[i],c1[i])));
		DIF(b,nn),memset(b+len,0,(nn-len)<<2);
	}
	static void __polyln(const int *a,int *b,int len)
	{
		__polyinv(a,b,len);
		for(int i=1;i<len;i++) c1[i-1]=mmul(i,a[i]);
		int nn=1<<(__lg((len<<1)-1)+1);
		memset(b+len,0,(nn-len)<<2);
		memset(c1+len,0,(nn-len)<<2);
		DIT(b,nn),DIT(c1,nn);
		for(int i=0;i<nn;i++) b[i]=mmul(b[i],c1[i]);
		DIF(b,nn),memset(b+len,0,(nn-len)<<2);
		for(int i=len-1;i>0;i--) b[i]=mdiv(b[i-1],i);
		b[0]=0;
	}
	static void __polyexp(const int *a,int *b,int l,int r)
	{
		if(l==r-1) return b[l]=(l?mdiv(b[l],l):1),void();
		int len=r-l,mid=(l+r)>>1;
		__polyexp(a,b,l,mid);
		for(int i=0;i<len;i++) c1[i]=a[i];
		memcpy(c2,b+l,(mid-l)<<2);
		memset(c2+mid-l,0,(r-mid)<<2);
		if(len<=NTT_LIM) for(int i=len-1;i>=0;i--)
		{
			c1[i]=mmul(c1[i],c2[0]);
			for(int j=0;j<i;j++) c1[i]=madd(c1[i],mmul(c1[j],c2[i-j]));
		}
		else
		{
			DIT(c1,len),DIT(c2,len);
			for(int i=0;i<len;i++) c1[i]=mmul(c1[i],c2[i]);
			DIF(c1,len);
		}
		for(int i=mid;i<r;i++) b[i]=madd(b[i],c1[i-l]);
		__polyexp(a,b,mid,r);
	}
 public:
	Polynomial(): deg(1),c(1){}
	Polynomial(const Polynomial &p): deg(p.deg),c(p.c){}
	Polynomial(Polynomial &&p): deg(p.deg),c(move(p.c)){}
	explicit Polynomial(int d): deg(d),c(d){}
	explicit Polynomial(const vector<int> &v): deg(v.size()),c(v){}
	explicit Polynomial(const initializer_list<int> &l): deg(l.size()),c(l){}
	inline int &operator [](int i) { return c[i]; }
	inline int operator [](int i)const { return c[i]; }
	inline int degree()const { return deg; }
	inline void resize(int d) { c.resize(deg=d); }
	inline Polynomial &operator +=(const Polynomial &p)
	{
		if(deg<p.deg) resize(p.deg);
		for(int i=0;i<deg;i++) c[i]=madd(c[i],p[i]);
		return *this;
	}
	inline Polynomial &operator -=(const Polynomial &p)
	{
		if(deg<p.deg) resize(p.deg);
		for(int i=0;i<deg;i++) c[i]=msub(c[i],p[i]);
		return *this;
	}
	inline Polynomial &operator *=(const Polynomial &p)
	{
		int n=deg,m=p.deg;resize(n+m-1);
		if(n+m<NTT_LIM)
		{
			memcpy(c1,c.data(),n<<2);
			memset(c2,0,(n+m-1)<<2);
			for(int i=0;i<n;i++)
				for(int j=0;j<m;j++)
					c2[i+j]=madd(c2[i+j],mmul(c1[i],p[j]));
			memcpy(c.data(),c2,(n+m-1)<<2);
		}
		else
		{
			int nn=1<<(__lg(n+m-1)+1);
			memcpy(c1,c.data(),n<<2),memcpy(c2,p.c.data(),m<<2);
			memset(c1+n,0,(nn-n)<<2),memset(c2+m,0,(nn-m)<<2);
			DIT(c1,nn),DIT(c2,nn);
			for(int i=0;i<nn;i++) c1[i]=mmul(c1[i],c2[i]);
			DIF(c1,nn),memcpy(c.data(),c1,deg<<2);
		}
		return *this;
	}
	friend inline Polynomial derivative(const Polynomial &p)
	{
		Polynomial q(p.deg-1);
		for(int i=1;i<p.deg;i++) q[i-1]=mmul(p[i],i);
		return q;
	}
	friend inline Polynomial integral(const Polynomial &p)
	{
		Polynomial q(p.deg+1);
		for(int i=1;i<p.deg;i++) q[i+1]=mdiv(p[i],i+1);
		return q;
	}
	inline Polynomial inv()const
	{
		if(c[0]==0) cerr<<"[x^0]f(x)=0, f(x)^-1 doesn't exist.\n",abort();
		int nn=1<<(__lg((deg<<1)-1)+1);
		Polynomial q(nn);
		__polyinv(c.data(),q.c.data(),deg);
		return q.resize(deg),q;
	}
	friend inline Polynomial ln(const Polynomial &p)
	{
		if(p[0]!=1) cerr<<"[x^0]f(x)!=1, ln(f(x)) doesn't exist.\n",abort();
		int nn=1<<(__lg((p.deg<<1)-1)+1);
		Polynomial q(nn);
		__polyln(p.c.data(),q.c.data(),p.deg);
		return q.resize(p.deg),q;
	}
	friend inline Polynomial exp(const Polynomial &p)
	{
		if(p[0]!=0) cerr<<"[x^0]f(x)!=0, exp(f(x)) doesn't exist.\n",abort();
		static int c[MAXN];
		int nn=1<<(__lg(p.deg-1)+1);
		for(int i=0;i<p.deg;i++) c[i]=mmul(i,p[i]);
		Polynomial q(nn);
		__polyexp(c,q.c.data(),0,nn);
		return q.resize(p.deg),q;
	}
	friend inline pair<Polynomial,Polynomial> div(const Polynomial &f,const Polynomial &g)
	{
		if(f.deg<g.deg) return make_pair(Polynomial{0},f);
		int n=f.deg-1,m=g.deg-1;
		Polynomial fr(n+1),gr(m+1);
		for(int i=0;i<=n;i++) fr[i]=f[n-i];
		for(int i=0;i<=m;i++) gr[i]=g[m-i];
		fr.resize(n-m+1),gr.resize(n-m+1),fr*=gr.inv();
		fr.resize(n-m+1),reverse(fr.c.begin(),fr.c.end());
		gr=f-fr*g,gr.resize(m);
		return make_pair(fr,gr);
	}
	inline Polynomial &operator =(const Polynomial &p)
		{ return deg=p.deg,c=p.c,*this; }
	inline Polynomial &operator =(Polynomial &&p)
		{ return deg=p.deg,c=move(p.c),*this; }
	inline Polynomial &operator *=(int k)
		{ for(auto &i: c) i=mmul(i,k);return *this; }
	inline Polynomial &operator /=(const Polynomial &rhs)
		{ return (*this)*=rhs.inv(); }
	inline Polynomial &operator %=(const Polynomial &rhs)
		{ return (*this)=div(*this,rhs).second; }
	inline Polynomial operator +(const Polynomial &rhs)const
		{ return Polynomial(*this)+=rhs; }
	inline Polynomial operator -(const Polynomial &rhs)const
		{ return Polynomial(*this)-=rhs; }
	inline Polynomial operator *(const Polynomial &rhs)const
		{ return Polynomial(*this)*=rhs; }
	inline Polynomial operator /(const Polynomial &rhs)const
		{ return Polynomial(*this)/=rhs; }
	inline Polynomial operator %(const Polynomial &rhs)const
		{ return div(*this,rhs).second; }
	friend inline Polynomial operator *(const Polynomial &p,int k)
		{ return Polynomial(p)*=k; }
	friend inline Polynomial operator *(int k,const Polynomial &p)
		{ return Polynomial(p)*=k; } 
};
int Polynomial::g[]={},Polynomial::c1[]={},Polynomial::c2[]={};
};
using namespace ZPoly;
int m,b,c,pw[405];
struct bignum{
    int a[405],len;
    void trim(){
        int r=0;
        for(int i=0;i<=len;i++)
        {a[i]+=r;r=floor(1.0*a[i]/b);a[i]-=r*b;}
        while(r!=0)
        {++len;a[len]=r;r=floor(1.0*a[len]/b);a[len]-=r*b;}
        while(len>=0&&a[len]==0)len--;
    }
    int getval()
    {
        int ret=0;
        for(int i=0;i<=len;i++)
            ret=(ret+1ll*pw[i]*a[i]%MOD)%MOD;
        return ret;
    }
    void sub(bignum &o)
    {
        for(int i=0;i<=o.len;i++)
            a[i]-=o.a[i];
        trim();
    }
}val[405],N;
bool comp(bignum a,bignum b)
{
    if(a.len!=b.len)return a.len<b.len;
    for(int i=a.len;i>=0;i--)if(a.a[i]!=b.a[i])return a.a[i]<b.a[i];
    return 0;
}
char s[100000];
int len;
int w[100000],tmp[100000],t[405],inv[405];
Polynomial f[2],g[2],h,ini;
int main()
{
    Polynomial::init();
    scanf("%d%d%d",&m,&b,&c);
    pw[0]=1;for(int i=1;i<=400;i++)pw[i]=1ll*pw[i-1]*b%MOD;
    inv[1]=1;for(int i=2;i<=400;i++)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
    for(int i=1;i<=m;i++)
        val[i].len=i,val[i].a[i]=1,val[i].a[0]-=c-1,val[i].trim();
    scanf("%s",s);
    len=strlen(s);reverse(s,s+len);len--;
    for(int i=0;i<=len;i++)w[i]=s[i]-'0';
    N.len=-1;
    while(len>=0)
    {
        int r=0;
        for(int i=len;i>=0;i--)
        {tmp[i]=(w[i]+1ll*r*10)/b;r=(1ll*r*10+w[i])%b;}
        for(int i=0;i<=len;i++)w[i]=tmp[i],tmp[i]=0;
        while(len>=0&&w[len]==0)len--;
        N.a[++N.len]=r;
    }   
    if(N.len==-1)
    {
        puts("0");
        return 0;
    }
    N.a[0]+=m-1;N.trim();
    int n=N.getval();
    f[0].resize(m+5),f[1].resize(m+5);
    g[0].resize(m+5),g[1].resize(m+5);
    ini.resize(m+5);
    h.resize(m+5);
    f[0][0]=1;
    for(int val=1,i=1;i<=m;i++)
    {
        val=1ll*val*(n-i+1)%MOD;
        val=1ll*val*inv[i]%MOD;
        f[0][i]=val;
    }
    for(int i=m;i>=1;i--)if(comp(val[i],N))t[i]=1,N.sub(val[i]);
    for(int i=m;i>=1;i--)
    {
		f[0].resize(m+5);f[1].resize(m+5);
        int nowa=(MOD-val[i].getval())%MOD;
        h[0]=1;
        for(int val=1,i=1;i<=m;i++)
        {
            val=1ll*val*(nowa-i+1)%MOD;
            val=1ll*val*inv[i]%MOD;
            h[i]=val;
        }
        if(t[i])
        {
            g[0]=ini-f[0]*h;
            g[1]=f[1]-f[1]*h+f[0];
        }
        else
        {
            g[0]=f[0];
            g[1]=f[1]-f[1]*h;
        }
        f[0]=g[0],f[1]=g[1];
    }
    int ans=(f[0][m]+f[1][m])%MOD;
    printf("%d\n",ans);
    return 0;
}