【学习笔记】Splay

发布时间 2023-10-23 16:57:07作者: KingPowers

前置知识:二叉排序树(BST)。

基本操作

首先我们要维护下面这几个东西:

int fa[maxn],siz[maxn],val[maxn],ch[maxn][2],cnt[maxn],root,tot;  //fa:当前点父亲 siz:以当前点为根子树大小 val:权值 ch:左右儿子 cnt:当前权值出现次数 root:当前的根 tot:结点数

还需要记住两个基本操作:

bool get_son(int x){  //判断左右儿子
	return ch[fa[x]][1]==x;
}
void pushup(int x){  //更新size
	siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}

之后就可以开始了。

rotate

首先,旋转是分为左旋和右旋的,不妨让我们先看一张图理解一下这个过程:

根据上图,我们分析下右旋的规律:$2$ 号结点的右儿子 $5$ 号结点成了其父亲$1$ 号结点的左儿子,同时 $1$ 号结点成了 $2$ 号结点的右儿子。而左旋的过程,与右旋是完全相反的,大家可以手模一下。

不过我们也不难发现,我们在程序实现中其实没必要区分左旋和右旋。还记得我们存儿子的数组吗?我们可以用 $0$ 来表示左儿子,$1$ 表示右儿子,这样我们通过异或的方式就可以左右旋二合一了!

具体实现参考代码:

void rotate(int x){  //旋转x
	int y=fa[x],z=fa[y],k=get_son(x);
	ch[z][get_son(y)]=x,fa[x]=z;
	ch[y][k]=ch[x][k^1],fa[ch[x][k^1]]=y;
	ch[x][k^1]=y,fa[y]=x;
	pushup(y);pushup(x);
}

splay

有了最基本的 rotate 操作,那么对于每一个点,我们就可以把它转到任何地方了!至于具体实现,先看代码:

void splay(int x,int f){  //将x旋转到f的儿子处
	while(fa[x]!=f){
		int y=fa[x],z=fa[y];
		if(z!=f){
			if((get_son(x))==(get_son(y))) rotate(y);
			else rotate(x);
		}
		rotate(x);
	}
	if(!f) root=x;  //f为0表示旋转到根
}

设当前要转的点为 $x$,其父亲为 $y$,爷爷为 $z$,则Splay操作一共可以分为以下几种情况:

  • $y$ 就是要转到的地方,直接旋转即可。
  • 若 $x$ 与 $y$ 作为儿子的类型相同,则需要先旋转下 $y$,再旋转 $x$。
  • 否则直接转两次 $x$ 即可。

这里需要解释下第二种情况:

考虑下我们旋转的目的是什么?是为了使树更加平衡。但是,如果你手模一下第二种情况,你会发现,你转了半天,这棵树从根出发的最长链的长度是没有变化的(最直接地,一条一直向左的链,将最深的点转到根之后还是一条链),并没有起到我们期望的优化效果,说直接点,就是你转了半天根没转一样……

对应地,解决方法就是先转一下 $y$,强行让 $x$ 和 $y$ 作为儿子的类型不同,然后继续搞就行了。

查找权值为某个值的结点

接下来我们就可以考虑查找操作了。

还记得前文提到的 BST 的性质吗:对于任意一个结点,左儿子的值一定比它小,右儿子的值一定比它大(再次加粗QAQ)。因此,我们类似二分查找,每次将当前结点的权值与我们要查找的值比较一下,判断往左子树还是右子树走就可以了。

找到之后,为了方便接下来的操作,我们直接将其转到根结点。

void find(int x){  //找到权值为x的点并将其转到根
	int now=root;
	if(!now) return;
	while(val[now]!=x&&ch[now][val[now]<x]) now=ch[now][val[now]<x];
	splay(now,0);
}

如果到这里你都看懂了,那么恭喜你,Spaly 的核心部分你基本都掌握了!

3. 其他操作

让我们看下洛谷模板题要求我们支持的操作:

1.插入 x 数
2.删除 x 数(若有多个相同的数,只删除一个)
3.查询 x 数的排名
4.查询排名为 x 的数
5.求 x 的前驱
6.求 x 的后继

接下来让我们一个一个地来分析。

插入

还是利用二叉排序树的性质,我们可以从根开始一层一层地去寻找 $x$ 数的位置,如果 $x$ 之前存在过就将对应的 $cnt$ 加一,否则新建一个结点即可。

然后为了维持树的结构,以及方便后续操作,我们还是将新插入的点旋转到根。

void insert(int x){  //插入操作
	int now=root,f=0;
	while(now&&val[now]!=x){
		f=now;
		now=ch[now][val[now]<x];
	}
	if(now) cnt[now]++;
	else{
		now=++tot;
		if(f) ch[f][val[f]<x]=now;
		cnt[now]=1,val[now]=x,fa[now]=f;
	}
	splay(now,0);
}

查询排名

啥?你问我为什么不先讲删除?往下看你就知道了。

这个操作算是最简单的了,考虑二叉排序树中一个结点的左儿子一定都是比它小的,所以我们可以直接将要查询的数旋转到根,此时左子树的大小加上一就是其排名了。

int get_rank(int x){  //查询x的排名 
	find(x);
	return siz[ch[root][0]];
}

啥?你问我为什么不加一?还是往下看你就知道了。

查询给定排名的数

如果我们要查询排名为 $x$ 的数,那么还是根据二叉排序树的性质:若当前点与左儿子的大小之和小于 $x$,则将 $x$ 减去这些值然后去右子树;若左儿子的大小大于等于当前值,那么就去左子树。否则答案就是当前结点。

