学习笔记:splay树

发布时间 2023-08-12 23:08:47作者: g1ove

0.前言

只有基础操作,题目传送门:click here

1.概念

splay树是一棵平衡二插查找树 保证左边子树的值比当前的值小并且右边子树的值比当前的值大 而且左右子树也是二插搜索树

根据以前的学习 我们知道 \(Treap\)算法是引入随机值的概念控制树高降低时间复杂度\(O(nlogn)\)\(Splay\)呢也是神奇到看不懂证明的均摊时间复杂度\(O(nlogn)\)

2.引入

用过输入法吧?

每输入拼音找到了词语,下一次再输入这个拼音时这个词语就到了最前面,这就是使用\(Splay\)实现的

哇,真的太神奇了

3.算法步骤

1.旋转操作

在一棵二叉查找树中,我们要使用旋转操作使某个子节点旋转到它的父亲节点的同时保证\(BST\)的性质

在这里引入左旋和右旋

直接拿 OIwiki的图

eg

分析右旋:1节点通过左旋到了根节点 怎么做到的呢?

过程:

  • 1.对于原根节点:右儿子不变,左儿子变成原左儿子的右儿子
  • 2.对于原左儿子:左儿子不变,右儿子变为根节点.

很抽象,难理解 多理解几遍就明白了

Code:

int chk(int x)
{
	return ch[fa[x]][1]==x;
}
void Pushup(int x)
{
	size[x]=size[ch[x][0]]+size[ch[x][1]]+sum[x];
}
void rotate(int x)
{
	int y=fa[x],z=fa[y],k=chk(x),w=ch[x][k^1];
	ch[y][k]=w;fa[w]=y;
	ch[z][chk(y)]=x;fa[x]=z;
	ch[x][k^1]=y;fa[y]=x;
	Pushup(y);
	Pushup(x);
}

代码里函数解析:

  • 1.\(chk\) 判断当前节点是父节点的左/右儿子
  • 2.\(Pushup\) 更新size数组 后面做\(Kth\)
  • 3.\(rotate\) 将当前儿子转到父亲节点

引入以下数组/变量:

  • 1.\(root\) 根节点
  • 2.\(ch[x][0/1]\) 0表示x的左儿子,1为右儿子
  • 3.\(size[x]\) 左右子树以及自身总数量的和
  • 4.\(fa[x]\) 记录父亲节点
  • 5.\(sum[x]\) 当前节点数的个数
  • 6.\(val[x]\) 当前节点的值

2.Splay操作

这一部分是\(Splay\)树的核心

目的:把一个节点伸展到它的某个祖父

最朴素的想法就是一次一次旋转

但这样必定\(TLE\)

所以Tarjan老爷子%%%就开发出了新的旋转方法:双旋

根据以下情况分类:

  • 1.距离目标点只差一步 一次旋转即可
  • 2.双旋:三个节点一条链,先旋转中间,在旋转最后的(一字型)
  • 3.双旋:三个节点非一条链,先旋转最后的,再旋转中间(之字型)

结合代码理解
Code:

void splay(int x,int goal)
{
	while(fa[x]!=goal)
	{
		int y=fa[x];
		int z=fa[y];
		if(z!=goal)
			if(chk(x)==chk(y)) rotate(y);
			else rotate(x);
		rotate(x);
	}
	if(!goal) root=x;
}

其中, 目标是将x旋转到goal节点的儿子
当其为0时,目标即旋转到根节点

3.插入操作

根据\(BST\)的性质,左小右大,直接先查找这个点,即使找不到找到的也是一个可以插入这个节点的叶子

判断一下:

  • 1.有这个节点 sum++即可
  • 2.没这个节点,再建一个点
void insert(int x)
{
	int now=root,p=0;
	while(now&&val[now]!=x)
	{
		p=now;
		now=ch[now][x>val[now]];
	}
	if(now) sum[now]++;
	else
	{
		now=++tot;
		if(p) ch[p][x>val[p]]=now;
		ch[now][0]=ch[now][1]=0;
		val[now]=x;fa[now]=p;
		sum[now]=size[now]=1;
	}
	splay(now,0);
}

注意,最后一句splay是一定要加的 是整个算法真正精髓

4.查找排名

这个最简单,引入Find函数,目标在\(BST\)中找到这个对应点(没有的话返回前驱或者后继)

void find(int x)
{
	if(!root) return ;
	int now=root;
	while(ch[now][x>val[now]]&&val[now]!=x)
		now=ch[now][x>val[now]];
	splay(now,0);
}

分类讨论一下

  • 1.如果找到前驱,就是左儿子的size加上根节点的sum
  • 2.如果找到后继或者本身,直接返回左儿子的size
int rank(int x)
{
	find(x);
	if(val[root]>=x) return size[ch[root][0]];
	else return size[ch[root][0]]+sum[root];
}

这里会疑问为什么没有rank+1

因为程序一开始要插入极小值和极大值来维护整棵树 有极小值可视为自动加一

肯定又有人想 这么麻烦,不如直接把后继找到然后输出根节点左子树的大小? 肯定是不行的 因为Find函数是在找不到原先值的情况下返回的前驱/后继 和这个不一样

当然 还有一种方法就是类似Kth一样树上二分 但是那样麻烦点

5.后继/前驱

这两一个东西

先像上面一样,Find到这个点 然后Splay到根

如果原先没这个点那么根肯定就是前驱或者后继之一

如果不是那么就分类:

  • 1.前驱 在根节点的左子树一路往右
  • 2.后继 在根节点的右子树一路往左
int next(int x,int p)//0前1后 
{
	find(x);
	int now=root;
	if(val[now]>x&&p) return now;
	if(val[now]<x&&!p) return now;
	now=ch[now][p];
	while(ch[now][p^1]) now=ch[now][p^1];
	splay(now,0);
	return now;
}

