Subtree 题解

发布时间 2023-08-11 17:28:56作者: TKXZ133

Subtree

题目大意

给定一颗树,你可以选出一些节点,你需要对于每个点求出在强制选这个点的情况下所有选择的点联通的方案数,对给定模数取模。

思路分析

对于这种求树上每一个点方案数的题目,首先考虑换根 DP。

强制嵌定树根为 \(1\),设 \(f_i\) 表示在 \(i\) 的子树中选点,\(i\) 强制选,所有选择的点联通的方案数,\(g_i\) 表示在 \(i\) 的子树外选点,\(i\) 强制选,所有选择的点联通的方案数,那么显然点 \(s\) 的答案就是 \(f_s\times g_s\)

  • 考虑计算 \(f\)

对于叶节点 \(s\),显然 \(f_s=1\),对于非叶节点,容易得出状态转移方程:

\[f_{u}=\prod_{v\in \text{son}_{u}}(f_v+1) \]

解释一下,\(f_v+1\) 就是 \(u\) 的一个子节点的子树染色的方案数,而 \(u\) 的子树的染色方案数就是所有 \(f_v+1\) 的乘积。

  • 考虑计算 \(g\)

对于根节点 \(1\),显然 \(g_1=1\),对于非根节点,不难得出状态转移方程:

\[g_v=g_{u}\times\frac{f_{u}}{f_v+1},u=\text{fa}_{v} \]

解释一下,从 \(g_u\) 转移到 \(g_v\),新增的节点就是 \(u\) 的子树去掉 \(v\) 的子树中的点后的所有点,而这些点染色的方案数就是 \(\frac{f_{u}}{f_{v}+1}\),也可以理解为在 \(f_u\) 中去掉所有由 \(v\) 产生的贡献。

但是直接求肯定是没法求的,模数不一定是质数,不一定存在逆元,但是我们发现我们可以将除法改为乘法,也即:

\[g_{v}=g_{u}\times\prod_{p\not =v,p\in \text{son}_u} (f_p+1) \]

而这个可以通过预处理每个节点的子节点权值的前缀积和后缀积实现。

故我们只需要通过两遍 dfs 就可以在 \(O(n)\) 的时间空间内解决问题。

代码

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>

using namespace std;
const int N=200200;
#define int long long 

int n,mod,in1,in2,idx=1;
int to[N],nxt[N],head[N];
int f[N],g[N];

vector<int> pre[N],suf[N];

void add(int u,int v){
    idx++;to[idx]=v;nxt[idx]=head[u];head[u]=idx;
}

void dfs_1(int s,int fa){
    f[s]=1;
    for(int i=head[s];i;i=nxt[i]){
        int v=to[i];
        if(v==fa) continue;
        dfs_1(v,s);
        f[s]=f[s]*(f[v]+1)%mod;
        pre[s].push_back(f[v]+1);
        suf[s].push_back(f[v]+1);
    }
    for(int i=1;i<pre[s].size();i++) 
        pre[s][i]=pre[s][i]*pre[s][i-1]%mod;//前缀积
    for(int i=suf[s].size()-2;i>=0;i--) 
        suf[s][i]=suf[s][i]*suf[s][i+1]%mod;//后缀积
}

void dfs_2(int s,int fa){
    int num=0,x=pre[s].size();
    for(int i=head[s];i;i=nxt[i]){
        int v=to[i];
        if(v==fa) continue;
        num++;
        if(x==1) g[v]=g[s]+1; //一些特判,可能不需要
        else if(num==1) g[v]=g[s]*suf[s][num]%mod+1;
        else if(num==x) g[v]=g[s]*pre[s][num-2]%mod+1;
        else g[v]=g[s]*(pre[s][num-2]*suf[s][num]%mod)%mod+1;
        dfs_2(v,s);
    }
}

signed main(){
    scanf("%lld%lld",&n,&mod);
    for(int i=1;i<n;i++){
        scanf("%lld%lld",&in1,&in2);
        add(in1,in2);add(in2,in1);
    }    
    dfs_1(1,0);
    g[1]=1;
    dfs_2(1,0);
    for(int i=1;i<=n;i++)
        cout<<(f[i]*g[i]%mod)<<'\n';
    return 0;
}