*【学习笔记】(7) 线段树及高级用法

发布时间 2023-10-15 10:11:32作者: Aurora-JC

一.普通线段树

线段树(Segment Tree)几乎是算法竞赛最常用的数据结构了,它主要用于维护区间信息(要求满足结合律)。与树状数组相比,它可以实现 \(O(logn)\) 的区间修改,还可以同时支持多种操作(加、乘),更具通用性。

接下来我们用这道模板题为例,看看线段树是怎么维护区间和这一信息的。

P3372 【模板】线段树 1

题目描述

如题,已知一个数列,你需要进行下面两种操作:

  1. 将某区间每一个数加上 \(k\)
  2. 求出某区间每一个数的和。

输入格式

第一行包含两个整数 \(n, m\),分别表示该数列数字的个数和操作的总个数。
第二行包含 \(n\) 个用空格分隔的整数,其中第 \(i\) 个数字表示数列第 \(i\) 项的初始值。
接下来 \(m\) 行每行包含 \(3\)\(4\) 个整数,表示一个操作,具体如下:

  1. 1 x y k:将区间 \([x, y]\) 内每个数加上 \(k\)
  2. 2 x y:输出区间 \([x, y]\) 内每个数的和。

线段树的建立

线段树是一棵平衡二叉树。母结点代表整个区间的和,越往下区间越小。注意,线段树的每个节点都对应一条线段(区间),但并不保证所有的线段(区间)都是线段树的节点,这两者应当区分开。

如果有一个数组 \([1,2,3,4,5]\) ,那么它对应的线段树大概长这个样子:

每个节点 u 的左右节点的编号分别为 \(2 \times u\)\(2 \times u + 1\), 假如节点 u 存储区间 \([l, r]\) 的和, 设 \(mid = \lfloor \dfrac{l+ r}{2} \rfloor\), 那么两个子节点分别储存 \([l, mid]\)\([mid + 1, r]\) 的和。可以发现,左节点对应的区间长度,与右节点相同或者比之恰好多1。

那这样我们就可以递归建树了

void PushUp(int u){
	node[u].val=node[u<<1].val+node[u<<1|1].val;
}
void Build(int u,int L,int R){
	int m=(L+R)>>1;
	node[u]=(NODE){L,R,m};
	if(L>=R){
		node[u].val=a[L];
		return ;
	}
	Build(u<<1,L,m);
	Build(u<<1|1,m+1,R);
	PushUp(u);
}

区间修改

在讲区间修改前,要先引入一个“懒标记”(或延迟标记)的概念。懒标记是线段树的精髓所在。对于区间修改,朴素的想法是用递归的方式一层层修改(类似于线段树的建立),但这样的时间复杂度比较高。使用懒标记后,对于那些正好是线段树节点的区间,我们不继续递归下去,而是打上一个标记,将来要用到它的子区间的时候,再向下传递。

void Update(int u,int L,int R,int V){
	if(L<=node[u].l&&R>=node[u].r){
		node[u].lazy+=V;
		node[u].val+=V*(node[u].r-node[u].l+1);
		return ;
	}
	PushDown(u);
	int m=(node[u].l+node[u].r)>>1;
	if(L<=m) Update(u<<1,L,R,V); 
	if(R>m) Update(u<<1|1,L,R,V);
	PushUp(u);
}

更新时,我们是从最大的区间开始,递归向下处理。注意到,任何区间都是线段树上某些节点的并集。于是我们记目标区间为 ,当前区间为 , 当前节点为 ,我们会遇到三种情况:

1.当前区间与目标区间没有交集:

这时直接结束递归。

2.当前区间被包括在目标区间里:

这时可以更新当前区间,别忘了乘上区间长度:

node[u].val+=V*(node[u].r-node[u].l+1);

然后打上懒标记

node[u].lazy+=V; // 注意是 +=, 不是 = 

这个标记表示“该区间上每一个点都要加上d”。因为原来可能存在标记,所以是 += 而不是=。

3.当前区间与目标区间相交,但不包含于其中:

