[ABC259Ex] Yet Another Path Counting

发布时间 2023-11-17 14:25:26作者: MrcFrst

\(\text{Links}\)

[ABC259Ex] Yet Another Path Counting

Luogu Blog


题外话

  • 淀粉质题单做不动了怎么办?来做一道根号题振奋一下精神吧/se!

  • 我要饿死了,我要吃饭,以后在学校还是不要不吃早饭了/kk


题意

给一个 \(n\times n\) 的网格图,每个格子上有一个颜色。

每一步只能往右或者往下走,问有多少条路径的起点和终点的颜色相同,对 \(998244353\) 取模。

\(n\le 400\)\(2.00s\)


题解

不同颜色的统计互不干扰,所以按颜色分开来统计。

考虑有用的信息只有起点和终点的颜色,所以枚举点对,组合数计算贡献,即 \(y2-y1+x2-x1\choose x2-x1\)。复杂度为 \(O(siz^2)\),其中 \(siz\) 为这种颜色的点数。

发现如果同种颜色的点数过大的话这个做法会 G。并且很难维护合并组合数的计算来降低复杂度。

但是点数有一个限制,即所有颜色的点数加起来为 \(n^2\),于是可以考虑根号分治了!

设置阈值 \(T\),当 \(siz\le T\) 时,直接用上面的暴力做法,此部分总时间复杂度为 \(O(\frac{n^2}{T}\times T^2)\),即 \(O(n^2T)\)

\(siz\gt T\) 时,这样的颜色最多只有 \(\frac{n}{T}\) 种,那么对于每个颜色再搞个暴力做法。

考虑,这个暴力做法时间复杂度的正确性应该是不依赖于 \(siz\) 的,不然我们根分有什么用呢?全部用这个做法不就好了吗。

所以考虑 \(O(n^2)\)\(dp\),钦定我们当前 \(solve\) 的颜色为 \(col\)。设 \(dp_{i,j}\) 表示从颜色为 \(col\) 的格子走到位置 \((i,j)\) 的方案数。

转移很简单:\(dp_{i,j}=dp_{i,j-1}+dp_{i-1,j}+[a_{i,j}=col]\)

于是每个位置 \((i,j)\)\(ans\) 的贡献应该是 \([a_{i,j}=col]\times f_{i,j}\)。此部分总时间复杂度为 \(O(\frac{n^2}{T}\times n^2)\),即 \(O(\frac{n^4}{T})\)

然后就做完了。取 \(T=n\) 的时候达到平衡,总时间复杂度为 \(O(n^3)\)

代码非常简单。(由于不怎么习惯用大量的 \(pair\),所以这篇码风可能比较诡异)


\(\text{Code}\)

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define il inline
#define re register
const int N=405,T=400,mod=998244353;
int n,a[N][N],ans,fac[N<<1],inv[N<<1],invfac[N<<1],f[N][N];
#define pii pair<int,int>
#define mp make_pair
vector<pii >v[N*N];
il int read(){
    re int x=0,f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*f;
}
il void Add(int &x,int y){
    x=(x+y)%mod;
}
il int C(int n,int m){
    if(n<0||m<0||n<m)return 0;
    return fac[n]*invfac[m]%mod*invfac[n-m]%mod;
}
il bool cmp(pii x,pii y){
    return y.first>=x.first&&y.second>=x.second;
}
il int disx(pii x,pii y){
    return y.first-x.first;
}
il int disy(pii x,pii y){
    return y.second-x.second;
}
#define nowi v[col][i]
#define nowj v[col][j]
il void solve1(int col){
    int siz=(int)v[col].size();
    for(re int i=0;i<siz;i++)
    for(re int j=i;j<siz;j++)
        if(cmp(nowi,nowj))Add(ans,C(disx(nowi,nowj)+disy(nowi,nowj),disx(nowi,nowj)));
}
il void solve2(int col){
    for(re int i=1;i<=n;i++)
        for(re int j=1;j<=n;j++){
            f[i][j]=(f[i-1][j]+f[i][j-1]+(a[i][j]==col))%mod;
            if(a[i][j]==col)Add(ans,f[i][j]);
        }
}
il void GetInv(){
    inv[1]=fac[1]=invfac[1]=fac[0]=invfac[0]=1;
    for(re int i=2;i<=(n<<1);i++){
        inv[i]=inv[mod%i]*(mod-mod/i)%mod;
        fac[i]=fac[i-1]*i%mod;
        invfac[i]=invfac[i-1]*inv[i]%mod;
    }
}
signed main(){
    n=read();
    GetInv();
    for(re int i=1;i<=n;i++)
        for(re int j=1;j<=n;j++)
            a[i][j]=read(),v[a[i][j]].push_back(mp(i,j));
    for(re int col=1;col<=n*n;col++){
        if(v[col].empty())continue;
        if((int)v[col].size()<=T)solve1(col);
        else solve2(col);
    }
    cout<<ans;
    return 0;
}