平衡树专栏

发布时间 2023-07-21 18:44:35作者: Rose_Lu

普通平衡树【Treap】

平衡树可以实现很多操作,而且时间都是在 \(O(log_n)\) 级别的

  • 芝士复杂度

为什么可以稳定在 \(O(log_n)\) 呢?是因为它不仅是二叉查找树,而且还是一个堆(堆的值是用随机函数实现的),所以它的深度就不会被卡,稳定在 \(O(log_n)\) 的深度,所以时间也就是 \(O(log_n)\) 了。

  • 芝士操作

那么 \(treap\) 都包括哪些操作呢? 插入,删除,根据编号找值,根据值找编号,找前驱和后继,由此可见其也是非常的强悍

  • 芝士实现

1.如果它插入或者删除之后不符合二叉查找树的性质了怎么办,那就
旋分为左旋和右旋,具体来说左旋就是把当前节点和它的两个儿子逆时针旋转一下,再计算值,右旋相反。这是一个非常重要的操作。
2.如何插入?如果插入的值在原树上已经有了,那我们就可以直接加进去,如果没有就新开一个节点给他弄进树,如果当前节点不是要找的插入位置那就根据情况判断是往左查找还是往右。
3.如何删除?删除的操作其实和插入大差不差,但是多了的操作,在删除节点之前一下,将要删除的点到叶子节点,这样就会让删除更简单!这里可以画一张图结合下面代码仔细想想。
4.如何根据编号找值?我们只需要在树中再开一个数组记录以它为根的这棵树的 \(size\),然后再通过编号和树的 \(size\) 作比较,然后向下递归直到找到要找的节点就好。
5.如何根据值找编号?因为我们的这棵树是二叉查找树,所以查找就像普通二叉查找树的查找一样(其实和删除操作也差不多),向下递归就好!
6.如何找前驱和后继?前驱就是比当前节点的值小的值里的最大的值,后继就是比当前节点大的值里最小的值,我们既然是二叉查找树,那这还不简单,前驱只要在它的左子树一直向右找到叶子节点就好,后继相反。

到现在为止,您就已经学会 \(treap\) 了!

  • 芝士代码
#include <iostream>
#include <cstdlib>

using namespace std;
const int Max = 1e5+10;
const int inf = 0x7fffffff;

int n , root , tot;

struct zx{
	int l , r , size , cnt , dat , pre;
}a[Max];

int add(int pre) {
	tot ++;
	a[tot].pre = pre;
	a[tot].dat = rand();
	a[tot].cnt = a[tot].size = 1;
	return tot;
}

void updata(int x) {
	a[x].size = a[a[x].l].size + a[a[x].r].size + a[x].cnt;
}

void build() {
	root = add(-inf);
	a[root].r = add(inf);
	updata(root);
}

void zig(int &p) {
	int q = a[p].l;
	a[p].l = a[q].r;
	a[q].r = p;
	p = q;
	updata(a[p].r) , updata(p);
}

void zag(int &p) {
	int q = a[p].r;
	a[p].r = a[q].l;
	a[q].l = p;
	p = q;
	updata(a[p].l) , updata(p);
}

void insert(int &x , int pre) {
	if(x == 0) {
		x = add(pre);
		return ;
	}
	if(a[x].pre == pre) {
		a[x].cnt ++;
		updata(x);
		return;
	}
	if(pre < a[x].pre) {
		insert(a[x].l , pre);
		if(a[a[x].l].dat > a[x].dat) zig(x);
	}
	else {
		insert(a[x].r , pre);
		if(a[a[x].r].dat > a[x].dat) zag(x);
	}
	updata(x);
}

void delet(int &x , int pre) {
	if(x == 0) return;
	if(a[x].pre == pre) {
		if(a[x].cnt > 1) {
			a[x].cnt --;
			updata(x);
			return;
		}
		if(a[x].l || a[x].r) {
			if(!a[x].r || a[a[x].l].dat > a[a[x].r].dat) zig(x) , delet(a[x].r , pre);
			else zag(x) , delet(a[x].l , pre);
			updata(x);
		}
		else x = 0;
	}
	if(pre > a[x].pre) delet(a[x].r , pre);
	else delet(a[x].l , pre);
	updata(x);
}

int get_pre(int p, int rank) {
	if (p == 0) return 0;
	if (a[a[p].l].size >= rank) return get_pre(a[p].l, rank);
	if (a[a[p].l].size + a[p].cnt >= rank) return a[p].pre;
	return get_pre(a[p].r, rank - a[p].cnt - a[a[p].l].size); 
}

int get_rank(int x, int pre) {
	if(x == 0) return 0;
	if(a[x].pre == pre) return a[a[x].l].size + 1;
	if(a[x].pre > pre) return get_rank(a[x].l , pre);
	return get_rank(a[x].r , pre) + a[a[x].l].size + a[x].cnt;
}

int get_front(int pre) {
	int ans = 0 , x = root;
	while(x) {
		if(a[x].pre >= pre) x = a[x].l;
		else ans = a[x].pre , x = a[x].r;
	}
	return ans;
}

int get_next(int pre) {
	int ans = 0 , x = root;
	while(x) {
		if(a[x].pre <= pre) x = a[x].r;
		else ans = a[x].pre , x = a[x].l;
	}
	return ans;
}

signed main() {
	build();
	cin >> n;
	for(long long i = 1; i <= n; i++) {
		long long opt , x;
		cin >> opt >> x;
		if(opt == 1) insert(root , x);
		else if(opt == 2) delet(root , x);
		else if(opt == 3) cout << get_rank(root , x) - 1 << endl;
		else if(opt == 4) cout << get_pre(root, x + 1) << endl;
		else if(opt == 5) cout << get_front(x) << endl;
		else cout << get_next(x) << endl;
	}
	return 0;
}