这时把当前区间一分为二,分别进行处理。如果存在懒标记,要先把懒标记传递给子节点(注意也是+=,因为原来可能存在懒标记):

void PushDown(int u){
	if(node[u].lazy){
		node[u<<1].val+=node[u].lazy*(node[u<<1].r-node[u<<1].l+1);
		node[u<<1|1].val+=node[u].lazy*(node[u<<1|1].r-node[u<<1|1].l+1);
		node[u<<1].lazy+=node[u].lazy;
		node[u<<1|1].lazy+=node[u].lazy;
		node[u].lazy=0;
	}
}
void Update(int u,int L,int R,int V){
	if(L<=node[u].l&&R>=node[u].r){
		node[u].lazy+=V;
		node[u].val+=V*(node[u].r-node[u].l+1);
		return ;
	}
	PushDown(u);
	int m=(node[u].l+node[u].r)>>1;
	if(L<=m) Update(u<<1,L,R,V); 
	if(R>m) Update(u<<1|1,L,R,V);
	PushUp(u);
}

传递完标记后,再递归地去处理左右两个子节点。

至于单点修改,只需要令左右端点相等即可。

区间查询

类似的

int Query(int u,int L,int R){
	if(L<=node[u].l&&R>=node[u].r) return node[u].val;
	PushDown(u);
	int m=(node[u].l+node[u].r)>>1,ans=0;
	if(L<=m) ans+=Query(u<<1,L,R);
	if(R>m) ans+=Query(u<<1|1,L,R);
	return ans;
}

整体代码

#include<bits/stdc++.h>
#define N 200005
#define int long long
using namespace std;
int n,Q,a[N],opt,L,R,V;
struct NODE{
	int l,r,m,val,lazy;
}node[N<<2];
void PushUp(int u){
	node[u].val=node[u<<1].val+node[u<<1|1].val;
}
void PushDown(int u){
	if(node[u].lazy){
		node[u<<1].val+=node[u].lazy*(node[u<<1].r-node[u<<1].l+1);
		node[u<<1|1].val+=node[u].lazy*(node[u<<1|1].r-node[u<<1|1].l+1);
		node[u<<1].lazy+=node[u].lazy;
		node[u<<1|1].lazy+=node[u].lazy;
		node[u].lazy=0;
	}
}
void Build(int u,int L,int R){
	int m=(L+R)>>1;
	node[u]=(NODE){L,R,m};
	if(L>=R){
		node[u].val=a[L];
		return ;
	}
	Build(u<<1,L,m);
	Build(u<<1|1,m+1,R);
	PushUp(u);
}
void Update(int u,int L,int R,int V){
	if(L<=node[u].l&&R>=node[u].r){
		node[u].lazy+=V;
		node[u].val+=V*(node[u].r-node[u].l+1);
		return ;
	}
	PushDown(u);
	int m=(node[u].l+node[u].r)>>1;
	if(L<=m) Update(u<<1,L,R,V); 
	if(R>m) Update(u<<1|1,L,R,V);
	PushUp(u);
}
int Query(int u,int L,int R){
	if(L<=node[u].l&&R>=node[u].r) return node[u].val;
	PushDown(u);
	int m=(node[u].l+node[u].r)>>1,ans=0;
	if(L<=m) ans+=Query(u<<1,L,R);
	if(R>m) ans+=Query(u<<1|1,L,R);
	return ans;
}
signed main(){
	scanf("%lld%lld",&n,&Q);
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
	Build(1,1,n);
	while(Q--){
		scanf("%lld",&opt);
		if(opt==1){
			scanf("%lld%lld%lld",&L,&R,&V);
			Update(1,L,R,V);
		}else{
			scanf("%lld%lld",&L,&R);
			printf("%lld\n",Query(1,L,R));
		}
	}
	return 0;
}

二.线段树的高级用法

1.动态开点线段树

当线段树数量过多或维护的下标值域过大时,我们就没有办法再使用左儿子两倍,右儿子两倍加一的编号方法了,因为空间不允许我们开 \(10^9\)的数组。

