树剖1(原理&模板&例题)

发布时间 2023-05-02 22:16:16作者: ZZM_248

引入

树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息

具体来说,将整棵树剖分为若干条链,使它组合成线性结构,然后用其他的数据结构维护信息。

By OI-Wiki

大多数情况下,「树链剖分」都指「重链剖分」,本文就只讲一下「重链剖分」。

前置芝士(重链剖分中的一些定义 \(\&\) 性质)

对于树上的任意一个结点,

  • 重儿子 表示其子结点中子树最大(子树中包含的结点个数最多)的子结点;若有多个子树最大的子结点,任取其一;若无子结点,则无重儿子。(生动形象)

  • 轻儿子 表示其除重儿子外的所有子结点。

  • 重边 表示从此结点到其重儿子的边。

  • 轻边 表示从此结点到其轻儿子的边。

  • 重链 表示由若干条首尾连接的重边构成的链(落单的节点也算作一条重链)。

P1(来自OI-Wiki)

用一下 Wiki 上的图。


如图,可以发现,在以上定义过后,整棵树就可以被剖分成若干条重链。

而且有一些有趣的性质:

  1. 轻儿子必在一重链顶端。

  2. 树上的任意一条路径可以被划分成不超过 \({\cal O}(\log n)\) 条连续的链。

证明先咕着。

应用

树剖主要针对的是如下问题:

P3384 【模板】重链剖分/树链剖分

已知一棵包含 \(N\) 个结点的树(连通且无环),每个结点上包含一个数值,需要支持以下操作:

  • 1 x y z,表示将树从 \(x\)\(y\) 结点最短路径上所有结点的值都加上 \(z\)

  • 2 x y,表示求树从 \(x\)\(y\) 结点最短路径上所有结点的值之和。

  • 3 x z,表示将以 \(x\) 为根节点的子树内所有结点值都加上 \(z\)

  • 4 x 表示求以 \(x\) 为根节点的子树内所有结点值之和。

对于 \(100\%\) 的数据: \(1\le N \leq {10}^5\)\(1\le M \leq {10}^5\)

是不是很像线段树之类的题?

没错,树剖就是这个思想,将树上的操作映射到一段连续的序列上,变成区间操作,再用线段树解决。

代码中的一些定义

  • \({\bf root}\) 表示树的根节点。

对于树上的任意节点 \(u\)

  • \({\bf fa}[u]\) 表示 \(u\) 的父亲编号。
  • \({\bf dep}[u]\) 表示 \(u\) 在树中的深度。
  • \({\bf sz}[u]\) 表示以 \(u\) 为根的子树中的结点个数。
  • \({\bf son}[u]\) 表示 \(u\)儿子。
  • \({\bf top}[u]\) 表示 \(u\) 所在重链的顶点(深度最小的结点)
  • \({\bf id}[u]\) 表示 \(u\) 映射到新序列上的位置。
  • \({\bf nw}[{\bf id}[u]]\) 表示 \(u\) 映射到新序列上后的点权(即 \({\bf nw}[i]\) 表示新序列第 \(i\) 个位置的权值,就是存下新的序列)。

对于本题,需要的代码如下:

const int N = 1e5 + 10, M = N << 1;

int n, m, root, mod;
int h[N], ne[M], e[M], w[M], idx;
int id[N], nw[N], cnt;
int dep[N], top[N], sz[N], fa[N], son[N];

树剖(重链剖分)代码实现

首先,常用两个 \(dfs\) 解决。

之后,就是线段树模板。

最后,根据题目需要实现一些将询问转化为新序列的区间操作的函数。

\(\large \text{Part 1}\)

第一个 \(dfs\) 标记所有重儿子,顺便处理深度、父亲等。

\(\large \color{gray} \cal{Code}\) (详解看注释

void dfs1(int u, int father, int depth){
	dep[u] = depth, fa[u] = father, sz[u] = 1;
	// 维护当前节点的dep, fa, sz
	for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == father) continue;
		// 跳过返祖的边
		dfs1(j, u, depth + 1);
		// 遍历以子结点为根的子树,记得传对参数
		sz[u] += sz[j];
		// 因为此子结点j以下的节点都遍历完了,可以直接更新当前节点u的sz
		if(sz[son[u]] < sz[j]) son[u] = j;
		// 如果当前儿子的sz大于了之前的重儿子,更新重儿子
	}
}
// ...
dfs(root, -1, 1) // 主函数调用,注意从根节点开始

第二个 \(dfs\) 处理出重链,并将树上的点映射到新的序列中。

