splay

发布时间 2024-01-08 20:16:21作者: liukejie

C++ 阶段性复习总结

Splay

Splay 是一种自平衡的二叉搜索树,其目的是通过旋转和重新平衡来提高树的性能。在 C++ 中,您可以实现 Splay 树作为一种高效的数据结构来处理动态数据集。

例题:

#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
const int INF = 1e9;
struct Treap {
    int lc, rc;
    int val, pri;
    int sum, size;
} tree[300005];
int cnt, root;
void update (int p) {
    tree[p].size = tree[p].sum + tree[tree[p].lc].size + tree[tree[p].rc].size;
}
void zig (int &p) {
    int q = tree[p].lc;
    tree[p].lc = tree[q].rc;
    tree[q].rc = p;
    tree[q].size = tree[p].size;
    update (p);
    p = q;
}
void zag (int &p) {
    int q = tree[p].rc;
    tree[p].rc = tree[q].lc;
    tree[q].lc = p;
    tree[q].size = tree[p].size;
    update (p);
    p = q;
}
int newnode (int val) {
    tree[++cnt].lc = 0;
    tree[cnt].rc = 0;
    tree[cnt].val = val;
    tree[cnt].pri = rand();
    tree[cnt].sum = 1;
    tree[cnt].size = 1;
    return cnt;
}
void ins (int &p, int val) {
    if (!p) {
        p = newnode (val);
        return;
    }
    tree[p].size++;
    if (val == tree[p].val) {
        tree[p].sum++;
        return;
    }
    if (val < tree[p].val) {
        ins (tree[p].lc, val);
        if (tree[tree[p].lc].pri > tree[p].pri)
            zig (p);
    } else {
        ins (tree[p].rc, val);
        if (tree[tree[p].rc].pri > tree[p].pri)
            zag (p);
    }
}
void del (int &p, int val) {
    if (!p)
        return;
    tree[p].size--;
    if (val == tree[p].val) {
        if (tree[p].sum > 1) {
            tree[p].sum--;
            return;
        }
        if (!tree[p].lc || !tree[p].rc) {
            p = tree[p].lc + tree[p].rc;
        } else if (tree[tree[p].lc].pri > tree[tree[p].rc].pri) {
            zig (p);
            del (tree[p].rc, val);
        } else {
            zag (p);
            del (tree[p].lc, val);
        }
        return;
    }
    if (val < tree[p].val)
        del (tree[p].lc, val);
    else
        del (tree[p].rc, val);
}
int pre (int val) {
    int p = root;
    int res = -INF;
    while (p) {
        if (val > tree[p].val) {
            res = tree[p].val;
            p = tree[p].rc;
        } else
            p = tree[p].lc;
    }
    return res;
}
int kth (int p, int val) {
    if (!p)
        return -INF;
    if (val <= tree[tree[p].lc].size)
        return kth (tree[p].lc, val);
    else if (val <= tree[tree[p].lc].size + tree[p].sum)
        return tree[p].val;
    else
        return kth (tree[p].rc, val - tree[tree[p].lc].size - tree[p].sum);
}
void work (int val) {
    while (pre (val) != -INF)
        del (root, pre (val) );
}
bool spl (int k) {
    return k <= tree[root].size;
}
int main() {
    int n, minn, addt = 0, ans = 0;
    scanf ("%d%d", &n, &minn);
    for (int i = 1; i <= n; i++) {
        char opt;
        int val;
        scanf (" %c%d", &opt, &val);
        switch (opt) {
        case 'I':
            if (val - addt >= minn) {
                ans++;
                ins (root, val - addt);
            }
            break;
        case 'A':
            minn -= val;
            addt += val;
            work (minn);
            break;
        case 'S':
            minn += val;
            addt -= val;
            work (minn);
            break;
        case 'F':
            if (spl (val) )
                printf ("%d\n", kth (root, tree[root].size + 1 - val) + addt);
            else
                printf ("-1\n");
            break;
        }
    }
    printf ("%d\n", ans - tree[root].size);
    return 0;
}
#include<bits/stdc++.h>
using namespace std;
struct Splay {
    int rt,tot,fa[1000010],son[1000010][2],val[1000010],cnt[1000010],sz[1000010];
    void maintain(int x) {sz[x]=sz[son[x][0]]+sz[son[x][1]]+cnt[x];}
    bool get(int x) {return son[fa[x]][1]==x;}
    void clear(int x) {fa[x]=son[x][1]=son[x][0]=val[x]=cnt[x]=sz[x]=0;}
    void rotate(int x) {
        int y=fa[x],z=fa[y],chk=get(x);
        son[y][chk]=son[x][chk^1],fa[son[x][chk^1]]=y;
        son[x][chk^1]=y;
        fa[y]=x,fa[x]=z;
        if(z) son[z][y==son[z][1]]=x;
        maintain(x),maintain(y);
    }
    void splay(int x) {
        for(int f=fa[x];f=fa[x],f;rotate(x))
            if(fa[f]) rotate((get(f)==get(x))?f:x);
        rt=x;
    }
    void insert(int x) {
        if(!rt) {
            val[++tot]=x,cnt[tot]++;
            rt=tot,maintain(rt);
            return;
        }
        int cnr=rt,f=0;
        while(1) {
            if(val[cnr]==x) {
                cnt[cnr]++;
                maintain(cnr),maintain(f);
                splay(cnr);
                break;
            }
            f=cnr,cnr=son[cnr][val[cnr]<x];
            if(!cnr) {
                val[++tot]=x,cnt[tot]++;
                fa[tot]=f,son[f][val[f]<x]=tot;
                maintain(tot),maintain(f);
                splay(tot);
                break;
            }
        }
    }
    int rk(int x) {
        int res=0,cnr=rt;
        while(1) {
            if(x<val[cnr]) cnr=son[cnr][0];
            else {
                res+=sz[son[cnr][0]];
                if(x==val[cnr]) {
                    splay(cnr);
                    return res+1;
                }
                res+=cnt[cnr];
                cnr=son[cnr][1];
            }
        }
    }
    int kth(int x) {
        int cnr=rt;
        while(1) {
            //printf("son is %d\n",son[cnr][0]);
            if(son[cnr][0]&&x<=sz[son[cnr][0]]) cnr=son[cnr][0];
            else {
                x-=sz[son[cnr][0]]+cnt[cnr];
                if(x<=0) {
                    //printf("splay %d\n",cnr);
                    splay(cnr);
                    return val[cnr];
                }
                cnr=son[cnr][1];
            }
        }
    }
    int pre() {
        int cnr=son[rt][0];
        while(son[cnr][1]) cnr=son[cnr][1];
        splay(cnr);
        return cnr;
    }
    int nxt() {
        int cnr=son[rt][1];
        while(son[cnr][0]) cnr=son[cnr][0];
        splay(cnr);
        return cnr;
    }
    void delete_(int x) {
        rk(x);
        if(cnt[rt]>1) {
            cnt[rt]--;
            maintain(rt);
            return;
        }
        if(!son[rt][0]&&!son[rt][1]) {
            clear(rt);rt=0;
            return;
        }
        if(!son[rt][0]) {
            int cnr=rt;rt=son[rt][1];
            fa[rt]=0,clear(cnr);
            return;
        }
        if(!son[rt][1]) {
            int cnr=rt;rt=son[rt][0];
            fa[rt]=0,clear(cnr);
            return;
        }
        int cnr=rt,t=pre();
        fa[son[cnr][1]]=t,son[t][1]=son[cnr][1];
        clear(cnr),maintain(rt);
    }
} tr;
int jkghjh,op,x;
int main() {
    scanf("%d",&jkghjh);
    while(jkghjh--) {
        scanf("%d %d",&op,&x);
        if(op==1) tr.insert(x);
        if(op==2) tr.delete_(x);
        if(op==3) printf("%d\n",tr.rk(x));
        if(op==4) printf("%d\n",tr.kth(x));
        if(op==5) tr.insert(x),printf("%d\n",tr.val[tr.pre()]),tr.delete_(x);
        if(op==6) tr.insert(x),printf("%d\n",tr.val[tr.nxt()]),tr.delete_(x);
    }
}
#include <bits/stdc++.h>