好办,对于每个节点,维护其左儿子和右儿子的编号,向下传递时若对应儿子为空就新建一个节点,可以通过传递引用值方便地实现。查询时如果走到空节点就直接返回。

下面是一个单点修改,区间求和的动态开点线段树实现。可以发现相比于普通线段树,动态开点线段树仅仅是加了一个引用,且向下递归儿子时不再是 \(u << 1 / u << 1 + 1\) 而是 \(lc[u] / rc[u]\)

void PushUp(int u){
	node[u].val=node[node[u].lc].val+node[node[u].rc].val;
}
int Build(){
	cnt++;
	node[cnt].lc=node[cnt].rc=node[cnt].val=0;
	return cnt;
}
void insert(int u,long long x,long long L,long long R){
	if(L==R){
		node[u].val++;
		return ;
	}
	long long m=(L+R)>>1;
	if(x<=m){
		if(!node[u].lc) node[u].lc=Build();
		insert(node[u].lc,x,L,m);
	} 
	else{
		if(!node[u].rc) node[u].rc=Build();
		insert(node[u].rc,x,m+1,R);
	} 
	PushUp(u);
}
long long Query(int u,long long l,long long r,long long L,long long R){
	if(!u||!node[u].val) return 0;
	if(l<=L&&r>=R){
		return node[u].val;
	}
	long long m=(L+R)>>1,ans=0;
	if(l<=m){
		if(!node[u].lc) node[u].lc=++cnt;
		ans+=Query(node[u].lc,l,r,L,m);
	} 
	if(r>m){
		if(!node[u].rc) node[u].rc=++cnt;
		ans+=Query(node[u].rc,l,r,m+1,R);
	} 
	return ans;
}

2.标记永久化

众所周知,线段树的区间修改需要打懒标记。但对于部分特殊的线段树而言,它不支持懒标记下传。例如可持久化线段树和动态开点线段树,如果下传懒标记,则需要新建节点,使空间常数过大。

此时可以考虑使用标记永久化的技巧。简单地说就是按着懒标记不下传,查询时考虑从根到目标区间的路径上每个节点的懒标记。

例如区间加,既然修改时没有 pushdown,那么查询时就要加上当前节点的 lazy。

#include<bits/stdc++.h>
#define N 200005
#define int long long
using namespace std;
int n,Q,a[N],opt,L,R,V;
struct NODE{
	int l,r,val,lazy;
}node[N<<2];
void PushUp(int u){
	node[u].val=node[u<<1].val+node[u<<1|1].val+(node[u].r-node[u].l+1)*node[u].lazy;
}
void Build(int u,int L,int R){
	int m=(L+R)>>1;
	node[u]=(NODE){L,R};
	if(L>=R){
		node[u].val=a[L];
		return ;
	}
	Build(u<<1,L,m);
	Build(u<<1|1,m+1,R);
	PushUp(u);
}
void Update(int u,int L,int R,int V){
	if(L<=node[u].l&&R>=node[u].r){
		node[u].lazy+=V;
		node[u].val+=(node[u].r-node[u].l+1)*V;
		return ;
	}
	int m=(node[u].l+node[u].r)>>1;
	if(L<=m) Update(u<<1,L,R,V); 
	if(R>m) Update(u<<1|1,L,R,V);
	PushUp(u);
}
int Query(int u,int L,int R){
	if(L<=node[u].l&&R>=node[u].r) return node[u].val;
	int m=(node[u].l+node[u].r)>>1,ans=(min(R,node[u].r)-max(L,node[u].l)+1)*node[u].lazy;
	if(L<=m) ans+=Query(u<<1,L,R);
	if(R>m) ans+=Query(u<<1|1,L,R);
	return ans;
}
signed main(){
	scanf("%lld%lld",&n,&Q);
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
	Build(1,1,n);
	while(Q--){
		scanf("%lld",&opt);
		if(opt==1){
			scanf("%lld%lld%lld",&L,&R,&V);
			Update(1,L,R,V);
		}else{
			scanf("%lld%lld",&L,&R);
			printf("%lld\n",Query(1,L,R));
		}
	}
	return 0;
}

