Splay

发布时间 2023-07-01 10:31:29作者: 我是浣辰啦

概念

Splay 树(伸展树),是一种平衡BST

它通过伸展操作不断将某个节点旋转到根节点,使得整棵树仍然满足BST的性质,能够在均摊 \(O(\log n)\) 时间内完成插入,查找和删除操作,并且保持平衡而不至于退化为链。

实现

rotate

其保证

  • 不破坏BST的性质
  • 不破坏节点维护的信息
  • root必须指向旋转后的根节点

在Splay中旋转分为左旋和右旋

具体分析旋转过程(令需要旋转的节点为 \(x\) ,其父亲为 \(y\) ,以右旋为例)

  1. \(y\) 的左儿子指向 \(x\) 的右儿子,且 \(x\) 的右儿子的父亲指向 \(y\)
  2. \(x\) 的右儿子指向 \(y\) ,且 \(y\) 的父亲指向 \(x\)
  3. \(y\) 的父亲 \(z\) 指向 \(y\) 的儿子的信息指向 \(x\) 并将 \(x\) 的父亲指向 \(z\)
inline void rotate(int x) {
	int y = fa[x], z = fa[y], d = dir(x);
	ch[y][d] = ch[x][d ^ 1];
	
	if (ch[x][d ^ 1])
		fa[ch[x][d ^ 1]] = y;
	
	ch[x][d ^ 1] = y;
	fa[y] = x, fa[x] = z;
	
	if (z)
		ch[z][y == ch[z][1]] = x;
		
	pushup(x), pushup(y);
}

splay

定义:每次访问一个节点后都多次使用 splay 操作强制旋转到根

splay操作步骤有三种,具体分为六种情况

  1. zig: \(y\) 是根节点。直接旋转即可

  2. zig-zig:\(x,y\) 都是其父亲的左儿子或右儿子。先把 \(y\) 旋上去,再把 \(x\) 旋上去

  1. zig-zag: \(x, y\) 不都是其父亲的左儿子或右儿子。把 \(x\) 旋上去两次即可

代码实现:

inline void splay(int x) {
	for (int f = fa[x]; f; rotate(x), f = fa[x])
		if (fa[f])
			rotate(dir(x) == dir(f) ? f : x);
	
	root = x;
}

合并

合并两棵splay树

设两棵树的根节点为 \(x, y\) ,令 \(x\) 中的最大值小于 \(y\) 中的最小值

\(x\) 树的最大值splay到根,将其右子树设为 \(y\) 即可

插入

设插入值为 \(k\)

  • 若树空,则直接插入根并退出
  • 若当前节点权值等于 \(k\) ,则增加当前节点大小并更新信息
  • 否则按照 BST 的性质向下找,找到空节点插入即可
inline void insert(int k) {
	if (!root) {
		val[++tot] = k;
		cnt[tot] = 1;
		root = tot;
		pushup(root);
		return ;
	}
	
	int cur = root, f = 0;
	
	for (;;) {
		if (val[cur] == k) {
			++cnt[cur];
			pushup(cur), pushup(f);
			splay(cur);
			break;
		}
		f = cur, cur = ch[cur][val[cur] < k];
		
		if (!cur) {
			val[++tot] = k;
			cnt[tot] = 1;
			fa[tot] = f;
			ch[f][val[f] < k] = tot;
			pushup(tot), pushup(f);
			splay(tot);
			break;
		}
	}
}

查询 \(x\) 的排名

与 BST 类似

inline int rnk(int k) {
	int res = 0, cur = root;
	
	for (;;) {
		if (k < val[cur])
			cur = ch[cur][0];
		else {
			res += siz[ch[cur][0]];
			
			if (k == val[cur]) {
				splay(cur);
				return res + 1;
			}
			
			res += cnt[cur], cur = ch[cur][1];
		}
	}
}

查询排名为 \(x\) 的数

与 BST 类似

inline int kth(int k) {
	int cur = root;
	
	for (;;) {
		if (ch[cur][0] && k <= siz[ch[cur][0]])
			cur = ch[cur][0];
		else {
			k -= siz[ch[cur][0]] + cnt[cur];
			
			if (k <= 0) {
				splay(cur);
				return val[cur];
			}
			
			cur = ch[cur][1];
		}
	}
}

查询前驱 / 后继

前驱定义为小于 \(x\) 的最大数,那么我们可以先插入 \(x\) ,前驱即为 \(x\) 左子树中最右节点,最后删除 \(x\) 即可

后继定义为大于 \(x\) 的最小数,查询方法类似前驱: \(x\) 的右子树中的最左节点