6.第k小 Kth

这个肯定是直接树上二分
想一下当时学权值线段树的时候怎么做:

  • 0.先判断整棵树大小是否大于k 大于的话一定找不到Kth
  • 1.如果当前子树\(size<=k\) 往左走
  • 2.否则就往右走 让k减去左子树的大小和当前根节点的sum 往右递归

把这个套到\(BST\)上就可以了

注意 往左走的时候要判断有没有左节点 如果同时2条件不成立 当前点就是答案

int kth(int k)
{
	int now=root;
	if(k>size[root]) return -1;
	while(1)
	{
		if(ch[now][0]&&k<=size[ch[now][0]])
			now=ch[now][0];
		else if(k>size[ch[now][0]]+sum[now])
		{
			k-=size[ch[now][0]]+sum[now];
			now=ch[now][1];
		}
		else 
		{
			splay(now,0);
			return now;
		}
	}
}

7.删除操作

思路很简单

找到当前点的前驱和后继

把前驱\(Splay\)到根

把后继\(Splay\)到根节点的右儿子

因为后继的左子树中肯定只有要删的数一个数

所以对后继的左儿子讨论

  • 1.如果左儿子个数大于1 将左儿子个数减1即可
  • 2.否则说明删去这个节点 直接清零即可
void remove(int x)
{
	int last=next(x,0),nxt=next(x,1);
	splay(last,0);splay(nxt,last);
	int del=ch[nxt][0];
	if(sum[del]>1)
	{
		sum[del]--;
		splay(del,0);
	}
	else ch[nxt][0]=0;
	Pushup(nxt);
	Pushup(root);
}

如果不存在这个点呢?那也没关系 因为它的前驱后继之间就是空的

4 完整Code

#include<bits/stdc++.h>
#define MAXN 100005
using namespace std;
int root,n,tot;
int fa[MAXN],size[MAXN];
int ch[MAXN][2],sum[MAXN],val[MAXN];
int chk(int x)
{
	return ch[fa[x]][1]==x;
}
void Pushup(int x)
{
	size[x]=size[ch[x][0]]+size[ch[x][1]]+sum[x];
}
void rotate(int x)
{
	int y=fa[x],z=fa[y],k=chk(x),w=ch[x][k^1];
	ch[y][k]=w;fa[w]=y;
	ch[z][chk(y)]=x;fa[x]=z;
	ch[x][k^1]=y;fa[y]=x;
	Pushup(y);
	Pushup(x);
}
void splay(int x,int goal)
{
	while(fa[x]!=goal)
	{
		int y=fa[x];
		int z=fa[y];
		if(z!=goal)
			if(chk(x)==chk(y)) rotate(y);
			else rotate(x);
		rotate(x);
	}
	if(!goal) root=x;
}
void find(int x)
{
	if(!root) return ;
	int now=root;
	while(ch[now][x>val[now]]&&val[now]!=x)
		now=ch[now][x>val[now]];
	splay(now,0);
}
void insert(int x)
{
	int now=root,p=0;
	while(now&&val[now]!=x)
 	{
		p=now;
		now=ch[now][x>val[now]];
	}
	if(now) sum[now]++;
	else
	{
		now=++tot;
		if(p) ch[p][x>val[p]]=now;
		ch[now][0]=ch[now][1]=0;
		val[now]=x;fa[now]=p;
		sum[now]=size[now]=1;
	}
	splay(now,0);
}
int rank(int x)
{
	find(x);
	if(val[root]>=x) return size[ch[root][0]];
	else return size[ch[root][0]]+sum[root];
}
int next(int x,int p)//0前1后 
{
	find(x);
	int now=root;
	if(val[now]>x&&p) return now;
	if(val[now]<x&&!p) return now;
	now=ch[now][p];
	while(ch[now][p^1]) now=ch[now][p^1];
	splay(now,0);
	return now;
}
int kth(int k)
{
	int now=root;
	if(k>size[root]) return -1;
	while(1)
	{
		if(ch[now][0]&&k<=size[ch[now][0]])
			now=ch[now][0];
		else if(k>size[ch[now][0]]+sum[now])
		{
			k-=size[ch[now][0]]+sum[now];
			now=ch[now][1];
		}
		else 
		{
			splay(now,0);
			return now;
		}
	}
}
void remove(int x)
{
	int last=next(x,0),nxt=next(x,1);
	splay(last,0);splay(nxt,last);
	int del=ch[nxt][0];
	if(sum[del]>1)
	{
		sum[del]--;
		splay(del,0);
	}
	else ch[nxt][0]=0;
	Pushup(nxt);
	Pushup(root);
}
int main()
{
	insert(1e8);
	insert(-1e8);
	scanf("%d",&n);
	while(n--)
	{
		int opr,x;
		scanf("%d%d",&opr,&x);
		if(opr==1) insert(x);
		if(opr==2) remove(x);
		if(opr==3) printf("%d\n",rank(x));	 
		if(opr==4) printf("%d\n",val[kth(x+1)]);
		if(opr==5) printf("%d\n",val[next(x,0)]);
		if(opr==6) printf("%d\n",val[next(x,1)]);
	}
	return 0;
}

5注意的点

  • 1.Kth完后要Splay
  • 2.insert完要Splay
  • 3.Find完要Splay
  • 4.前驱后继完也要Splay
  • 5.各个函数返回的是下标值 要加上\(val[x]\)
  • 6.使用前插入两个极值 所以数组大小一定最少要加2
  • 7.rotate顺序的y z x
  • 8.记得Pushup
  • 9.Splay时判断最顶端是否为目标
  • 10.前驱后继从子节点开始 不是根
  • 11.rank 函数里面以root为主体