3.线段树的合并

考虑合并两棵线段树,分别记为 \(T_x\)\(T_y\)
首先从 \(T_x\)\(T_y\) 的根 \(rt_x,rt_y\) 开始递归。设当前区间分别为 x,y。

  • 如果 \(x,y\) 至少有一个为空,则返回另一个。
  • 否则 将 \(x\) 视为新节点,将 \(y\) 的信息合并到 \(x\) 上,递归 \(ls[x]/ls[y] , rs[x]/rs[y]\)
  • 注意,当递归到叶子节点时需要直接合并。大部分情况下叶子节点的合并是容易的,因为这只涉及两个长度为 1
    的区间的信息。但 合并叶子处的信息与合并左右儿子的信息 是两个不同种类的合并操作,需要分来开考虑。例如可以是叶子处相加,左右儿子取 max。
int merge(int x, int y, int l, int r){
	if(!x || !y) return x + y;
	int mid = (l + r) >> 1;
	ls[x] = merge(ls[x], ls[y], l, mid);
	rs[x] = merge(rs[x], rs[y], mid + 1, r);
	pushup(x);
	return x;
}

Ⅰ. P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并

将链修改转化为树上差分,最后线段树合并。注意每合并完一个节点就直接查询该节点的答案,不然需要可持久化,空间开销过大。

点击查看代码
#include<bits/stdc++.h>
#define N 100005
using namespace std;
int Head[N],Next[N<<1],to[N<<1],tot;
int n,m;
int num,B[N],x[N],y[N],z[N];
int d[N],f[N][55],t;
int ans[N],root[N],cnt;
int lc[N<<6],rc[N<<6],sum[N<<6],tree[N<<6];
void add(int u,int v){
	to[++tot]=v,Next[tot]=Head[u],Head[u]=tot;
}
void bfs(int s){
	queue<int> q;
	q.push(s),d[s]=1;
	while(!q.empty()){
		int x=q.front();
		q.pop();
		for(int i=Head[x];i;i=Next[i]){
			int y=to[i];
			if(d[y]) continue;
			d[y]=d[x]+1;
			f[y][0]=x;
			for(int j=1;j<=t;j++) f[y][j]=f[f[y][j-1]][j-1];
			q.push(y);
		}
	}
}
int lca(int x,int y){
	if(d[x]>d[y]) swap(x,y);
	for(int i=t;i>=0;i--){
		if(d[f[y][i]]>=d[x]){
			y=f[y][i];
		}
	}
	if(x==y) return x;
	for(int i=t;i>=0;i--){
		if(f[x][i]!=f[y][i]){
			x=f[x][i],y=f[y][i];
		}
	}
	return f[x][0];
}
void PushUp(int x){
	if(sum[lc[x]]==0){
		sum[x]=sum[rc[x]];
		tree[x]=tree[rc[x]];
		return ;
	}
	if(sum[rc[x]]==0){
		sum[x]=sum[lc[x]];
		tree[x]=tree[lc[x]];
		return ;
	}
	if(sum[lc[x]]>=sum[rc[x]]){
		sum[x]=sum[lc[x]];
		tree[x]=tree[lc[x]];
	}else{
		sum[x]=sum[rc[x]];
		tree[x]=tree[rc[x]];
	}
}
int merge(int x,int y,int l,int r){
	if(x==0||y==0) return x+y;
	if(l==r){
		sum[x]+=sum[y];
		return x;
	}
	int m=(l+r)>>1;
	lc[x]=merge(lc[x],lc[y],l,m);
	rc[x]=merge(rc[x],rc[y],m+1,r);
	PushUp(x);
	return x;
}
void UpDate(int &u,int l,int r,int x,int v){
	if(!u) u=++cnt;
	if(l==r){
		sum[u]+=v;
		tree[u]=x;
		return ;
	}
	int mid=(l+r)>>1;
	if(x<=mid) UpDate(lc[u],l,mid,x,v);
	else UpDate(rc[u],mid+1,r,x,v);
	PushUp(u);
}
void dfs(int x,int fa){
	for(int i=Head[x];i;i=Next[i]){
		int y=to[i];
		if(y==fa) continue;
		dfs(y,x);
		root[x]=merge(root[x],root[y],1,N);
	}
	ans[x]=tree[root[x]];
	if(sum[root[x]]==0) ans[x]=0;
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		add(u,v),add(v,u);
	}
	t=(int)(log(n)/log(2))+1;
	bfs(1);
	for(int i=1;i<=m;i++){
		scanf("%d%d%d",&x[i],&y[i],&z[i]);
		B[i]=z[i];
	}
	sort(B+1,B+1+m);
	num=unique(B+1,B+1+m)-B-1;
	for(int i=1;i<=m;i++){
		z[i]=lower_bound(B+1,B+1+num,z[i])-B;
		int tt=lca(x[i],y[i]);
		UpDate(root[x[i]],1,N,z[i],1);
		UpDate(root[y[i]],1,N,z[i],1);
		UpDate(root[tt],1,N,z[i],-1);
		UpDate(root[f[tt][0]],1,N,z[i],-1);
	}
	dfs(1,0);
	for(int i=1;i<=n;i++) printf("%d\n",B[ans[i]]);
	return 0;
} 