找根节点的前驱后继代码:

inline int near(int sign) {
	int cur = ch[root][sign];
	
	if (!cur)
		return cur;
	
	while (ch[cur][sign ^ 1])
		cur = ch[cur][sign ^ 1];
	
	splay(cur);
	return cur;
}

删除

首先将 \(x\) 旋转到根

  • 若有不止一个 \(x\) ,则直接将该点数量减 \(1\) 即可
  • 否则,合并左右子树即可
inline void remove(int k) {
	rnk(k);
	
	if (cnt[root] > 1)
		--cnt[root], pushup(root);
	else if (!ch[root][0] && !ch[root][1])
		clear(root), root = 0;
	else if (!ch[root][0]) {
		int cur = root;
		root = ch[root][1];
		fa[root] = 0;
		clear(cur);
	} else if (!ch[root][1]) {
		int cur = root;
		root = ch[root][0];
		fa[root] = 0;
		clear(cur);
	} else {
		int cur = root, x = near(0);
		fa[ch[cur][1]] = x;
		ch[x][1] = ch[cur][1];
		clear(cur);
		pushup(root); 
	}
}

应用

P3369 【模板】普通平衡树

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;

namespace Splay {
int ch[N][2];
int fa[N], siz[N], cnt[N];
int val[N];

int root, tot;

inline void pushup(int x) {
	siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}

inline int get(int x) {
	return x == ch[fa[x]][1];
}

inline void clear(int x) {
	ch[x][0] = ch[x][1] = fa[x] = val[x] = siz[x] = cnt[x] = 0;
}

inline void rotate(int x) {
	int y = fa[x], z = fa[y], d = get(x);
	ch[y][d] = ch[x][d ^ 1];
	
	if (ch[x][d ^ 1])
		fa[ch[x][d ^ 1]] = y;
	
	ch[x][d ^ 1] = y;
	fa[y] = x, fa[x] = z;
	
	if (z)
		ch[z][y == ch[z][1]] = x;
		
	pushup(x), pushup(y);
}

inline void splay(int x) {
	for (int f = fa[x]; f = fa[x], f; rotate(x))
		if (fa[f])
			rotate(get(x) == get(f) ? f : x);
	
	root = x;
}

inline void insert(int k) {
	if (!root) {
		val[++tot] = k;
		cnt[tot] = 1;
		root = tot;
		pushup(root);
		return ;
	}
	
	int cur = root, f = 0;
	
	for (;;) {
		if (val[cur] == k) {
			++cnt[cur];
			pushup(cur), pushup(f);
			splay(cur);
			break;
		}
		f = cur, cur = ch[cur][val[cur] < k];
		
		if (!cur) {
			val[++tot] = k;
			cnt[tot] = 1;
			fa[tot] = f;
			ch[f][val[f] < k] = tot;
			pushup(tot), pushup(f);
			splay(tot);
			break;
		}
	}
}

inline int rnk(int k) {
	int res = 0, cur = root;
	
	for (;;) {
		if (k < val[cur])
			cur = ch[cur][0];
		else {
			res += siz[ch[cur][0]];
			
			if (k == val[cur]) {
				splay(cur);
				return res + 1;
			}
			
			res += cnt[cur], cur = ch[cur][1];			
		}
	}
}

inline int kth(int k) {
	int cur = root;
	
	for (;;) {
		if (ch[cur][0] && k <= siz[ch[cur][0]])
			cur = ch[cur][0];
		else {
			k -= siz[ch[cur][0]] + cnt[cur];
			
			if (k <= 0) {
				splay(cur);
				return val[cur];
			}
			
			cur = ch[cur][1];
		}
	}
}

inline int near(int sign) {
	int cur = ch[root][sign];
	
	if (!cur)
		return cur;
	
	while (ch[cur][sign ^ 1])
		cur = ch[cur][sign ^ 1];
	
	splay(cur);
	return cur;
}

inline void remove(int k) {
	rnk(k);
	
	if (cnt[root] > 1)
		--cnt[root], pushup(root);
	else if (!ch[root][0] && !ch[root][1])
		clear(root), root = 0;
	else if (!ch[root][0]) {
		int cur = root;
		root = ch[root][1];
		fa[root] = 0;
		clear(cur);
	} else if (!ch[root][1]) {
		int cur = root;
		root = ch[root][0];
		fa[root] = 0;
		clear(cur);
	} else {
		int cur = root, x = near(0);
		fa[ch[cur][1]] = x;
		ch[x][1] = ch[cur][1];
		clear(cur);
		pushup(root); 
	}
}
}

