P3412 仓鼠找sugar II 题解

发布时间 2023-11-18 21:32:29作者: 流星Meteor

P3412 仓鼠找sugar II 题解

大水题一个

题目大意

给定你一个树,设 \(f_{u, v}\) 表示在树上随机游走的情况下从 \(u\) 走到 \(v\) 的期望步数,求 \(\displaystyle \frac{\sum_{i = 1}^n \sum_{j = 1}^n f_{i, j}}{n^2}\)

题解

不难想到 dp,不过 \(1e5\) 的范围差点让我怀疑我 \(O(n)\) 的 dp。先设一下状态,设 \(f_u\) 表示 \(u\) 子树内的所有点全都走到点 \(u\) 的期望步数。答案就是以每个点为根时根的 \(f\) 值的和。

考虑怎么转移。

似乎不好直接转,于是想想我们转移时什么东西卡住了我们。假设现在 \(u\) 子树内的所有点都走到了 \(u\),那么我们现在想要让这些点再从 \(u\) 结点走向它的父亲结点,这个期望步数不好直接求。

于是我们再设 \(g_u\) 表示从 \(u\) 结点走到它的父亲结点的期望步数。先来考虑它的转移。\(deg_u\) 表示 \(u\) 结点的度,即与它相连的边数,\(son_u\) 表示 \(u\) 结点的儿子构成的集合。

\[\begin{aligned} g_u &= \frac{1}{deg_u} + \sum_{v \in son_u} \frac{1 + g_v + g_u}{deg_u}\\ deg_u \times g_u &= 1 + \sum_{v \in son_u} (1 + g_v + g_u)\\ &= 1 + (deg_u - 1) + (deg_u - 1) \times g_u + \sum_{v \in son_u} g_v\\ &= deg_u + (deg_u - 1) \times g_u + \sum_{v \in son_u} g_v\\ g_u &= deg_u + \sum_{v \in son_u} g_v \end{aligned}\]

\(g\) 的转移就没了,再来考虑 \(f\)

\[\begin{aligned} f_u &= \sum_{v \in son_u} f_v + size_v \times g_v \end{aligned}\]

这个非常好理解。

于是可以打 \(n^2\) 了。

换一下根,就 \(O(n)\) 了。

\(h(x)\) 为以 \(x\) 为根时 \(x\)\(f\) 值,那么有:

\[\begin{aligned} h_u &= f_u + (h_{fa} - f_u - size_u \times g_u) + (n - size_u) \times (g_{fa} - g_u) \end{aligned}\]

最终答案为 \(\displaystyle \frac{\sum_i^n h_i}{n^2}\)

然后就没了。

代码

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int M = 100005;
const int mod = 998244353;
int n, f[M], g[M], siz[M], out[M], inv, ans, h[M];
int from[M << 1], to[M << 1], head[M], nex[M << 1], tot;

inline void add_edge(int u, int v) {
    from[++ tot] = u;
    to[tot] = v;
    nex[tot] = head[u];
    head[u] = tot;
}

void dfs1(int u, int fa) {
    g[u] = out[u];
    siz[u] = 1;
    for(int i = head[u]; i; i = nex[i]) {
        int v = to[i];
        if(v == fa)
            continue;
        dfs1(v, u);
        siz[u] += siz[v];
        g[u] = (g[u] + g[v]) % mod;
        f[u] = (f[u] + f[v] + siz[v] * g[v] % mod) % mod;
    }
}

void dfs2(int u, int fa) {
    h[u] = (f[u] + (h[fa] - f[u] + mod - siz[u] * g[u] % mod + mod) % mod + (n - siz[u]) * ((g[fa] - g[u] + mod) % mod) % mod) % mod;
    g[u] = g[u] + (g[fa] - g[u]);
    ans = (ans + h[u]) % mod;
    for(int i = head[u]; i; i = nex[i]) {
        int v = to[i];
        if(v == fa)
            continue;
        dfs2(to[i], u);
    }
}

inline int quick_pow(int base, int ci, int pp) {
    int res = 1;
    while(ci) {
        if(ci & 1)
            res = res * base % pp;
        base = base * base % pp;
        ci >>= 1;
    }
    return res;
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
    cin >> n;
    inv = quick_pow(n, mod - 2, mod);
    for(int i = 1; i < n; ++ i) {
        int u, v;
        cin >> u >> v;
        add_edge(u, v);
        add_edge(v, u);
        ++ out[u];
        ++ out[v];
    }
    dfs1(1, 0);
    ans = f[1];
    h[1] = f[1];
    for(int i = head[1]; i; i = nex[i]) 
        dfs2(to[i], 1);
    ans = ans * inv % mod * inv % mod;
    cout << ans;
}