Ⅱ. P3899 [湖南集训] 更为厉害

点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 3e5 + 67;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}

bool _u;

int n, q, cnt;
int sz[N], dep[N], rt[N], ls[N * 40], rs[N * 40];
ll val[N * 40];
vector<int> e[N];

void modify(int &x, int l, int r, int p, int v){
	if(!x) x = ++cnt;
	val[x] += v;
	if(l == r) return ;
	int mid = (l + r) >> 1;
	if(p <= mid) modify(ls[x], l, mid, p, v);
	else modify(rs[x], mid + 1, r, p, v);
}

ll query(int x, int l, int r, int ql, int qr){
	if(ql <= l && r <= qr) return val[x];
	int mid = (l + r) >> 1; ll ans = 0;
	if(ql <= mid) ans += query(ls[x], l, mid, ql, qr);
	if(qr > mid) ans += query(rs[x], mid + 1, r, ql, qr);
	return ans;
}

int merge(int x, int y){
	if(!x || !y) return x + y;
	int z = ++cnt; 
	ls[z] = merge(ls[x], ls[y]), rs[z] = merge(rs[x], rs[y]);
	val[z] = val[x] + val[y];
	return z;
}

void dfs(int x, int ff){
	sz[x] = 1, dep[x] = dep[ff] + 1;
	for(int y : e[x]) if(y ^ ff) dfs(y, x), sz[x] += sz[y], rt[x] = merge(rt[x], rt[y]);
	modify(rt[x], 1, n, dep[x], sz[x] - 1);
}

bool _v;

int main(){
	cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";
	
	n = read(), q = read();
	for(int i = 1; i < n; ++i){
		int u = read(), v = read();
		e[u].push_back(v), e[v].push_back(u);
	}
	dfs(1, 0);
	while(q--){
		int p = read(), k = read();
		printf("%lld\n", 1ll * min(dep[p] - 1, k) * (sz[p] - 1) + query(rt[p], 1, n, dep[p] + 1, min(dep[p] + k, n)));
	}
	return 0;
} 

Ⅲ. P3521 [POI2011] ROT-Tree Rotations

这道题让我们求出逆序对个数最小值,并且允许我们随意交换一个节点的两棵子树。

考虑一个任意的节点,它的子树先序遍历后的逆序对显然只有三种组成:

  1. 左子树中

  2. 右子树中

  3. 跨越左右子树

对子树的交换显然不会影响第1,2类,因此我们只需要计算出第三类的最小值即可。

至于计算则没有什么难度。由于我们维护的是值域,因此右儿子必定比左儿子大,那么我们就用左儿子大小乘以右儿子大小即可得出交换前逆序对个数。交换后同理之。

这里需要注意,我们能够这样计算是因为无论左右儿子怎么交换,影响的都只有当前部分的逆序对个数,而不会影响深度更浅的节点的值。

