CF1823F Random Walk 题解

发布时间 2023-08-20 20:06:07作者: User-Unauthorized

题意

给定一棵由 \(n\) 个节点组成的树,定义每次移动的方式为等概率的移动到相邻节点上,询问从 \(s\) 移动到 \(t\) 的过程中每个点的期望经过次数。

\(1 \le n \le 2 \times 10^5\))。

题解

定义 \(f_i\) 为节点 \(i\) 的期望经过次数,\(fa_u\) 为节点 \(u\) 的父亲节点,\(\operatorname{deg}_u\) 表示节点 \(u\) 的度数,\(\operatorname{son}_u\) 表示节点 \(u\) 的子节点集合。

我们记路径 \(s \rightarrow t\) 上的点为 \(k_0, k_1, k_2, \cdots k_m\),其中 \(k_0 = s, k_m = t\)。我们可以发现对于任意 \(r_i\) 在去除路径上的边连接的子树后都会形成一棵以自己为根的有根树,记为 \(\operatorname{subtree}_{r_i}\)。通过观察可以发现,对于这类子树叶子节点 \(v\),有

\[f_v = \dfrac{1}{\operatorname{deg}_{fa_v}} f_{fa_v} \]

考虑推广这一结论,对于在子树中的节点 \(u\),有 \(f_u = \dfrac{\operatorname{deg}_u}{\operatorname{deg}_{fa_u}} f_{fa_u}\),下面给出数学归纳法的证明

\[\begin{aligned} f_u &= \sum\limits_{\left(u, v\right) \in E} \dfrac{1}{\operatorname{deg}_v}f_v \\ &= \sum\limits_{v \in \operatorname{son}_u} \dfrac{1}{\operatorname{deg}_v}f_v + \dfrac{1}{\operatorname{deg}_{fa_u}}f_{fa_u} \\ &= \sum\limits_{v \in \operatorname{son}_u} \dfrac{1}{\operatorname{deg}_v} \dfrac{\operatorname{deg}_v}{\operatorname{deg}_u} f_u + \dfrac{1}{\operatorname{deg}_{fa_u}} f_{fa_u} \\ &= \sum\limits_{v \in \operatorname{son}_u} \dfrac{1}{\operatorname{deg}_u} f_u + \dfrac{1}{\operatorname{deg}_{fa_u}} f_{fa_u} \\ &= \dfrac{\operatorname{deg}_u - 1}{\operatorname{deg}_u} f_u + \dfrac{1}{\operatorname{deg}_{fa_u}} f_{fa_u} \\ &= \dfrac{\operatorname{deg}_u}{\operatorname{deg}_{fa_u}} f_{fa_u} \end{aligned}\]

推广该结论,设 \(v \in \operatorname{son}_u, pa = fa_u\)

\[f_v = \dfrac{\operatorname{deg}_v}{\operatorname{deg}_u} f_u = \dfrac{\operatorname{deg}_v}{\operatorname{deg}_u} \dfrac{\operatorname{deg}_u}{\operatorname{deg}_{pa}} f_{pa} \]

对于 \(\forall v \in \operatorname{subtree}_{u}\),有

\[f_v = \dfrac{\operatorname{deg}_v}{\operatorname{deg}_u} f_u \]


现在考虑路径 \(s \rightarrow t\) 上的点

\[\begin{aligned} f_{k_0} &= 1 + \sum\limits_{\left(k_0, v\right) \in E} \dfrac{1}{\operatorname{deg}_v}f_v \\ &= 1 + \sum\limits_{\left(k_0, v\right) \in E \land v \neq k_1}\dfrac{1}{\operatorname{deg}_v}f_v + \dfrac{1}{\operatorname{deg}_{k_1}} f_{k_1} \\ &= 1 + \sum\limits_{\left(k_0, v\right) \in E \land v \neq k_1}\dfrac{1}{\operatorname{deg}_v}\dfrac{\operatorname{deg}_v}{\operatorname{deg}_{k_0}} f_{k_0} + \dfrac{1}{\operatorname{deg}_{k_1}} f_{k_1} \\ &= 1 + \dfrac{\operatorname{deg}_{k_0} - 1}{\operatorname{deg}_{k_0}} f_{k_0} + \dfrac{1}{\operatorname{deg}_{k_1}} f_{k_1} \\ &= \operatorname{deg}_{k_0} \left(1 + \dfrac{1}{\operatorname{deg}_{k_1}} f_{k_1}\right) \end{aligned}\]