int find_rank(int x){  //查询给定排名的数
	int now=root;
	if(siz[now]<x) return -1;
	while(now){
		if(x>siz[ch[now][0]]+cnt[now]){
			x-=siz[ch[now][0]]+cnt[now];
			now=ch[now][1];
		}
		else if(x<=siz[ch[now][0]]) now=ch[now][0];
		else return val[now];
	}
}

查询前驱

我们继续利用二叉排序树的性质,先找到要查询的数并将其转到根,然后再跳一次左子树,之后在不断地跳右子树,跳到最后就是其前驱。

稍微解释下吧,首先左子树的数一定都比当前的数要小,然后一直往右子树跳就能找到比当前数小的数中最大的了。

代码先不着急。

查询后继

跟上面原理一样,将对应的结点转到根,然后先跳右子树再一直跳左子树就可以了。

然后我们发现这两个操作其实也完全可以写到一起,于是就有了这么份代码:

int get_nxt(int x,int k){  //k为0表示找前驱,为1表示找后继
	find(x);
	int now=root;
	if((val[now]<x&&!k)||(val[now]>x&&k)) return now;
	now=ch[now][k];
	while(ch[now][k^1]) now=ch[now][k^1];
	return now;
}

有意思的是,查询前驱与后继的操作有时可能会因为找不到而出现一些神奇的错误,不过解决方法也很简单,提前插入进去个正负无穷就可以解决,也就是在主函数(或者构造函数)加上这么两行:

insert(inf);
insert(-inf);

这也是查询排名的时候没有加一的原因。

删除

删除操作稍微麻烦一点。

我们首先需要找到要删除的数的前驱与后继(这就是为什么删除最后再讲的原因),然后将前驱旋转到根节点,将后继旋转到前驱的儿子,此时后继的左儿子就是要删除的数了(比后继大比前驱小的数还能是哪个)。如果要删除的数的 $cnt$ 大于一直接减去一即可,否则的话就直接删除。

void erase(int x){  //删除
	int pre=get_nxt(x,0),nxt=get_nxt(x,1);
	splay(pre,0);splay(nxt,pre);
	int del=ch[nxt][0];
	if(cnt[del]>1){
		cnt[del]--;
		splay(del,0);
	}
	else{
		ch[nxt][0]=0;
		splay(nxt,0);
	}
}

好的,到此为止,模板题的所有操作你就都已经学会了,快去复制粘贴切掉吧!

4. 代码

这里使用了struct封装 。

#include<bits/stdc++.h>
#define int long long
#define inf 0x7fffffff
#define PII pair<int,int>
#define fx first
#define fy second
#define mk_p make_pair
#define Set(a,b) memset(a,b,sizeof(a))
using namespace std;
const int maxn=1e5+5;
int read(){
	int ans=0,flag=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')flag=-1;ch=getchar();}
	while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
	return ans*flag;
}
struct Splay{
	int fa[maxn],siz[maxn],val[maxn],ch[maxn][2],cnt[maxn],root,tot;
	bool get_son(int x){
		return ch[fa[x]][1]==x;
	}
	void pushup(int x){
		siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
	}
	void rotate(int x){
		int y=fa[x],z=fa[y],k=get_son(x);
		ch[z][get_son(y)]=x,fa[x]=z;
		ch[y][k]=ch[x][k^1],fa[ch[x][k^1]]=y;
		ch[x][k^1]=y,fa[y]=x;
		pushup(y);pushup(x);
	}
	void splay(int x,int f){
		while(fa[x]!=f){
			int y=fa[x],z=fa[y];
			if(z!=f){
				if((get_son(x))==(get_son(y))) rotate(y);
				else rotate(x);
			}
			rotate(x);
		}
		if(!f) root=x;
	}
	void find(int x){
		int now=root;
		if(!now) return;
		while(val[now]!=x&&ch[now][val[now]<x]) now=ch[now][val[now]<x];
		splay(now,0);
	}
	void insert(int x){
		int now=root,f=0;
		while(now&&val[now]!=x){
			f=now;
			now=ch[now][val[now]<x];
		}
		if(now) cnt[now]++;
		else{
			now=++tot;
			if(f) ch[f][val[f]<x]=now;
			cnt[now]=1,val[now]=x,fa[now]=f;
		}
		splay(now,0);
	}
	int get_nxt(int x,int k){ 
		find(x);
		int now=root;
		if((val[now]<x&&!k)||(val[now]>x&&k)) return now;
		now=ch[now][k];
		while(ch[now][k^1]) now=ch[now][k^1];
		return now;
	}
	void erase(int x){
		int pre=get_nxt(x,0),nxt=get_nxt(x,1);
		splay(pre,0);splay(nxt,pre);
		int del=ch[nxt][0];
		if(cnt[del]>1){
			cnt[del]--;
			splay(del,0);
		}
		else{
			ch[nxt][0]=0;
			splay(nxt,0);
		}
	}
	int get_rank(int x){
		find(x);
		return siz[ch[root][0]];
	}
	int find_rank(int x){
		int now=root;
		if(siz[now]<x) return -1;
		while(now){
			if(x>siz[ch[now][0]]+cnt[now]){
				x-=siz[ch[now][0]]+cnt[now];
				now=ch[now][1];
			}
			else if(x<=siz[ch[now][0]]) now=ch[now][0];
			else return val[now];
		}
	}
}T;
signed main(){
	T.insert(inf);
	T.insert(-inf);
	int n=read();
	for(int i=1;i<=n;i++){
		int opt=read(),x=read();
		if(opt==1) T.insert(x);
		else if(opt==2) T.erase(x);
		else if(opt==3) printf("%lld\n",T.get_rank(x));
		else if(opt==4) printf("%lld\n",T.find_rank(x+1));
		else if(opt==5) printf("%lld\n",T.val[T.get_nxt(x,0)]);
		else if(opt==6) printf("%lld\n",T.val[T.get_nxt(x,1)]);
	}
	return 0;
}