点击查看代码
#include<bits/stdc++.h>
#define N 200005*22
#define LL long long
using namespace std;
int n,sum,tot;
int lc[N],rc[N],val[N];
LL ans1,ans2,ans;
int UpDate(int l,int r,int pos){
	int rt=++tot;
	if(l==r){
		val[rt]++;
		return rt;
	}
	int mid=(l+r)>>1;
	if(pos<=mid) lc[rt]=UpDate(l,mid,pos);
	else rc[rt]=UpDate(mid+1,r,pos);
	val[rt]=val[lc[rt]]+val[rc[rt]];
	return rt;
}
int merge(int x,int y,int l,int r){
	if(x==0||y==0) return x|y;
	if(l==r){
		val[x]+=val[y];
		return x;
	}
	ans1+=1ll*val[rc[x]]*val[lc[y]];
	ans2+=1ll*val[lc[x]]*val[rc[y]];
	int mid=(l+r)>>1;
	lc[x]=merge(lc[x],lc[y],l,mid);
	rc[x]=merge(rc[x],rc[y],mid+1,r);
	val[x]+=val[y];
	return x;
}
int dfs(){
	int x=0,v;
	scanf("%d",&v);
	if(v==0){
		int ls=dfs(),rs=dfs();
		ans1=ans2=0;
		x=merge(ls,rs,1,n);
		ans+=min(ans1,ans2);
	}else{
		x=UpDate(1,n,v);
	}
	return x;
}
int main(){
	scanf("%d",&n);
	int rt=dfs();
	printf("%lld\n",ans);
	return 0;
}

Ⅳ.CF1051G Distinctification

点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e5 + 67;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}

bool _u;

int cnt, rt[N], ls[N << 5], rs[N << 5], sz[N << 5];
ll val[N << 5], sum[N << 5];
ll n, ans, fa[N], a[N], b[N], l[N], tot[N];

int find(int x){return fa[x] == x ? x : fa[x] = find(fa[x]);}

void pushup(int u){
	val[u] = val[ls[u]] + val[rs[u]] + sum[ls[u]] * sz[rs[u]];
	sum[u] = sum[ls[u]] + sum[rs[u]], sz[u] = sz[ls[u]] + sz[rs[u]];
}

void modify(int &u, int l, int r, int v){
	if(!u) u = ++cnt;
	if(l == r) return ++sz[u], sum[u] += v, void();
	int mid = (l + r) >> 1;
	if(v <= mid) modify(ls[u], l, mid, v);
	else modify(rs[u], mid + 1, r, v);
	pushup(u);
}

int merge(int x, int y, int l, int r){
	if(!x || !y) return x + y;
	int mid = (l + r) >> 1;
	ls[x] = merge(ls[x], ls[y], l, mid);
	rs[x] = merge(rs[x], rs[y], mid + 1, r);
	pushup(x);
	return x;
}

ll calc(int x){
	ll res = l[x] * sum[rt[x]] - tot[x];
	return res + val[rt[x]];
}

void lian(int x, int y, bool opt){
	if(x > 2e5 || y > 2e5) return ;
	x = find(x), y = find(y);
	if(x == y || (!rt[x] || !rt[y]) && !opt) return ;
	ans -= calc(x), ans -= calc(y);
	fa[y] = x, l[x] = min(l[x], l[y]), tot[x] += tot[y];
	rt[x] = merge(rt[x], rt[y], 1, n);
	ans += calc(x);
}

bool _v;

signed main(){
	cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";
	
	n = read();
	for(int i = 1; i <= 2e5; ++i) fa[i] = l[i] = i;
	for(int i = 1; i <= n; ++i){
		a[i] = read(), b[i] = read();
		int p = find(a[i]);
		ans -= calc(p); modify(rt[p], 1, n, b[i]), tot[p] += a[i] * b[i];
		ans += calc(p);
		if(sz[rt[p]] == 1) lian(p, a[i] - 1, 0), lian(p, a[i] + 1, 0);
		else lian(p, l[p] + sz[rt[p]] - 1, 1), lian(p, l[p] + sz[rt[p]], 0);
		printf("%lld\n", ans);
	}
	return 0;
} 

