主席树的区间修改

发布时间 2023-08-18 07:43:52作者: ZnPdCo

因为以前搞的主席树基本都忘了,故写一篇帮助记忆。

前置芝士:

主席树

我发现网上的大部分代码码风和我不同,我希望主席树的打法和线段树差不多,所以我找到了一个和线段树差不多的打法。

首先,主席树如果涉及到区间修改,会稍麻烦一些。为了不占用过多空间,我们常常使用一种叫标记永久化的技术。我们不再向下传递标记,相反,我们在查询时带着标记传递。

但是如果我们按照基础的线段树写出下传的代码,会打出:

ll update(ll nl, ll nr, ll l, ll r, ll pos, ll k) {
	pos = clone(pos);
	t[pos].v += (r - l + 1) * k;
	if(nl <= l && r <= nr) {
		t[pos].mark += k;
		return pos;
	}
	
	ll mid = (l + r) >> 1;
	
	if(nl <= mid)
		t[pos].ls = update(nl, nr, l, mid, t[pos].ls, k);
	if(mid < nr)
		t[pos].rs = update(nl, nr, mid + 1, r, t[pos].rs, k);
	
	return pos;
}

这样的话就会出现一个问题,假如我们修改的是下图的橙色区间,然鹅当前的pos为下图的蓝色区间,那么就会多修改下图的红色区间:

也就是说,假如我们要修改 \(\text{[nl, nr]}\),但是当前为 \(\text{[l,r]}\) ,为了保证不出现上图情况,我们可以修改 \(\text{[max(nl,l),min(nr,r)]}\)(感性理解)

那么我们只需要修改得出:

ll update(ll nl, ll nr, ll l, ll r, ll pos, ll k) {
	pos = clone(pos);
	t[pos].v += (min(nr, r) - max(nl, l) + 1) * k; //⭐⭐
	if(nl <= l && r <= nr) {
		t[pos].mark += k;
		return pos;
	}
	
	ll mid = (l + r) >> 1;
	
	if(nl <= mid)
		t[pos].ls = update(nl, nr, l, mid, t[pos].ls, k);
	if(mid < nr)
		t[pos].rs = update(nl, nr, mid + 1, r, t[pos].rs, k);
	
	return pos;
}

对的,只修改了赋值部分。

那么查询就是普通的写法:

ll query(ll nl, ll nr, ll l, ll r, ll pos, ll mark) {
	if(nl <= l && r <= nr) {
		return t[pos].v + (r - l + 1) * mark;
	}
	
	ll mid = (l + r) >> 1;
	ll res = 0;
	if(nl <= mid)
		res += query(nl, nr, l, mid, t[pos].ls, mark + t[pos].mark);
	if(mid < nr)
		res += query(nl, nr, mid + 1, r, t[pos].rs, mark + t[pos].mark);
	return res;
}

看,一点也不难写。

最后我贴出例题和代码,大家可以去打一下这道题(洛谷如果交不了可以试试vjudge):

TTM - To the moon

题面翻译

一个长度为 \(N\) 的数组 \(\{A\}\)\(4\) 种操作 :

  • C l r d:区间 \([l,r]\) 中的数都加 \(d\) ,同时当前的时间戳加 \(1\)

  • Q l r:查询当前时间戳区间 \([l,r]\) 中所有数的和 。

  • H l r t:查询时间戳 \(t\) 区间 \([l,r]\) 的和 。

  • B t:将当前时间戳置为 \(t\)

  所有操作均合法 。

ps:刚开始时时间戳为 \(0\)

输入格式,一行 \(N\)\(M\),接下来 \(M\) 行每行一个操作

输出格式:对每个查询输出一行表示答案

数据保证:\(1\le N,M\le 10^5\)\(|A_i|\le 10^9\)\(1\le l \le r \le N\)\(|d|\le10^4\)。在刚开始没有进行操作的情况下时间戳为 \(0\),且保证 B 操作不会访问到未来的时间戳。

由 @bztMinamoto @yzy1 提供翻译

题目描述

输入格式

n m
A1 A2 ... An
... (here following the m operations. )

输出格式

... (for each query, simply print the result. )

样例 #1

样例输入 #1

10 5
1 2 3 4 5 6 7 8 9 10
Q 4 4
Q 1 10
Q 2 4
C 3 6 3
Q 2 4

样例输出 #1

4
55
9
15

样例 #2

样例输入 #2

2 4
0 0
C 1 1 1
C 2 2 -1
Q 1 2
H 1 2 1

样例输出 #2

0
1

代码:

#include <cstdio>
#include <algorithm>
#define ll long long
#define N 200000
using namespace std;

ll n, m;
ll a[N + 10];
ll rt[N + 10];
ll time = 0;

struct node {
	ll v, ls, rs, mark;
} t[(N << 5) + 10];
ll tot;

ll build(ll l, ll r, ll pos) {
	pos = ++tot;
	if(l == r) {
		t[pos].v = a[l];
		return pos;
	}
	ll mid = (l + r) >> 1;
	t[pos].ls = build(l, mid, pos);
	t[pos].rs = build(mid + 1, r, pos);
	
	t[pos].v = t[t[pos].ls].v + t[t[pos].rs].v;
	return pos;
}

ll clone(ll pos) {
	t[++tot] = t[pos];
	return tot;
}

ll update(ll nl, ll nr, ll l, ll r, ll pos, ll k) {
	pos = clone(pos);
	t[pos].v += (min(nr, r) - max(nl, l) + 1) * k;
	if(nl <= l && r <= nr) {
		t[pos].mark += k;
		return pos;
	}
	
	ll mid = (l + r) >> 1;
	
	if(nl <= mid)
		t[pos].ls = update(nl, nr, l, mid, t[pos].ls, k);
	if(mid < nr)
		t[pos].rs = update(nl, nr, mid + 1, r, t[pos].rs, k);
	
	return pos;
}

ll query(ll nl, ll nr, ll l, ll r, ll pos, ll mark) {
	if(nl <= l && r <= nr) {
		return t[pos].v + (r - l + 1) * mark;
	}
	
	ll mid = (l + r) >> 1;
	ll res = 0;
	if(nl <= mid)
		res += query(nl, nr, l, mid, t[pos].ls, mark + t[pos].mark);
	if(mid < nr)
		res += query(nl, nr, mid + 1, r, t[pos].rs, mark + t[pos].mark);
	return res;
}

int main() {
	scanf("%lld %lld", &n, &m);
	
	for(ll i = 1; i <= n; i++) {
		scanf("%lld", &a[i]);
	} 
	
	rt[0] = build(1, n, 0);
	
	for(ll i = 1; i <= m; i++) {
		char op[5];
		ll l, r, d;
		
		scanf("%s", op);
		
		if(op[0] == 'C') {
			scanf("%lld %lld %lld", &l, &r, &d);
			rt[time+1] = update(l, r, 1, n, rt[time], d);
			time++;
		}
		else if(op[0] == 'Q') {
			scanf("%lld %lld", &l, &r);
			printf("%lld\n", query(l, r, 1, n, rt[time], 0));
		}
		else if(op[0] == 'H') {
			scanf("%lld %lld %lld", &l, &r, &d);
			printf("%lld\n", query(l, r, 1, n, rt[d], 0));
		}
		else if(op[0] == 'B') {
			scanf("%lld", &time);
		}
	}
}