处理重链实际上就是标记每个结点 \(u\) 所在重链的顶点 \({\bf top}[u]\)(因为后面只会用到这个)。

映射的部分有些需要注意的地方(重点),具体看代码:

\(\large \color{gray} \cal{Code}\) (详解看注释

	// u是当前结点,t是当前结点所在重链的顶点
	id[u] = ++cnt, nw[cnt] = w[u], top[u] = t;
	// 在序列上新建一个结点,下标为cnt
	// 将点权映射至新序列,标记top[u]
	if(!son[u]) return;
	// 如果是叶节点,即没有重儿子,跳过
	dfs2(son[u], t);
	// 这里要注意,为了保证重链在映射的新序列中是连续的,向下遍历时一定要优先遍历重儿子
	// 重儿子必定不在重链的顶端,重儿子所在重链的顶端和其父节点是一样的

	// 随后处理轻儿子
    for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == son[u] || j == fa[u]) continue;
		// 重儿子前面遍历过了,跳过;返租边照样跳过
		dfs2(j, j);
		// 轻儿子肯定在一条重链的顶端
	}
}

\(\large \text{Part 2}\)

线段树板子,只需基础的区间加,求区间和即可。

注意维护的是映射的新序列, \(\rm build\) 时要用新序列 \(\bf nw\),代码就不展示了,最后的整合代码里有。

还不会线段树看这篇博客 \(\large \Longrightarrow\) Link

\(\large \text{Part 3}\)

实现此题的四种操作。

  • 1 x y z,表示将树从 \(x\)\(y\) 结点最短路径上所有结点的值都加上 \(z\)

因为树上的任意一条路径可以被划分成不超过 \({\cal O}(\log n)\) 条连续的链,而每一条重链映射到新序列上都是一段连续的区间,于是就转化成了 \({\cal O}(\log n)\) 个区间操作,具体过程还是有许多的细节,见下:

\(\large \color{gray} \cal{Code}\) (详解看注释

void update_path(int u, int v, int k){
	// 表示将树从u到v的最短路径上所有结点的值都加上k
	// 这里将u和v一直往上跳,每次跳到所在重链的顶端再上一个结点,直到跳到两个点的LCA位为止
	while(top[u] != top[v]){ // 这里注意,两点跳到同一条重链上时就结束,后面单独分析
		if(dep[top[u]] < dep[top[v]]) swap(u, v);
		// 注意一定是要比较top[u]和top[v]的深度,否则有可能跳到LCA上面,画个图看看就知道了
		update(1, id[top[u]], id[u], k);
		// 每跳过一条重链,就区间修改一次
		// 注意映射时深度越浅在越前面,左右别搞反了
		u = fa[top[u]];
		// 注意要跳到所在重链的顶点的父节点,不然一直在原地跳
	}
	// 最后在同一条重链上时单独处理
	if(dep[u] < dep[v]) swap(u, v);
	update(1, id[v], id[u], k);
	// 还是注意顺序
}
  • 3 x z,表示将以 \(x\) 为根节点的子树内所有结点值都加上 \(z\)

这个其实很简单,因为映射新序列时是按照的搜索序,一个子树中所有结点必定是连续的。

代码很简单:

\(\large \color{gray} \cal{Code}\)

void update_tree(int u, int k){
	update(1, id[u], id[u] + sz[u] - 1, k);
	// 右边界 + sz[u] - 1 即可
}

另外两个查询的操作类似,这里也就不赘述了。

完整代码

最后是完整的巨长 \(\color{red}{\cal 210}\)\(\color{gray}{\cal Code}\)

可能是我的一些 Max/Min/read 模板太臃肿了。

不开 \(\rm long \ long\) 见祖宗哦

#include <map>
#include <queue>
#include <cmath>
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>

namespace oi{
	template <class T>
	inline T Abs(T x){return x > 0 ? x : -x;}

	template <class T>
	inline T Max(T a, T b){return a > b ? a : b;}
	template <class T, class... TT>
	inline T Max(T a, TT... b){
		T res = Max(b...);
		return Max(a, res);
	}

	template <class T>
	inline T Min(T a, T b){return a < b ? a : b;}
	template <class T, class... TT>
	inline T Min(T a, TT... b){
		T res = Min(b...);
		return Min(a, res);
	}

	template <class T>
	inline void read(T &x){
		x = 0;
		char ch = getchar();
		bool flag = 0;
		while(ch < '0' || ch > '9') flag |= ch == '-', ch = getchar();
		while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
		flag ? x = -x : 0;
	}
	template <class T, class... TT>
	inline void read(T &x, TT &...xx){
		read(x), read(xx...);
	}
}

using namespace std;
using namespace oi;

typedef long long ll;
typedef unsigned long long ull;

const int N = 1e5 + 10, M = N << 1;
const int INF = 0x3f3f3f3f;

int n, m, root, mod;
int h[N], ne[M], e[M], w[M], idx;
int id[N], nw[N], cnt;
int dep[N], top[N], sz[N], fa[N], son[N];

struct Segment_Tree{
	int l, r;
	ll add, sum;
}tr[N << 2];

void init(){
	memset(h, -1, sizeof h);
}

inline void add(int a, int b){
	e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs1(int u, int father, int depth){
	dep[u] = depth, fa[u] = father, sz[u] = 1;
	for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == father) continue;
		dfs1(j, u, depth + 1);
		sz[u] += sz[j];
		if(sz[son[u]] < sz[j]) son[u] = j;
	}
}