4.线段树的分裂

和 FHQ Treap 分裂一样,线段树分裂有按值分裂和按排名分裂两种方法。按排名分裂的流程如下。

设当前区间为 \([l,r]\),当前节点为 \(x\),要分裂出的线段树的当前节点为 $ y\((\)y$ 为新建出来的节点),且需要保留 \(T_x\) 维护的较小的 \(k\) 个值。

\(v\) 为$ x$ 的左子树所维护的值的个数,分三种情况讨论:

  • \(k<v\),说明 \(x\) 的右子树要全部给 \(y\),并向左子树分裂。

  • \(k=v\),则将 \(x\) 的右子树全部给 \(y\) 后返回

  • \(k>v\),则向右子树分裂,并令 \(k\) 减去 \(v\)
    对于按值分裂,类似上述过程即可。设需要保留 \(T_x\) 小于等于 \(k\) 的值,且当前区间中点为 \(m\)

    • \(k<m\),说明 \(x\) 的右子树要全部给 \(y\),并向左子树分裂。
  • \(k=m\),则将 \(x\) 的右子树全部给 \(y\) 后返回。

  • \(k>m\) ,则向右子树分裂。

显然,一次分裂会新建 \(logn\) 个节点,因此时空复杂度均为线性对数。

Ⅰ. P5494 【模板】线段树分裂

对于操作 0,将 [1,n] 分裂成 [1,x−1],[x,y] 和 [y+1,n],再将 [1,x−1] 与 [y+1,n] 合并。

对于操作 1,将 p,t 两棵线段树合并。

剩下来都是权值线段树的基本操作。时空复杂度线性对数。

#include<bits/stdc++.h>
#define N 200100
#define LL long long
#define lc ch[rt][0]
#define rc ch[rt][1]
using namespace std;
int n,m,tot=1,cc;
int root[N*2],ch[N<<5][2],cnt,bac[N<<5];
LL val[N<<5];
int newnod(){
	return (cc?bac[cc--]:++cnt);
}
void del(int rt){
	bac[++cc]=rt;
	lc=rc=val[rt]=0;
}
void modify(int &rt,int l,int r,int pos,int v){
	if(!rt) rt=newnod();
	val[rt]+=v;
	if(l==r) return ;
	int mid=(l+r)>>1;
	if(pos<=mid) modify(lc,l,mid,pos,v);
	else modify(rc,mid+1,r,pos,v);
}
LL Query(int rt,int l,int r,int L,int R){
	if(L<=l&&R>=r) return val[rt];
	int mid=(l+r)>>1;
	LL ans=0;
	if(L<=mid) ans+=Query(lc,l,mid,L,R);
	if(R>mid) ans+=Query(rc,mid+1,r,L,R);
	return ans;
}
int QueryK(int rt,int l,int r,int v){
	if(l==r) return l;
	int mid=(l+r)>>1;
	if(val[lc]>=v) return QueryK(lc,l,mid,v); 
	else return QueryK(rc,mid+1,r,v-val[lc]);
}
int merge(int x,int y){
	if(!x||!y) return x+y;
	val[x]+=val[y];
	ch[x][0]=merge(ch[x][0],ch[y][0]);
	ch[x][1]=merge(ch[x][1],ch[y][1]);
	del(y);
	return x;
}
void split(int x,int &y,LL k){
	if(!x) return ;
	y=newnod();
	LL v=val[ch[x][0]];
	if(k>v) split(ch[x][1],ch[y][1],k-v);
	else swap(ch[x][1],ch[y][1]);
	if(k<v) split(ch[x][0],ch[y][0],k);
	val[y]=val[x]-k;
	val[x]=k;
	return ;
}
int main(){
	//freopen("1.in","r",stdin);
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		int x;
		scanf("%d",&x);
		modify(root[1],1,n,i,x);
	}
	while(m--){
		int p,x,y,t,q,k,opt;
		scanf("%d",&opt);
		if(opt==0){
			scanf("%d%d%d",&p,&x,&y);
			LL k1=Query(root[p],1,n,1,y),k2=Query(root[p],1,n,x,y);
			int tmp=0;
			split(root[p],root[++tot],k1-k2);
			split(root[tot],tmp,k2);
			root[p]=merge(root[p],tmp);
		}else if(opt==1){
			scanf("%d%d",&p,&t);
			root[p]=merge(root[p],root[t]);
		}else if(opt==2){
			scanf("%d%d%d",&p,&x,&q);
			modify(root[p],1,n,q,x);
		}else if(opt==3){
			scanf("%d%d%d",&p,&x,&y);
			printf("%lld\n",Query(root[p],1,n,x,y));
		}else if(opt==4){
			scanf("%d%d",&p,&k);        
			if(val[root[p]]<k) printf("-1\n");
			else printf("%d\n",QueryK(root[p],1,n,k));
		}
	}
	return 0;
}

