线段树合并 && 分裂

发布时间 2023-10-05 22:21:45作者: Aisaka_Taiga

线段树合并

引入

线段树合并就是把两颗线段树合并起来。

比如:

线段树 \(a\) 维护 \([1,1,2,0,0,2]\)

线段树 \(b\) 维护 \([0,0,2,5,1,2]\)

合并后的线段树 \(c\) 所维护的序列就是 \([1,1,4,5,1,4]\)

解决问题

目前我所见到的线段树合并的题目,一般都是维护区间内众数之类的操作,所以都是搭配权值线段树来使用。

数据范围一般在 \(1e5\)

可能是我的错觉,这种题目主席树也可以切。。。

合并

我们在合并的时候,比如区间求和,假设我们需要将线段树 \(a\) 合并到线段树 \(b\) 上去。

我们在递归合并的过程中有以下两种情况:

  • 当前位置节点两个线段树有一个是空节点。

这种情况最好办了,我们直接返回不为 \(0\) 的那个节点编号即可。

  • 当前两棵线段树都有节点。

因为我们是把 \(b\) 合并到 \(a\) 上,所以我们到了叶子节点就直接将两棵线段树维护的信息合并即可,比如我们一般是用到权值线段树,这个时候直接加和就好。对于非叶子节点,我们在更新完叶子节点的时候用 push_up 函数来更新即可。

inline int merge(int a, int b, int l, int r)
{
    if(!a || !b) return a + b;//返回非空的节点编号
    if(l == r)//如果到了叶子节点
    {
        e[a].sum += e[b].sum;//数量累加
        return a;//返回合并后的点的编号
    }
    int mid = l + r >> 1;
    e[a].l = merge(e[a].l, e[b].l, l, mid);//向左子树内合并
    e[a].r = merge(e[a].r, e[b].r, mid + 1, r);//向右子树内合并
    push_up(a);//更新当前点的信息
    return a;//返回当前点的编号
}

CF600E Lomsat gelral - 洛谷

我们考虑如果是直接做的话不好处理,我们想到在 dfs 的回溯过程中将子树的权值线段树与父节点的权值线段树合并,然后处理信息。

值得注意的是这里用的是权值线段树。

我们线段树节点维护两个东西,当前点代表的区间内颜色出现次数最多的次数,用 \(sum\) 表示,还有就是当前区间对根节点答案的贡献,用 \(ans\) 表示。

我们直接在建完边之后进行 dfs,首先先把当前点的颜色合并到当前点的权值线段树上,然后再回溯的过程中,顺带把当前点所有的子树都合并到当前点上,再合并的过程中用 push_up 函数来保证信息的正确性。

当前点的答案即为现在当前点的权值线段树的根节点的 \(ans\)

/*
 * @Author: Aisaka_Taiga
 * @Date: 2023-10-05 19:23:37
 * @LastEditTime: 2023-10-05 20:24:05
 * @LastEditors: Aisaka_Taiga
 * @FilePath: \Desktop\CF600E.cpp
 * The heart is higher than the sky, and life is thinner than paper.
 */
#include <bits/stdc++.h>

#define int long long
#define ls e[x].l
#define rs e[x].r
#define N 1000100
#define M 100010
#define endl '\n'

using namespace std;

inline int read()
{
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') f = -1; c = getchar();}
    while(c <= '9' && c >= '0') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return x * f;
}

struct tree{int l, r, sum, ans;}e[N << 2];
int rt[M], cl[M], cnt, n, ans[M];
vector<int> g[M];

inline void push_up(int x)
{
    if(e[ls].sum > e[rs].sum) e[x].sum = e[ls].sum, e[x].ans = e[ls].ans;
    if(e[ls].sum < e[rs].sum) e[x].sum = e[rs].sum, e[x].ans = e[rs].ans;
    if(e[ls].sum == e[rs].sum) e[x].sum = e[ls].sum, e[x].ans = e[ls].ans + e[rs].ans;
    return ;
}

inline void add(int &x, int l, int r, int p, int v)
{
    if(!x) x = ++ cnt;
    if(l == r)
    {
        e[x].sum += v;//累加个数
        e[x].ans = l;//对答案的贡献
        return ;
    }
    int mid = l + r >> 1;
    if(p <= mid) add(ls, l, mid, p, v);//向左插入
    else add(rs, mid + 1, r, p, v);//向右儿子插入
    push_up(x);//更新x的信息
    return ;
}

inline int merge(int a, int b, int l, int r)
{
    if(!a || !b) return a + b;//返回非空的节点编号
    if(l == r)//如果到了叶子节点
    {
        e[a].sum += e[b].sum;//数量累加
        e[a].ans = l;//答案贡献为l
        return a;//返回合并后的点的编号
    }
    int mid = l + r >> 1;
    e[a].l = merge(e[a].l, e[b].l, l, mid);//向左子树内合并
    e[a].r = merge(e[a].r, e[b].r, mid + 1, r);//向右子树内合并
    push_up(a);//更新当前点的信息
    return a;//返回当前点的编号
}