void dfs2(int u, int t){
	id[u] = ++cnt, nw[cnt] = w[u], top[u] = t;
	if(!son[u]) return;
	dfs2(son[u], t);
	for(int i = h[u]; ~i; i = ne[i]){
		int j = e[i];
		if(j == son[u] || j == fa[u]) continue;
		dfs2(j, j);
	}
}

inline void pushup(int u){
	tr[u].sum = tr[u << 1].sum  + tr[u << 1 | 1].sum;
}

inline void pushdown(int u){
	if(tr[u].add){
		Segment_Tree &root = tr[u], &le = tr[u << 1], &ri = tr[u << 1 | 1];
		le.add += root.add;
		ri.add += root.add;
		le.sum += root.add * (le.r - le.l + 1);
		ri.sum += root.add * (ri.r - ri.l + 1);
		root.add = 0;
	}
}

void build(int u, int l, int r){
	tr[u] = {l, r, 0, nw[l]};
	if(l == r) return;
	
	int mid = (l + r) >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	pushup(u);
}

void update(int u, int l, int r, int c){
	if(l <= tr[u].l && tr[u].r <= r){
		tr[u].add += c;
		tr[u].sum += c * (tr[u].r - tr[u].l + 1);
		return;
	}
	
	pushdown(u);
	int mid = (tr[u].l + tr[u].r) >> 1;
	if(l <= mid) update(u << 1, l, r, c);
	if(r > mid) update(u << 1 | 1, l, r, c);
	pushup(u);
}

ll query(int u, int l, int r){
	if(l <= tr[u].l && tr[u].r <= r){
		return tr[u].sum;
	}
	
	pushdown(u);
	int mid = (tr[u].l + tr[u].r) >> 1;
	ll res = 0;
	if(l <= mid) res += query(u << 1, l, r);
	if(r > mid) res += query(u << 1 | 1, l, r);
	return res;
}

void update_path(int u, int v, int k){
	while(top[u] != top[v]){
		if(dep[top[u]] < dep[top[v]]) swap(u, v);
		update(1, id[top[u]], id[u], k);
		u = fa[top[u]];
	}
	if(dep[u] < dep[v]) swap(u, v);
	update(1, id[v], id[u], k);
}

void update_tree(int u, int k){
	update(1, id[u], id[u] + sz[u] - 1, k);
}

ll query_path(int u, int v){
	ll res = 0;
	while(top[u] != top[v]){
		if(dep[top[u]] < dep[top[v]]) swap(u, v);
		res += query(1, id[top[u]], id[u]);
		u = fa[top[u]];
	}
	if(dep[u] < dep[v]) swap(u, v);
	res += query(1, id[v], id[u]);
	return res;
}

ll query_tree(int u){
	return query(1, id[u], id[u] + sz[u] - 1);
}

int main(){
	init();
	read(n, m, root, mod);
	for(int i = 1; i <= n; ++i) read(w[i]);
	int a, b;
	for(int i = 1; i < n; ++i){
		read(a, b);
		add(a, b), add(b, a);
	}
	
	dfs1(root, -1, 1);
	dfs2(root, root);
	build(1, 1, n);
	
	int op, u, v, k;
	while(m--){
		read(op, u);
		if(op == 1){
			read(v, k);
			update_path(u, v, k);
		}
		else if(op == 3){
			read(k);
			update_tree(u, k);
		}
		else if(op == 2){
			read(v);
			printf("%lld\n", query_path(u, v) % mod);
		}
		else printf("%lld\n", query_tree(u) % mod);
	}
	
	return 0;
}

进阶

/#TODO#/