P9669 [ICPC2022 Jinan R] DFS Order 2

发布时间 2023-10-24 20:25:48作者: Thunder_S

Description

P 有一棵树,根节点是 \(1\),总共有 \(n\) 个节点,从 \(1\)\(n\) 编号。

他想从根节点开始进行深度优先搜索。他想知道对于每个节点 \(v\),在深度优先搜索中,它出现在第 \(j\) 个位置的方式有多少种。深度优先搜索的顺序是在搜索过程中访问节点的顺序。节点出现在第 \(j(1 \leq j \leq n)\) 个位置表示它在访问了 \(j-1\) 个其他节点之后才被访问。因为节点的子节点可以以任意顺序访问,所以有多种可能的深度优先搜索顺序。

P 想知道对于每个节点 \(v\),有多少种不同的深度优先搜索顺序,使得 \(v\) 出现在第 \(j\) 个位置。对于每个 \(v\)\(j(i \leq v, j \leq n)\),计算答案。答案可能很大,所以输出时要取模 \(998244353\)。以下是深度优先搜索的伪代码,用于处理树。在调用 main() 函数后,dfs_order 将会包含深度优先搜索的顺序。

https://z1.ax1x.com/2023/10/24/piEhAWq.png

Solution

先考虑怎么求总方案数。设 \(t_x\) 表示以 \(x\) 为根节点的子树的总方案数。有 \(t_x=son_x!\times \Pi_{y\in son} t_{y}\)\(son_x\) 表示 \(x\) 的儿子数)。

\(ans_{x,i}\) 表示点 \(i\)\(\text{dfs}\) 序中的位置为 \(i\) 的答案(不考虑 \(x\) 子树内部的顺序)。因为 \(\text{dfs}\) 是从 \(x\) 走到 \(x\) 的儿子 \(y\),所以转移也考虑从 \(x\) 转移到 \(y\)。枚举 \(y\)\(x\)\(\text{dfs}\) 序差几,有 \(ans_{y,j}=\sum ans_{x,i}\times g_{j-i}\),其中 \(g_{j-i}\) 表示在 \(\text{dfs}\) 序中 \(y\)\(x\)\(j-i\) 的方案数。

考虑 \(g\) 的转移。枚举当前节点 \(y\) 前有多少个点 \(i\),它们的大小的和是 \(j\)。设 \(f_{i,j}\) 表示在 \(x\) 的儿子中选了 \(i\) 个,大小为 \(j\) 的方案数。\(f\) 可以用背包来维护。那么 \(g_{j+1}=\sum f_{i,j}\times j!\times (son_x-1-j)!\times \frac{t_x}{t_y\times son_x!}\)。其中,\(j!\)\((son_x-1-j)!\) 分别表示 \(y\) 前和 \(y\) 后的兄弟节点的排列顺序,\(\frac{t_x}{t_y\times son_x!}\) 表示 \(x\) 的儿子中,除了 \(y\) 的子树内部方案数的积。因为 \(ans\) 定义的时候就要求不能考虑 \(y\) 子树内部的方案数,因此要除去 \(y\) 的子树带来的贡献。

那么最后答案为 \(ans_{x,i}\times t_x\)

但是,求 \(f\) 的过程是 \(\mathcal O(n^4)\) 的。因此考虑写回滚背包,就是先求出所有儿子带来的贡献,再减去目标儿子节点带来的贡献即可。优化至 \(\mathcal O(n^3)\)

Code

#include<cstdio>
#include<cstring>
#define N 505
#define mod 998244353
#define ll long long
using namespace std;
int n,tot,son[N],siz[N];
ll jc[N],ans[N][N],f[N][N],g[N],t[N];
struct node {int to,next,head;}a[N<<1];
void add(int x,int y)
{
    a[++tot].to=y;a[tot].next=a[x].head;a[x].head=tot;
    a[++tot].to=x;a[tot].next=a[y].head;a[y].head=tot;
}
ll ksm(ll x,ll y)
{
    ll res=1;
    while (y)
    {
        if (y&1) res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
void gett(int x,int fa)
{
    ll res=1;
    siz[x]=1;t[x]=1;
    for (int i=a[x].head;i;i=a[i].next)
    {
        int y=a[i].to;
        if (y==fa) continue;
        gett(y,x);
        son[x]++;siz[x]+=siz[y];
        t[x]=t[x]*t[y]%mod;
    }
    t[x]=t[x]*jc[son[x]]%mod;
}
void dfs(int x,int fa)
{
    memset(f,0,sizeof(f));
    f[0][0]=1;
    for (int i=a[x].head;i;i=a[i].next)
    {
        int y=a[i].to;
        if (y==fa) continue;
        for (int j=son[x];j;--j)
            for (int k=siz[y];k<=siz[x];++k)
                f[j][k]=(f[j][k]+f[j-1][k-siz[y]])%mod;
    }
    for (int i=a[x].head;i;i=a[i].next)
    {
        int y=a[i].to;
        ll base=t[x]*ksm(t[y],mod-2)%mod*ksm(jc[son[x]],mod-2)%mod;
        if (y==fa) continue;
        for (int j=1;j<=son[x];++j)
            for (int k=siz[y];k<=siz[x];++k)
                f[j][k]=(f[j][k]-f[j-1][k-siz[y]]+mod)%mod;
        memset(g,0,sizeof(g));
        for (int j=0;j<son[x];++j)
            for (int k=0;k<siz[x];++k)
                g[k+1]=(g[k+1]+f[j][k]*jc[j]%mod*jc[son[x]-1-j]%mod*base%mod)%mod;
        for (int j=1;j<=n;++j)
            for (int k=j+1;k<=n;++k)
                ans[y][k]=(ans[y][k]+ans[x][j]*g[k-j]%mod)%mod;
        for (int j=son[x];j;--j)
            for (int k=siz[y];k<=siz[x];++k)
                f[j][k]=(f[j][k]+f[j-1][k-siz[y]])%mod;
    }
    for (int i=a[x].head;i;i=a[i].next)
    {
        int y=a[i].to;
        if (y==fa) continue;
        dfs(y,x);
    }
}
int main()
{
    scanf("%d",&n);
    for (int i=1;i<n;++i)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
    }
    jc[0]=1;
    for (int i=1;i<=n;++i)
        jc[i]=jc[i-1]*(ll)i%mod;
    gett(1,0);
    ans[1][1]=1;
    dfs(1,0);
    for (int i=1;i<=n;++i)
    {
        for (int j=1;j<=n;++j)
            printf("%lld ",ans[i][j]*t[i]%mod);
        printf("\n");
    }
    return 0;
}