inline void dfs(int x, int f)
{
    add(rt[x], 1, 100000, cl[x], 1);//给cl[x]加1的权值
    for(auto v : g[x])
    {
        if(v == f) continue;
        dfs(v, x);//回溯的时候合并线段树,一层一层合并上去
        merge(rt[x], rt[v], 1, 100000);//将v合并到x上去,区间是1,1e5,颜色最大值不超过1e5
    }
    ans[x] = e[rt[x]].ans;//赋答案
    return ;
}

signed main()
{
    n = read();
    for(int i = 1; i <= n; i ++)
        cl[i] = read(), rt[i] = i, cnt ++;
    for(int i = 1; i <= n - 1; i ++)
    {
        int u = read(), v = read();
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0);
    for(int i = 1; i <= n; i ++)
        cout << ans[i] << " ";
    return 0;
}

P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并 - 洛谷

自己亲手一遍切的紫/hsh

做完上面那道题目再来做这个就看起来简单多了。

我们首先想到路径修改只有两个方法,树剖和树上差分。

那肯定选简单的树上差分啊/hsh

对于一条路径 \((u,v)\),设 \(lca = LCA(u,v)\),则我们只需要在 \(u,v\) 的权值线段树上给对应救济粮编号加 \(1\) ,在 \(lca,fa(lca)\) 的权值线段树减 \(1\) 就好,把路径拆成两半,把上面的操作对应拆出的两条链就能理解了。

然后就是直接上线段树合并。

/*
 * @Author: Aisaka_Taiga
 * @Date: 2023-10-05 21:59:23
 * @LastEditTime: 2023-10-05 22:09:30
 * @LastEditors: Aisaka_Taiga
 * @FilePath: \Desktop\P4556.cpp
 * The heart is higher than the sky, and life is thinner than paper.
 */
#include <bits/stdc++.h>

#define int long long
#define ls e[u].l
#define rs e[u].r
#define M 1000100
#define N 100010
#define endl '\n'

using namespace std;

inline int read()
{
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') f = -1; c = getchar();}
    while(c <= '9' && c >= '0') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return x * f;
}

int n, a[N], ans[N], rt[N], cnt, m, mcol = 1e5, f[N][21], dep[N];
struct node{int l, r, sum, ans;}e[M << 4];
vector<int> g[N];

inline void push_up(int u)
{
    if(ls == 0){e[u].sum = e[rs].sum, e[u].ans = e[rs].ans; return ;}
    if(rs == 0){e[u].sum = e[ls].sum, e[u].ans = e[ls].ans; return ;}
    if(e[ls].sum < e[rs].sum) e[u].sum = e[rs].sum, e[u].ans = e[rs].ans;
    else e[u].sum = e[ls].sum, e[u].ans = e[ls].ans;
    return ;
}

inline void add(int &u, int l, int r, int p, int v)
{
    if(!u) u = ++ cnt;
    if(l == r)
    {
        e[u].sum += v;
        e[u].ans = l;
        return ;
    }
    int mid = l + r >> 1;
    if(p <= mid) add(ls, l, mid, p, v);
    else add(rs, mid + 1, r, p, v);
    push_up(u);
    return ;
}

inline int merge(int a, int b, int l, int r)
{
    if(!a || !b) return a + b;
    if(l == r) return e[a].sum += e[b].sum, a;
    int mid = l + r >> 1;
    e[a].l = merge(e[a].l, e[b].l, l, mid);
    e[a].r = merge(e[a].r, e[b].r, mid + 1, r);
    push_up(a);
    return a;
}

inline int LCA(int u, int v)
{
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 20; ~i; i --)
        if(dep[f[u][i]] >= dep[v]) u = f[u][i];
    if(u == v) return u;
    for(int i = 20; ~i; i --)
        if(f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
    return f[u][0];
}

inline void dfs1(int u, int fa)
{
    dep[u] = dep[fa] + 1;
    f[u][0] = fa;
    for(int i = 1; i <= 20; i ++)
        f[u][i] = f[f[u][i - 1]][i - 1];
    for(auto v : g[u])
        if(v != fa) dfs1(v, u);
    return ;
}

inline void dfs2(int u, int fa)
{
    for(auto v : g[u])
    {
        if(v == fa) continue;
        dfs2(v, u);
        merge(rt[u], rt[v], 1, mcol);
    }
    ans[u] = e[rt[u]].ans;
    if(e[rt[u]].sum == 0) ans[u] = 0;
    return ;
}

signed main()
{
    n = read(), m = read();
    for(int i = 1; i <= n; i ++) rt[i] = i, cnt ++;
    for(int i = 1; i <= n - 1; i ++)
    {
        int u = read(), v = read();
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs1(1, 0);
    for(int i = 1; i <= m; i ++)
    {
        int u = read(), v = read(), w = read();
        int lca = LCA(u, v);
        add(rt[u], 1, mcol, w, 1);
        add(rt[v], 1, mcol, w, 1);
        add(rt[lca], 1, mcol, w, -1);
        add(rt[f[lca][0]], 1, mcol, w, -1);
    }
    dfs2(1, 0);
    for(int i = 1; i <= n; i ++)
        cout << ans[i] << endl;
    return 0;
}