\[\begin{aligned} f_{k_1} &= \sum\limits_{\left(k_1, v\right) \in E} \dfrac{1}{\operatorname{deg}_v} f_v \\ &= \sum\limits_{\left(k_1, v\right) \in E \land v \neq k_0 \land v \neq k_2} \dfrac{1}{\operatorname{deg}_v}f_v + \dfrac{1}{\operatorname{deg}_{k_0}}f_{k_0} + \dfrac{1}{\operatorname{deg}_{k_2}} f_{k_2} \\ &= \sum\limits_{\left(k_1, v\right) \in E \land v \neq k_0 \land v \neq k_2} \dfrac{1}{\operatorname{deg}_v}\dfrac{\operatorname{deg}_v}{\operatorname{deg}_{k_1}}f_{k_1} + \dfrac{1}{\operatorname{deg}_{k_0}}f_{k_0} + \dfrac{1}{\operatorname{deg}_{k_2}} f_{k_2} \\ &= \dfrac{\operatorname{deg}_{k_1} - 2}{\operatorname{deg}_{k_1}} f_{k_1} + \left(1 + \dfrac{1}{\operatorname{deg}_{k_0} f_{k_1}}\right) + \dfrac{1}{\operatorname{deg}_{k_2}} f_{k_2} \\ &= 1 + \dfrac{\operatorname{deg}_{k_1} - 1}{\operatorname{deg}_{k_1}} f_{k_1} + \dfrac{1}{\operatorname{deg}_{k_2}} f_{k_2} \\ &= \operatorname{deg}_{k_1} \left(1 + \dfrac{1}{\operatorname{deg}_{k_2}} f_{k_2}\right) \end{aligned}\]

同理

\[\begin{aligned} f_{k_{m - 1}} &= \operatorname{deg}_{k_{m - 1}} \left(1 + \dfrac{1}{\operatorname{deg}_{k_m}} f_{k_m}\right) \\ &= \operatorname{deg}_{k_{m - 1}} \left(1 + 0\right) \\ &= \operatorname{deg}_{k_{m - 1}} \end{aligned}\]

接下来我们将该式展开

\[f_{k_{m - 2}} = \operatorname{deg}_{k_{m - 2}} \left(1 + \dfrac{1}{\operatorname{deg}_{k_{m - 2}}} f_{k_{m - 2}}\right) = 2 \cdot \operatorname{deg}_{k_{m - 2}} \]

\[f_{k_{m - 3}} = \operatorname{deg}_{k_{m - 2}} \left(1 + \dfrac{1}{\operatorname{deg}_{k_{m - 3}}} f_{k_{m - 3}}\right) = 3 \cdot \operatorname{deg}_{k_{m - 3}} \]

\[f_{k_{m - i}} = i \cdot \operatorname{deg}_{k_{m - i}} \]


综合以上的结论可以发现,对于路径上的各个点 \(k_{m - i}\)

\[\forall v \in \operatorname{subtree}_{k_{m - i}}, f_v = i \cdot \operatorname{deg}_v \]

可以 \(\mathcal{O}(n)\) 解决本题。

Code

//Codeforces - 1823F
#include <bits/stdc++.h>

typedef long long valueType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;
typedef std::vector<bool> bitset;

constexpr valueType MOD = 998244353;

bool ModOperSafeModOption = false;

template<typename T1, typename T2, typename T3 = valueType>
void Inc(T1 &a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    a = a + b;

    if (a >= mod)
        a -= mod;
}

template<typename T1, typename T2, typename T3 = valueType>
void Dec(T1 &a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    a = a - b;

    if (a < 0)
        a += mod;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 sum(T1 a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    return a + b >= mod ? a + b - mod : a + b;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 sub(T1 a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    return a - b < 0 ? a - b + mod : a - b;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 mul(T1 a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    return (long long) a * b % MOD;
}

template<typename T1, typename T2, typename T3 = valueType>
void Mul(T1 &a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    a = (long long) a * b % mod;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 pow(T1 a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod - 1;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod - 1;
    }

    T1 result = 1;

    while (b > 0) {
        if (b & 1)
            Mul(result, a, mod);

        Mul(a, a, mod);
        b = b >> 1;
    }

    return result;
}

valueType N, S, T;
ValueVector ans, distance;
ValueMatrix G;

void dfs(valueType x, valueType from) {
    if (x == T) {
        distance[x] = 0;

        return;
    }

    for (auto const &iter: G[x]) {
        if (iter == from)
            continue;

        dfs(iter, x);

        if (distance[iter] != -1) {
            distance[x] = distance[iter] + 1;

            return;
        }
    }
}

void calc(valueType x, valueType from, valueType k) {
    ans[x] = mul(k, G[x].size());

    for (auto const &iter: G[x]) {
        if (iter == from)
            continue;

        calc(iter, x, k);
    }
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    std::cin >> N >> S >> T;

    ans.resize(N + 1, 0);
    distance.resize(N + 1, -1);
    G.resize(N + 1);

    for (valueType i = 1; i < N; ++i) {
        valueType u, v;

        std::cin >> u >> v;

        G[u].push_back(v);
        G[v].push_back(u);
    }

    dfs(S, 0);

    for (valueType i = 1; i <= N; ++i) {
        if (distance[i] != -1) {
            ans[i] = mul(distance[i], G[i].size());

            for (auto const &iter: G[i])
                if (distance[iter] == -1)
                    calc(iter, i, distance[i]);
        }
    }

    ans[T] = 1;

    for (valueType i = 1; i <= N; ++i)
        std::cout << ans[i] << ' ';

    std::cout << std::endl;

    return 0;
}