Splay 伸展树扩展应用

发布时间 2023-12-07 21:02:39作者: Imcaigou

Update

2023.5.27 好吧,lxl好像已经发明过这种数据结构了(悲)。

前言

谈谈扩展Splay。(下文用KzSplay代替)

前置知识:

1.Splay,以及文艺平衡树。

2.线段树。

问题引入

请你设计一种数据结构,支持 在线 处理以下操作:

给定一个长度为 \(n\) 的序列 \(a\)

1.支持序列的区间翻转。

2.支持任意区间的交换。

(如 "aabbc" 交换 "aa" 与 "c" 变为 "cbbaa")

3.支持区间的求和。

4.支持区间整体加减。

约束条件:

\(1 \leq n \leq 10^5,1 \leq a_i \leq 10^9\)

分析

首先,看到操作1就可以知道,这种数据结构一定要跟平衡树的左右子树旋转有关。

再看操作2,实际上可以发现这个操作可以分解成四次操作1:

操作3和操作4模仿线段树懒标记下传,在Splay旋转前先下传并摆脱懒标记,具体见代码。

Code

#include <bits/stdc++.h>
using namespace std;
int n, Q, opts, l, r, l1, r1, l2, r2, dt, Kth_Res;
const int SIZE_OF_N = 1e5 + 5;
const int SIZE_OF_KzSPLAY = 1e5 + 5;
int root, tot, a[SIZE_OF_N];
struct KzSplay {
	int size, fa, val;
	int id, son[2], lazy, tag;
	KzSplay (){
		size = fa = val = 0;
		id = son[0] = son[1] = lazy = 0;
	}
}tr[SIZE_OF_KzSPLAY];
void make0 (){
	tr[0].fa = tr[0].id = tr[0].lazy = tr[0].size = tr[0].son[0] = tr[0].son[1] = tr[0].tag = tr[0].val = 0;
}
void push_down (int x){ // Before doing 'push_down', the point must be already solved.
	make0 ();
	if (x == 0)
		return ;
	if (tr[x].lazy == 1){
		tr[tr[x].son[0]].lazy ^= 1;
		tr[tr[x].son[1]].lazy ^= 1;
		swap (tr[x].son[0], tr[x].son[1]);
		tr[x].lazy = 0;
	}
	if (tr[x].tag != 0){
		tr[tr[x].son[0]].tag += tr[x].tag;
		tr[tr[x].son[1]].tag += tr[x].tag;
		a[tr[tr[x].son[0]].id] += tr[x].tag;
		a[tr[tr[x].son[1]].id] += tr[x].tag;
		tr[tr[x].son[0]].val += tr[x].tag * tr[tr[x].son[0]].size;
		tr[tr[x].son[1]].val += tr[x].tag * tr[tr[x].son[1]].size;
		tr[x].tag = 0;
	}
}
void maintain (int x){
	make0 ();
	push_down (x);
	push_down (tr[x].son[0]);
	push_down (tr[x].son[1]);
    make0 ();
	tr[x].size = tr[tr[x].son[0]].size + tr[tr[x].son[1]].size + 1;
    tr[x].val = tr[tr[x].son[0]].val + tr[tr[x].son[1]].val + a[tr[x].id];
}
int get (int x){
	return (int) x != tr[tr[x].fa].son[0];
}
void zag (int x){
	make0 ();
	int y = tr[x].fa, z = tr[y].fa, u = tr[x].son[0];
	int xval = tr[x].val, xsize = tr[x].size;
	int yval = tr[y].val, ysize = tr[y].size;
	int uval = tr[u].val, usize = tr[u].size;
	tr[x].size = ysize;
	tr[x].val = yval;
	tr[y].size += - xsize + usize;
	tr[y].val += - xval + uval;
	tr[y].son[1] = tr[x].son[0];
	if (tr[y].son[1])
		tr[tr[y].son[1]].fa = y;
	tr[x].son[0] = y;
	tr[y].fa = x;
	tr[x].fa = z;
	if (z)
		tr[z].son[y != tr[z].son[0]] = x;
	maintain (y);
	maintain (x);
}
void zig (int x){
	make0 ();
	int y = tr[x].fa, z = tr[y].fa, u = tr[x].son[1];
	int xval = tr[x].val, xsize = tr[x].size;
	int yval = tr[y].val, ysize = tr[y].size;
	int uval = tr[u].val, usize = tr[u].size;
	tr[x].size = ysize;
	tr[x].val = yval;
	tr[y].size += - xsize + usize;
	tr[y].val += - xval + uval;
	tr[y].son[0] = tr[x].son[1];
	if (tr[y].son[0])
		tr[tr[y].son[0]].fa = y;
	tr[x].son[1] = y;
	tr[y].fa = x;
	tr[x].fa = z;
	if (z)
		tr[z].son[y != tr[z].son[0]] = x;
	maintain (y);
	maintain (x);
}
void rotate (int x){
	int z = tr[x].fa;
	push_down (z);
	push_down (tr[z].son[0]);
	push_down (tr[z].son[1]);
	push_down (tr[tr[z].son[0]].son[0]);
	push_down (tr[tr[z].son[0]].son[1]);
	push_down (tr[tr[z].son[1]].son[0]);
	push_down (tr[tr[z].son[1]].son[1]);
	if (get (x) == 0)
		zig (x);
	else
		zag (x);
}
void splay (int x){
	for (int f = tr[x].fa;f = tr[x].fa, f;rotate (x)){
        push_down (f);
        push_down (x);
		if (tr[f].fa)
			rotate (get (x) == get (f) ? f : x);
    }
	root = x;
}
void splay_ (int x){
	for (;tr[x].fa && tr[x].fa != root;rotate (x)){
        push_down (tr[x].fa);
        push_down (x);
	}
}
int Kth (int k){
	Kth_Res = 0;
	int cur = root;
	while (true){
		push_down (cur);
		if (tr[cur].son[0] && k <= tr[tr[cur].son[0]].size)
			cur = tr[cur].son[0];
		else {
			k -= tr[tr[cur].son[0]].size + 1;
			Kth_Res += tr[tr[cur].son[0]].val + a[tr[cur].id];
			if (k <= 0){
				return cur;
			}
			cur = tr[cur].son[1];
		}
	}
}
void Swap (int l, int r){
    if (r <= l)
        return ;
	splay (Kth (l));
	splay_ (Kth (r + 2));
	tr[tr[tr[root].son[1]].son[0]].lazy ^= 1;
}
void Update (int l, int r, int x){
	splay (Kth (l));
	splay_ (Kth (r + 2));
	int u = tr[tr[root].son[1]].son[0];
	tr[u].tag += x;
	tr[u].val += tr[u].size * x;
	a[tr[u].id] += x;
}
void Build (int l, int r, int father, int t){
	if (l > r)
		return ;
	int mid = l + r >> 1, p;
	++ tot;
	p = tot;
	tr[tot].id = mid;
	tr[tot].val = a[mid];
	if (father)
		tr[tot].fa = father;
	else
		root = tot;
	tr[tot].size = 1;
	tr[father].son[t] = tot;
	if (l == r)
		return ;
	Build (l, mid - 1, p, 0);
	Build (mid + 1, r, p, 1);
	maintain (p);
}
int main (){
	freopen ("example.in", "r", stdin);
	freopen ("example.out", "w", stdout);
	scanf ("%d%d", &n, &Q);
	for (int i = 1;i <= n;++ i)
		scanf ("%d", &a[i]);
	Build (0, n + 1, 0, 0);
	while (Q --){
		scanf ("%d", &opts);
		// 1 : Let [l, r] be upside-down.
		if (opts == 1){
			scanf ("%d%d", &l, &r);
			Swap (l, r);
		}
		// 2 : Swap [l1, r1] and [l2, r2].  ## Tips : (r1 - l1) and (r2 - l2) may not equal to each other.
		if (opts == 2){
			scanf ("%d%d%d%d", &l1, &r1, &l2, &r2);
			int d1 = r1 - l1, d2 = r2 - l2;
			Swap (l1, r2);
			Swap (l1, l1 + d2);
			Swap (l1 + d2 + 1, r2 - d1 - 1);
			Swap (r2 - d1, r2);
		}
		// 3 : Ask the sum of the number of [l, r].
		if (opts == 3){
			scanf ("%d%d", &l, &r);
			Kth (l);
			int el = Kth_Res;
			Kth (r + 1);
			int er = Kth_Res;
			printf ("%d\n", er - el);
		}
		// 4 : Add the same number dt to [l, r].
		if (opts == 4){
			scanf ("%d%d%d", &l, &r, &dt);
			Update (l, r, dt);
		}
	}
	return 0;
}