using namespace std;

// #define int long long 
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
#define me0(a); memset(a, 0, sizeof a);
#define me3(a); memset(a, 0x3f, sizeof a);
#define PII pair<int, int>
#define il inline

const int INF = 0x3f3f3f3f, MOD = 1e9 + 7;

il void read(int &n)
{
    bool w = 0;
    char c = getchar();
    for(; c < 48 || c > 57; c = getchar())
        w = c == 45;
    for(n = 0; c >= 48 && c <= 57; c = getchar())
        n = n * 10 + c - 48;
    n = w ? -n : n;
}

il void write(int x, char a)
{
    char c[40], s = 0;
    if(x < 0) putchar(45), x = -x;
    for(; x ;) c[s ++] = x % 10, x /= 10;
    if(!s) putchar(48);
    for(; s -- ;) putchar(c[s] + 48);
    putchar(a);
}

#define ls(p) tr[p][0]
#define rs(p) tr[p][1]
const int MAXN = 1e6 + 10;

int tr[MAXN][2], fa[MAXN], sz[MAXN], val[MAXN], tot, root, n, m;
void up(int p){sz[p] = sz[ls(p)] + sz[rs(p)] + 1;}
bool check(int x){return rs(fa[x]) == x;}
void rotate(int x) 
{
    int y = fa[x], z = fa[y];
    int dx = check(x), dy = check(y), tmp = tr[x][dx ^ 1];
    fa[x] = z; fa[y] = x;
    if(tmp) fa[tmp] = y;
    if(z) tr[z][dy] = x;
    tr[x][dx ^ 1] = y; tr[y][dx] = tmp;
    up(y); up(x);
}