Ⅱ.CF558E A Simple Task

点击查看代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e5 + 67;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}

bool _u;

int n, q, cnt;
int val[N << 7], ls[N << 7], rs[N << 7];
char ss[N];
struct seq{
	int l, r, rt, type;
	bool operator < (const seq &A) const{return l != A.l ? l < A.l : r < A.r;}
};
set<seq> s;

void modify(int &u, int l, int r, int p){
	val[u = ++cnt] = 1;
	if(l == r) return ;
	int mid = (l + r) >> 1;
	if(p <= mid) modify(ls[u], l, mid, p);
	else modify(rs[u], mid + 1, r, p);
}

int merge(int x, int y){
	if(!x || !y) return x + y;
	ls[x] = merge(ls[x], ls[y]), rs[x] = merge(rs[x], rs[y]);
	val[x] += val[y];
	return x; 
}

void split1(int x, int &y, int k){
	if(!x) return y = 0, void();
	y = ++cnt;
	if(k < val[ls[x]]) swap(rs[x], rs[y]), split1(ls[x], ls[y], k);
	else if(k == val[ls[x]]) swap(rs[x], rs[y]);
	else split1(rs[x], rs[y], k - val[ls[x]]);
	val[y] = val[x] - k, val[x] = k;
}

int query(int x, int l, int r, int k){
	if(l == r) return l;
	int mid = (l + r) >> 1;
	if(k <= val[ls[x]]) return query(ls[x], l, mid, k);
	return query(rs[x], mid + 1, r, k - val[ls[x]]);
}

void split2(int p){
	if(p == 0 || p == n) return ;
	auto it = --s.upper_bound({p, 100001, 0, 0});
	if(it -> r == p) return ;
	int l = it -> l, r = it -> r, lR = it -> rt, rR;
	if(it -> type == 0) split1(lR, rR, p - l + 1);
	else split1(lR, rR, r - p), swap(lR, rR);
	s.insert({l, p, lR, it -> type}), s.insert({p + 1, r, rR, it -> type});
	s.erase(it);
}

bool _v;

signed main(){
	cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";	
	
	n = read(), q = read(); scanf("%s", ss + 1);
	for(int i = 1, x = 0; i <= n; ++i) modify(x, 1, 26, ss[i] - 'a' + 1), s.insert({i, i, x, 0});
	while(q--){
		int l = read(), r = read(), opt = 1 - read();
		split2(l - 1), split2(r);
		int rt = 0;
		while(1){
			auto it = s.lower_bound({l, 0, 0, 0});
			if(it == s.end() || it -> l > r) break;
			rt = merge(rt, it -> rt), s.erase(it);
		}
		s.insert({l, r, rt, opt});
	}
	for(int i = 1; i <= n; ++i){
		auto it = --s.upper_bound({i, 100001, 0, 0});
		printf("%c", query(it -> rt, 1, 26, !it -> type ? i - it -> l + 1 : it -> r - i + 1) + 'a' - 1);
	}
	return 0;
} 

5. 线段树分治

6.李超线段树

参考资料:算法学习笔记(14): 线段树 线段树的高级用法