NC20477 [ZJOI2008]树的统计COUNT

发布时间 2023-06-23 14:32:19作者: 空白菌

题目链接

题目

题目描述

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。

我们将以下面的形式来要求你对这棵树完成 一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I

II. QSUM u v: 询问从点u到点v的路径上的节点的权值和

注意:从点u到点v的路径上的节点包括u和v本身

输入描述

输入的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有 一条边相连。
接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1 ≤ n ≤ 30000,0 ≤ q ≤ 200000;中途操作中保证每个节点的权值w在-30000到30000之间。

输出描述

对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

示例1

输入

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

输出

4
1
2
2
10
6
5
6
5
16

题解

知识点:树链剖分,线段树。

这是一道树剖的板题,只需要最基本的查询修改即可。

通常树剖是对点维护,按链处理线段树。每次跳到下一条链的起点,维护整段链。

时间复杂度 \(O(n \log n + q\log ^2n)\)

空间复杂度 \(O(n)\)

代码

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

struct HLD {
    vector<int> siz, dep, fat, son, top, dfn, L, R;

    HLD() {}
    HLD(int rt, const vector<vector<int>> &g) { init(rt, g); }

    void init(int rt, const vector<vector<int>> &g) {
        assert(g.size());
        int n = g.size() - 1;
        siz.assign(n + 1, 0);
        dep.assign(n + 1, 0);
        fat.assign(n + 1, 0);
        son.assign(n + 1, 0);
        top.assign(n + 1, 0);
        dfn.assign(n + 1, 0);
        L.assign(n + 1, 0);
        R.assign(n + 1, 0);

        function<void(int, int)> dfsA = [&](int u, int fa) {
            siz[u] = 1;
            dep[u] = dep[fa] + 1;
            fat[u] = fa;
            for (auto &v : g[u]) {
                if (v == fa) continue;
                dfsA(v, u);
                siz[u] += siz[v];
                if (siz[v] > siz[son[u]]) son[u] = v;
            }
        };
        dfsA(rt, 0);

        int dfncnt = 0;
        function<void(int, int)> dfsB = [&](int u, int tp) {
            top[u] = tp;
            dfn[++dfncnt] = u;
            L[u] = dfncnt;
            if (son[u]) dfsB(son[u], tp);
            for (auto v : g[u]) {
                if (v == fat[u] || v == son[u]) continue;
                dfsB(v, v);
            }
            R[u] = dfncnt;
        };
        dfsB(rt, rt);
    }
};

template<class T, class F>
struct SegmentTree {
    int n;
    vector<T> node;

    void update(int rt, int l, int r, int x, F f) {
        if (r < x || x < l) return;
        if (l == r) return node[rt] = f(node[rt]), void();
        int mid = l + r >> 1;
        update(rt << 1, l, mid, x, f);
        update(rt << 1 | 1, mid + 1, r, x, f);
        node[rt] = node[rt << 1] + node[rt << 1 | 1];
    }

    T query(int rt, int l, int r, int x, int y) {
        if (r < x || y < l) return T::e();
        if (x <= l && r <= y) return node[rt];
        int mid = l + r >> 1;
        return query(rt << 1, l, mid, x, y) + query(rt << 1 | 1, mid + 1, r, x, y);
    }

public:
    SegmentTree(int _n = 0) { init(_n); }
    SegmentTree(const vector<T> &src) { init(src); }

    void init(int _n) {
        n = _n;
        node.assign(n << 2, T::e());
    }
    void init(const vector<T> &src) {
        assert(src.size() >= 2);
        init(src.size() - 1);
        function<void(int, int, int)> build = [&](int rt, int l, int r) {
            if (l == r) return node[rt] = src[l], void();
            int mid = l + r >> 1;
            build(rt << 1, l, mid);
            build(rt << 1 | 1, mid + 1, r);
            node[rt] = node[rt << 1] + node[rt << 1 | 1];
        };
        build(1, 1, n);
    }

    void update(int x, F f) { update(1, 1, n, x, f); }

    T query(int x, int y) { return query(1, 1, n, x, y); }
};

struct T {
    int sum;
    int mx;
    static T e() { return { 0,(int)-1e9 }; }
    friend T operator+(const T &a, const T &b) { return { a.sum + b.sum,max(a.mx,b.mx) }; }
};

struct F {
    int upd;
    T operator()(const T &x) { return { upd,upd }; }
};

const int N = 3e4 + 7;
vector<int> g[N];

HLD hld;
SegmentTree<T, F> sgt;

void node_update(int u, int w) {
    sgt.update(hld.L[u], { w });
}

int path_max(int u, int v) {
    auto &top = hld.top;
    auto &dep = hld.dep;
    auto &fat = hld.fat;
    auto &L = hld.L;
    int ans = -1e9;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        ans = max(ans, sgt.query(L[top[u]], L[u]).mx);
        u = fat[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    ans = max(ans, sgt.query(L[u], L[v]).mx);
    return ans;
}

int path_sum(int u, int v) {
    auto &top = hld.top;
    auto &dep = hld.dep;
    auto &fat = hld.fat;
    auto &L = hld.L;
    int ans = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        ans += sgt.query(L[top[u]], L[u]).sum;
        u = fat[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    ans += sgt.query(L[u], L[v]).sum;
    return ans;
}


int main() {
    std::ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n;
    cin >> n;
    for (int i = 1;i <= n - 1;i++) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    hld.init(1, vector<vector<int>>(g, g + n + 1));
    vector<T> a(n + 1);
    for (int i = 1;i <= n;i++) {
        int x;
        cin >> x;
        a[hld.L[i]] = { x,x };
    }
    sgt.init(a);

    int q;
    cin >> q;
    while (q--) {
        string op;
        cin >> op;
        if (op == "CHANGE") {
            int u, t;
            cin >> u >> t;
            node_update(u, t);
        }
        else if (op == "QMAX") {
            int u, v;
            cin >> u >> v;
            cout << path_max(u, v) << '\n';
        }
        else {
            int u, v;
            cin >> u >> v;
            cout << path_sum(u, v) << '\n';
        }
    }
    return 0;
}