void splay(int p) 
{
    for(;fa[p];) 
        if(fa[p] == root || check(p) != check(fa[p]) ) rotate(p);
        else rotate(fa[p]), rotate(p);
    root = p;
}
int Delete(int x) 
{
    splay(x);
    int ret = sz[ls(x)], p = rs(x);
    for(;ls(p);) p = ls(p);
    splay(p);
    fa[ls(x)] = p; ls(p) = ls(x);
    fa[x] = ls(x) = rs(x) = 0;
    sz[x] = 1; up(p);
    return ret;
}
void insert(int x, int num) 
{
    int p = root;
    for(;;)
	{
        if(x == sz[ls(p)] + 1) {splay(p); break;}
        if(x <= sz[ls(p)]) p = ls(p);
        else x -= sz[ls(p)] + 1, p = rs(p);
    }
    int t = rs(p);
    for(;ls(t);) t = ls(t);
    splay(t); rs(p) = num; fa[num] = p;
    up(p); up(t);
}

int findKth(int p, int k)
{
	if(sz[ls(p)] >= k) return findKth(ls(p), k);
	if(sz[ls(p)] + 1 >= k) return p;
	return findKth(rs(p), k - sz[ls(p)] - 1);
}

int build(int l, int r) 
{
    if(l > r) return 0;
    int mid = (l + r) >> 1;
    ls(val[mid]) = build(l, mid - 1);
    rs(val[mid]) = build(mid + 1, r);
    fa[ls(val[mid])] = val[mid];
    fa[rs(val[mid])] = val[mid];
    up(val[mid]);
    return val[mid];
}
main() 
{
    read(n); read(m);
    rep(i, 1, n) cin >> val[i + 1];
    val[1] = n + 1, val[n + 2] = n + 2;
    root = build(1, n + 2);
    for(int x, t; m; -- m)
	{
        char op[10];
        cin >> op;
        if(op[0] == 'T') read(x), Delete(x), insert(1, x);
        if(op[0] == 'B') read(x), Delete(x), insert(n, x);
        if(op[0] == 'I') read(x), read(t), insert(Delete(x) + t, x);
        if(op[0] == 'A') read(x), splay(x), write(sz[ls(x)] - 1, '\n');
        if(op[0] == 'Q') read(x), write(findKth(root, x + 1), '\n');
    }
}
#include <bits/stdc++.h>

using namespace std;

// #define int long long 
#define rep(i, l, r) for(int i = l; i <= r; ++ i)
#define per(i, r, l) for(int i = r; i >= l; -- i)
#define me0(a); memset(a, 0, sizeof a);
#define me3(a); memset(a, 0x3f, sizeof a);
#define PII pair<int, int>
#define il inline

const int INF = 0x3f3f3f3f, MOD = 1e9 + 7;

