dfs序线段树

发布时间 2023-08-18 11:02:33作者: ljfyyds

dfs序线段树

1.树上操作

思路

遍历一整棵树,记录一下节点 \(u\) 的所对应的子树的节点数 \(siz_u\) 以及 \(dfs\)\(dfn_u\)

根据整棵树的 \(dfs\) 序,我们可以把树变成了一个序列

再维护线段树,\(\text{update(l,r,x)}\) 表示将 \([\text{l,r}]\) 上值加上 \(x\).

\(\text{query(l,r)}\) 表示 \(\text{l,r}\) 上的区间和

操作 \(1\) 执行 \(\text{update(}dfn_a,dfn_a+size_a-1,x)\)

操作 \(2\) 执行 \(\text{query}(dfn_a,dfn_a+siz_a-1,x)\)

#include <bits/stdc++.h>
#define ls p << 1
#define rs p << 1 | 1
using namespace std;
const int N = 4e6 + 10, M = N * 2;
typedef long long ll;
int e[M], ne[M], idx, h[N], w[N];
int n, m, r, x;
int dfn[N], cnt, vis[N];
int st[N], ed[N];
struct node
{
    ll s, add;
} tr[N];
void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u)
{
    vis[u] = 1;
    dfn[++cnt] = w[u];
    st[u] = cnt;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (!vis[j])
            dfs(j);
    }
    ed[u] = cnt;
}

void build(int p, int l, int r)
{
    if (l == r)
    {
        tr[p].s = dfn[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    tr[p].s = tr[ls].s + tr[rs].s;
}

void pushdown(int p, int l, int r)
{
    int mid = (l + r) >> 1;
    tr[ls].s += tr[p].add * (mid - l + 1);
    tr[ls].add += tr[p].add;
    tr[rs].s += tr[p].add * (r - (mid + 1) + 1);
    tr[rs].add += tr[p].add;
    tr[p].add = 0;
}

void update(int p, int l, int r, int ql, int qr, int x)
{
    if (ql <= l && r <= qr)
    {
        tr[p].s += (1ll) * (r - l + 1) * x;
        tr[p].add += x;
        return;
    }
    pushdown(p, l, r);
    int mid = (l + r) >> 1;
    if (ql <= mid)
        update(ls, l, mid, ql, qr, x);
    if (qr > mid)
        update(rs, mid + 1, r, ql, qr, x);
    tr[p].s = tr[ls].s + tr[rs].s;
}

ll query(int p, int l, int r, int ql, int qr)
{
    if (ql <= l && r <= qr)
        return tr[p].s;
    pushdown(p, l, r);
    int mid = (l + r) >> 1;
    ll ans = 0;
    if (ql <= mid)
        ans += query(ls, l, mid, ql, qr);
    if (qr > mid)
        ans += query(rs, mid + 1, r, ql, qr);
    return ans;
}

int main()
{
    memset(h, -1, sizeof(h));
    cin >> n >> m >> r;
    for (int i = 1; i <= n; i++)
        cin >> w[i];
    for (int i = 1; i < n; i++)
    {
        int a, b;
        cin >> a >> b;
        add(a, b), add(b, a);
    }
    vis[0] = 1;
    dfs(r);
    build(1, 1, n);
    for (int i = 1; i <= m; i++)
    {
        int k, a;
        cin >> k >> a;
        if (k == 1)
        {
            cin >> x;
            update(1, 1, n, st[a], ed[a], x);
        }
        else
            cout << query(1, 1, n, st[a], ed[a]) << '\n';
    }
    return 0;
}