int m;

signed main() {
	scanf("%d", &m);
	
	for (int op, x; m; --m) {
		scanf("%d%d", &op, &x);
		
		if (op == 1)
			Splay::insert(x);
		else if (op == 2)
			Splay::remove(x);
		else if (op == 3)
			printf("%d\n", Splay::rnk(x));
		else if (op == 4)
			printf("%d\n", Splay::kth(x));
		else if (op == 5) {
			Splay::insert(x);
			printf("%d\n", Splay::val[Splay::near(0)]);
			Splay::remove(x);
		}
		else {
			Splay::insert(x);
			printf("%d\n", Splay::val[Splay::near(1)]);
			Splay::remove(x);
		}
	}
	
	return 0;
}

扩展

区间翻转

P3391 【模板】文艺平衡树

我们以编号为下标建立一棵 Splay

当我们翻转区间 \([l, r]\) ,时,我们可以考虑利用 Splay 的性质,将 \(l - 1\) 翻转至根节点,再将 \(r + 1\) 翻转至其右儿子,这样 \(r + 1\) 的左儿子就是所有 \([l, r]\) 的数了

此时,我们对这个节点打上标记,有需要时再翻转即可

为了方便,我们在树的两端各插入 \(\pm \infty\) ,防止在翻转 \([1, n]\) 时出现问题

#include <bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 7;

namespace Splay {
int ch[N][2];
int fa[N], siz[N], val[N], tag[N];

int root, tot;

inline void pushup(int x) {
	siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + 1;
}

inline void pushdown(int x) {
	if (tag[x]) {
		tag[ch[x][0]] ^= 1;
		tag[ch[x][1]] ^= 1;
		swap(ch[x][0], ch[x][1]);
		tag[x] = 0;
	}
}

inline int get(int x) {
	return x == ch[fa[x]][1];
}

inline void rotate(int x) {
	int y = fa[x], z = fa[y], d = get(x);
	pushdown(y), pushdown(x);
	ch[y][d] = ch[x][d ^ 1];
	
	if (ch[x][d ^ 1])
		fa[ch[x][d ^ 1]] = y;
	
	ch[x][d ^ 1] = y;
	fa[y] = x, fa[x] = z;
	
	if (z)
		ch[z][y == ch[z][1]] = x;
	
	pushup(y), pushup(x);
}

inline void splay(int x, int goal = 0) {
	for (int  f = fa[x]; f != goal; rotate(x), f = fa[x])
		if (fa[f] != goal)
			rotate(get(x) == get(f) ? f : x);
	
	if (!goal)
		root = x;
}

inline void insert(int k) {
	if (!root) {
		val[++tot] = k;
		root = tot;
		pushup(root);
		return;
	}
	
	for (int x = root, f = 0;;) {
		f = x, x = ch[x][val[x] < k];
		
		if (!x) {
			val[++tot] = k;
			fa[tot] = f, ch[f][val[f] < k] = tot;
			pushup(tot), pushup(f);
			splay(tot);
			break;
		}
	}
}

inline int find(int k) {
	for (int x = root;;) {
		pushdown(x);
		
		if (k <= siz[ch[x][0]])
			x = ch[x][0];
		else if (k == siz[ch[x][0]] + 1)
			return x;
		else
			k -= siz[ch[x][0]] + 1, x = ch[x][1];
	}
}

inline void reverse(int l, int r) {
	l = find(l - 1), r = find(r + 1);
	splay(l, 0), splay(r, l);
	tag[ch[ch[root][1]][0]] ^= 1;
}

inline void dfs(int x) {
	if (!x)
		return;
	
	pushdown(x);
	dfs(ch[x][0]);
	
	if (val[x] != -inf && val[x] != inf)
		printf("%d ", val[x]);
	
	dfs(ch[x][1]);
}
}

int n, m;

signed main() {
	scanf("%d%d", &n, &m);
	Splay::insert(-inf), Splay::insert(inf);
	
	for (int i = 1; i <= n; ++i)
		Splay::insert(i);
	
	for (int l, r; m; --m) {
		scanf("%d%d", &l, &r);
		Splay::reverse(l + 1, r + 1);
	}
	
	Splay::dfs(Splay::root);
	return 0;
}

区间移动

将区间 \([l, r]\) 扔到 \(c\) 后面

首先,类似区间翻转,我们拿出区间 \([l, r]\) ,并将 \(c\) 旋转到根,将 \(c + 1\) 旋转至 \(c\) 的右儿子,接着把 \([l, r]\) 设为 \(c + 1\) 的左儿子即可