CF1254D Tree Queries

发布时间 2023-12-30 21:36:35作者: Hanx16Msgr

Tree Queries

Luogu CF1254D

题面翻译

给定一棵 \(N\) 个节点的树,有 \(Q\) 次操作。

  • \(1\ v\ d\) 给定一个点 \(v\) 和一个权值 \(d\),等概率地选择一个点 \(r\),对每一个点 \(u\),若 \(v\)\(u\)\(r\) 的路径上,则 \(u\) 的权值加上 \(d\)(权值一开始为 \(0\))。
  • \(2\ v\) 查询 \(v\) 的权值期望,对 \(998244353\) 取模。

\(1\leq N,Q \leq 150000\)

Solution

操作很怪异,所以考虑一下每次操作会对所有点产生什么影响。先选中一个点 \(v\) 作为操作点,然后枚举点 \(r\)。容易发现最后产生的影响只与 \(v\) 作为树根时 \(r\)\(v\) 的哪一个子树有关,那么这一次操作如果选择 \(r\) 将会对 \(v\)\(r\) 所在的子树外的所有子树产生 \(\dfrac d n\) 的贡献,也就是说这一颗子树内所有可能的 \(r\) 点将会对其余所有的子树产生 \(\dfrac {ds}{n}\) 的贡献(其中 \(s\) 为当前子树的大小)。

那么可以有两种不同思路的暴力,一种修改慢查询快,一种修改快查询慢:

  • 考虑对每一个节点 \(x\) 维护一个值 \(val_x\) 表示当前节点的答案。那么每次修改就是暴力枚举修改节点的每一个子树,然后暴力的加到所有节点的 \(v_x\) 上。
  • 每次修改的时候暴力枚举所有修改过的点,然后根据修改的这个点在询问点的哪一个子树来计算答案。

注意到我们每一次的修改都与当前节点有的子树个数相关,因此对度数根号分治,设一个阈值 \(S\),所有度数大于 \(S\) 的节点称为关键点。显然关键点的数量是 \(\mathcal O(\dfrac n S)\) 的。

对于修改操作,如果当前点是关键点,那么就只把当前点标记上;否则直接暴力枚举所有子树然后处理贡献。注意到处理贡献的时候可以看作是子树加,所以使用树状数组来维护。复杂度是 \(\mathcal O(S\log n)\)

对于询问操作,先在树状数组中取出非关键点的贡献,然后再暴力枚举所有关键点并计算其贡献。时间复杂度不精细实现是 \(\mathcal O(\dfrac n S\log n)\)

\(S\)\(\sqrt n\) 时最优,总时间复杂度 \(\mathcal O(q\sqrt n\log n)\),跑得挺慢,不过能过。

Code
int N, Q, S;
vector<int> e[_N];
bool crit[_N];
int deg[_N];
vector<int> crl;
int fa[20][_N], dep[_N], siz[_N], dfn[_N], tot;
void dfs(int x, int F) {
    fa[0][x] = F, siz[x] = 1, dfn[x] = ++tot, dep[x] = dep[F] + 1;
    For(i, 1, 18) fa[i][x] = fa[i-1][fa[i-1][x]];
    for (int v: e[x]) if (v ^ F) dfs(v, x), siz[x] += siz[v];
}
int jump(int x, int y) {
    for (int tmp = dep[x] - dep[y] - 1, i = 0; tmp; tmp >>= 1, ++i)
        if (tmp & 1) x = fa[i][x];
    return x;
}
struct Bit {
    mint val[_N];
    inline int lowbit(int x) {return x & -x;}
    void update(int x, mint v) {for (; x <= N; x += lowbit(x)) val[x] += v;}
    inline void update(int l, int r, mint v) {update(l, v), update(r + 1, -v);}
    mint ask(int x) {mint res = 0; for (; x; x -= lowbit(x)) res += val[x]; return res;}
} bit;
inline bool chk(int x, int y) {return dfn[x] >= dfn[y] && dfn[x] < dfn[y] + siz[y];}
mint inv, val[_N];
void _() {
    cin >> N >> Q;
    inv = mint(1) / N;
    For(i, 2, N) {
        int x, y; cin >> x >> y;
        e[x].epb(y), e[y].epb(x);
        ++deg[x], ++deg[y];
    }
    S = sqrt(N);
    For(i, 1, N) if (deg[i] >= S)
        crl.epb(i), crit[i] = 1;
    Debug("crit:", crl);
    dfs(1, 0);
    while (Q--) {
        int opt, x;
        cin >> opt >> x;
        if (opt == 1) {
            mint d; cin >> d;
            if (crit[x]) val[x] += d;
            else {
                for (int v: e[x]) {
                    if (fa[0][x] == v) {
                        mint res = (N - siz[x]) * d * inv;
                        bit.update(dfn[x], dfn[x] + siz[x] - 1, res);
                    } else {
                        mint res = siz[v] * d * inv;
                        bit.update(1, N, res);
                        bit.update(dfn[v], dfn[v] + siz[v] - 1, -res);
                    }
                }
                bit.update(1, N, d * inv);
            }
        } else {
            mint base = bit.ask(dfn[x]);
            for (int c: crl) {
                if (x == c) base += val[c];
                else if (chk(x, c)) {
                    int u = jump(x, c);
                    base += val[c] * (N - siz[u]) * inv;
                } else base += val[c] * siz[c] * inv;
            }
            fmtout("{}\n", base);
        }
    }
}