[ABC309Ex] Simple Path Counting Problem

发布时间 2023-09-22 19:41:58作者: 灰鲭鲨

Problem Statement

We have a grid with $N$ rows and $M$ columns. We denote by $(i,j)$ the cell in the $i$-th row from the top and $j$-th column from the left.

You are given integer sequences $A=(A_1,A_2,\dots,A_K)$ and $B=(B_1,B_2,\dots,B_L)$ of lengths $K$ and $L$, respectively.

Find the sum, modulo $998244353$, of the answers to the following question over all integer pairs $(i,j)$ such that $1 \le i \le K$ and $1 \le j \le L$.

  • A piece is initially placed at $(1,A_i)$. How many paths are there to take it to $(N,B_j)$ by repeating the following move $(N-1)$ times?
    • Let $(p,q)$ be the piece's current cell. Move it to $(p+1,q-1),(p+1,q)$, or $(p+1,q+1)$, without moving it outside the grid.

Constraints

  • $1 \le N \le 10^9$
  • $1 \le M,K,L \le 10^5$
  • $1 \le A_i,B_j \le M$
先考虑如何计算某一个点到另一个点的方案数。

定义 \(dp_{i,j}\) 为到达 \((i,j)\) 的方案数,那么 \(dp_{i,j}=dp_{i-1,j-1}+dp_{i-1,j}+dp_{i-1,j+1}\)

看似可以多项式快速幂优化,但是会发现有边界问题。

如何解决边界问题?对称一下,使 \(dp_{n+1-i}=-dp_i\),然后卷积的时候就可以把 \(j\le N\) 的限制给去掉。

\(x>0\) 的限制怎么去? 用 \(2n+2\) 的循环卷积即可。

然后循环卷积快速幂就可以了。

原题也一样,循环卷积跑出来 \((1+x+x^{-1})^N\) 的系数卷上 \(A\) 数组就可以了。

#include<bits/stdc++.h>
using namespace std;
const int N=524288,P=998244353;
int n,m,k,a,b,rr[N],ans,f[N],g[N],h[N],p,l;
int read()
{
    int s=0;
    char ch=getchar();
    while(ch<'0'||ch>'9')
        ch=getchar();
    while(ch>='0'&&ch<='9')
        s=s*10+ch-48,ch=getchar();
    return s;
}
int pown(int x,int y)
{
    if(!y)
        return 1;
    int t=pown(x,y>>1);
    if(y&1)
        return 1LL*t*t%P*x%P;
    return 1LL*t*t%P;
}
void ntt(int a[],int op)
{
    for(int i=1;i<N;i++)
        if(rr[i]<i)
            swap(a[i],a[rr[i]]);
    for(int md=1;md<N;md<<=1)
    {
        int g=pown(op? 3:332748118,(P-1)/(md<<1));
        for(int i=0;i<N;i+=md<<1)
        {
            int pw=1;
            for(int j=0;j<md;j++,pw=1LL*g*pw%P)
            {
                int x=a[i+j+md]*1LL*pw%P;
                a[i+j+md]=(a[i+j]+P-x)%P;
                (a[i+j]+=x)%=P;
            }
        }
    }
    if(!op)
    {
        int ivN=pown(N,P-2);
        for(int i=0;i<N;i++)
            a[i]=1LL*a[i]*ivN%P;
    }
}
void mul(int a[],int b[])
{
    ntt(a,1);
    ntt(b,1);
    for(int i=0;i<N;i++)
        a[i]=1LL*a[i]*b[i]%P;
    ntt(a,0);
    for(int i=p;i<N;i++)
        (a[i%p]+=a[i])%=P,a[i]=0;
}
void solve(int x)
{
    if(x==1)
        return;
    if(x&1)
    {
        solve(x-1);
        memcpy(h,f,sizeof(h));
        for(int i=0;i<p;i++)
            f[i]=(1LL*h[i]+h[(i+p-1)%p]+h[(i+1)%p])%P;
    }
    else
    {
        solve(x>>1);
        memcpy(h,f,sizeof(h));
        mul(f,h);
    }
}
int main()
{
    n=read(),m=read(),k=read(),l=read();
    p=2*m+2;
    for(int i=0;i<N;i++)
        rr[i]=rr[i>>1]>>1|(i&1)*(N/2);
    for(int i=1;i<=k;i++)
        a=read(),g[a]++,(g[2*m-a+2]+=P-1)%=P;
    if(n^1)
    {
        f[0]=f[p-1]=f[1]=1;
        solve(n-1);
        mul(g,f);
    }
    for(int i=1;i<=l;i++)
        (ans+=g[read()])%=P;
    printf("%d",ans);
}