题意
给定一棵 \(n\) 个结点的树,你从点 \(x\) 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 \(Q\) 次询问,每次询问给定一个集合 \(S\),求如果从 \(x\) 出发一直随机游走,直到点集 \(S\) 中所有点都至少经过一次的话,期望游走几步。
特别地,点 \(x\)(即起点)视为一开始就被经过了一次。
答案对 \(998244353\) 取模。
思路
不难发现,直接求经过点集 \(S\) 中所有点是无法做到的,于是可以用 min-max 容斥转化一下,令 \(E(\min(T))\) 为第一次经过 \(T\) 中节点的期望步数。那么就有:
\[E(\max(S))=\sum_{T \subseteq S} (-1)^{|T|+1} E(\min(T))
\]
考虑如何求出 \(E(\min(T))\)。可以令 \(f_{S,i}\) 表示以 \(i\) 为起点的时候,第一次经过 \(S\) 节点的期望步数。当 \(i \in S\) 时,\(f_{S,i}=0\),而当 \(i \notin S\) 时,可以得到转移方程:
\[f_{S,u}=\frac{f_{S,fa_u}+\sum_{v \in son_u} f_{S,v}}{deg_u}+1
\]
如果这是在一般图上的转移方程,就要用到高斯消元了。而由于本题是在一棵树上游走,于是可以用到待定系数法。即设 \(f_{S,u}=A_u \times f_{S,fa_u}+B_u\),于是就可以对转移方程进行变形:
\[f_{S,u}=\frac{f_{S,fa_u}+\sum_{v \in son_u} (A_v \times f_{S,u}+B_v)}{deg_u}+1
\]
\[=\frac{f_{S,fa_u}+f_{S,u} \times \sum_{v \in son_u} A_v+\sum_{v \in son_u} B_v}{deg_u}+1
\]
\[f_{S,u} \times deg_u={f_{S,fa_u}+f_{S,u} \times \sum_{v \in son_u} A_v+\sum_{v \in son_u} B_v}+deg_u
\]
\[f_{S,u} \times (deg_u-\sum_{v \in son_u} A_v)={f_{S,fa_u}+\sum_{v \in son_u} B_v}+deg_u
\]
\[f_{S,u}=\frac{1}{deg_u-\sum_{v \in son_u} A_v} \times f_{S,fa_u}+\frac{\sum_{v \in son_u} B_v+deg_u}{deg_u-\sum_{v \in son_u} A_v}
\]
于是就可以得到:
\[A_u=\frac{1}{deg_u-\sum_{v \in son_u} A_v},B_u=\frac{\sum_{v \in son_u} B_v+deg_u}{deg_u-\sum_{v \in son_u} A_v}
\]
那么 \(A_u\) 和 \(B_u\) 就可以直接求了。而对于根节点来说,其不能从父亲转移过来,于是就有 \(f_{S,rt}=B_{rt}\)
按照套路,接下来只需要用 FWT 求高维前缀和即可。
code:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=20,M=(1<<20)+10,mod=998244353;
int d[N],siz[M],n,q,rt,h[N],idx,A[N],B[N],f[M];
struct edge{int v,nex;}e[N<<1];
void add(int u,int v){e[++idx]=edge{v,h[u]};h[u]=idx;d[v]++;}
void Add(int &a,int b){a+=b;a-=a>=mod?mod:0;}
int mul(int a,int b){int res=1;while(b) ((b&1)&&(res=1ll*res*a%mod)),a=1ll*a*a%mod,b>>=1;return res;}
void dfs(int u,int fa,int S)
{
int Sa=0,Sb=0;if(S&(1<<u-1)) return;
for(int i=h[u];i;i=e[i].nex)
{
int v=e[i].v;if(v==fa) continue;dfs(v,u,S);
Add(Sa,A[v]);Add(Sb,B[v]);
}
A[u]=mul((d[u]-Sa+mod)%mod,mod-2);
B[u]=1ll*A[u]*(Sb+d[u])%mod;
}
int main()
{
scanf("%d%d%d",&n,&q,&rt);for(int u,v,i=1;i<n;i++) scanf("%d%d",&u,&v),add(u,v),add(v,u);
for(int i=1;i<(1<<n);i++) siz[i]=siz[i-(i&-i)]+1;
for(int T=1;T<(1<<n);T++)
{
for(int i=1;i<=n;i++) A[i]=B[i]=0;dfs(rt,0,T);
if(siz[T]&1) f[T]=B[rt];else f[T]=mod-B[rt];
}
for(int i=0;i<n;i++) for(int j=0;j<(1<<n);j++) if((j>>i)&1) Add(f[j],f[j^(1<<i)]);
while(q--)
{
int S=0,x,m;scanf("%d",&m);
while(m--) scanf("%d",&x),S|=1<<x-1;
printf("%d\n",f[S]);
}
return 0;
}