il void read(int &n)
{
    bool w = 0;
    char c = getchar();
    for(; c < 48 || c > 57; c = getchar())
        w = c == 45;
    for(n = 0; c >= 48 && c <= 57; c = getchar())
        n = n * 10 + c - 48;
    n = w ? -n : n;
}

il void write(int x, char a)
{
    char c[40], s = 0;
    if(x < 0) putchar(45), x = -x;
    for(; x ;) c[s ++] = x % 10, x /= 10;
    if(!s) putchar(48);
    for(; s -- ;) putchar(c[s] + 48);
    putchar(a);
}

#define ls(p) tr[p][0]
#define rs(p) tr[p][1]
const int MAXN = 1e6 + 10;

int tr[MAXN][2], fa[MAXN], sz[MAXN], add[MAXN], val[MAXN], maxx[MAXN], tot, root, n, m;
int wh[MAXN];
void up(int p)
{
	sz[p] = sz[ls(p)] + sz[rs(p)] + 1;
	maxx[p] = max(max(maxx[ls(p)], maxx[rs(p)]), val[p]);
}
void push_up(int p)
{
	if(wh[p])
	{
		wh[ls(p)] ^= 1; wh[rs(p)] ^= 1;
		swap(ls(p), rs(p));
		wh[p] = 0;
	}
	maxx[ls(p)] += add[p];
	maxx[rs(p)] += add[p];
	val[ls(p)] += add[p];
	val[rs(p)] += add[p];
	add[ls(p)] += add[p];
	add[rs(p)] += add[p];
	add[p] = 0;
	maxx[0] = -INF;
	maxx[1] = -INF;
	maxx[n + 2] =-INF;
	sz[0] = 0;
}
bool check(int x){return rs(fa[x]) == x;}
void rotate(int x) 
{
    int y = fa[x], z = fa[y];
    int dx = check(x), dy = check(y), tmp = tr[x][dx ^ 1];
    fa[x] = z; fa[y] = x;
    if(tmp) fa[tmp] = y;
    if(z) tr[z][dy] = x;
    tr[x][dx ^ 1] = y; tr[y][dx] = tmp;
    up(y); up(x);
}
void ladd(const int &p)
{
	if(fa[p]) ladd(fa[p]);
	push_up(p);
}
void splay(int p, int aim) 
{
    ladd(p);
    for(;fa[p] != aim;)
    {
		if(fa[fa[p]] != aim)
			if(check(p) == check(fa[p])) rotate(fa[p]);
			else rotate(p);
		rotate(p);
	}
	if(!aim) root = p;
}

int findKth(int p, int k)
{
    push_up(p);
	if(sz[ls(p)] >= k) return findKth(ls(p), k);
	if(sz[ls(p)] + 1 >= k) return p;
	return findKth(rs(p), k - sz[ls(p)] - 1);
}
int extract(int l, int r)
{
	r += 2;
	l = findKth(root, l); r = findKth(root, r);
	splay(l, 0);
	splay(r, l);
	return ls(r);
}
void change(int l, int r, int v)
{
	int p = extract(l, r);
	add[p] += v;
	val[p] += v;
	maxx[p] += v;
	splay(p, 0);
}
void reverse(int l, int r)
{
    r += 2;
	l = findKth(root, l); r = findKth(root,r);
	splay(l, 0); splay(r, l);
	wh[ls(r)] ^= 1;
	splay(ls(r), 0);
}
int ask(int l, int r)
{
	int p = extract(l, r);
	return maxx[p];
}
int build(int l, int r) 
{
    if(l > r) return 0;
    int mid = (l + r) >> 1;
    if(l == r) {up(l); return l;}
    ls(mid) = build(l, mid - 1);
    rs(mid) = build(mid + 1, r);
    if(ls(mid)) fa[ls(mid)] = mid;
    if(rs(mid)) fa[rs(mid)] = mid;
    up(mid);
    return mid;
}
main() 
{
    read(n); read(m);
    maxx[0] = -INF;
	maxx[1] = -INF;
	maxx[n + 2] = -INF;
    sz[0] = 0;
    root = build(1, n + 2);
    for(int l, r, v, op; m; -- m)
	{
        read(op);
        if(op == 1) read(l), read(r), read(v), change(l, r, v);
        if(op == 2) read(l), read(r), reverse(l, r);
        if(op == 3) read(l), read(r), write(ask(l, r), '\n');
    }
}