后缀平衡树

发布时间 2023-08-08 16:00:34作者: Xun_Xiaoyao

一个神奇的字符串科技。
它支持:

  1. \(O(\log |S|)\) 时间在字符串 \(S\) 前插入一个字符。
  2. \(O(\log |S|)\) 时间删去字符串 \(S\) 的第一个字符。
  3. \(O(|T|\log |S|)\) 时间查询字符串 \(T\) 在所有 \(S\) 后缀中的排名。
  4. \(O(\log |S|)\) 查询 \(S\) 一个串的排名。
  5. \(O(1)\) 比较 \(S\) 的两个后缀。

理论上可以维护更多的信息,但是能在线处理这些东西就很强大了(离线问题基本可以用后缀数组或者后缀自动机代替)。

加下来我们记 \(n=|S|\)

既然它叫后缀平衡树,首先它需要一棵平衡树,具体需要什么平衡树先不管,因为可能要用到很多特殊的性质。我们先假设它就是一个支持插入和删除的平衡树!

我们先思考如何插入。既然我们说的是将 \(S\) 变成 \(cS\),我们多出的一个后缀就是 \(cS\),考虑在平衡树中插入的过程,我们会将这个串与 \(O(\log |n|)\)\(|S|\) 的后缀比较。
暴力比较可以做到单次 \(O(n)\),也就是插入一个字符复杂度是 \(O(n\log n)\),显然无法接受。
使用二分+哈希比较可以做到单次 \(O(\log n)\),那么插入一个字符的复杂度是 \(O(\log^2n)\)

但是,这些都是比较淳朴的比较方式,我们都没有利用好后缀树的性质——既然我们已经建出了后缀树,这就意味这我们可以很方便的比较出已经有的两个后缀的大小。
\(cS\) 还没有加入,但是 \(S\) 已经在后缀平衡树里面了,所以我们要比较 \(cS\)\(S\) 的某一个后缀,我们只需要比较第一位,如果相同,就可以直接比较 \(S\) 和某一个后缀即可。

如果直接从后缀平衡树上读取排名,仍然是单次 \(O(\log n)\),一个插入是 \(O(\log^2n)\)
假设我们能设计出一个权值 \(val\),它的大小关系和 \(rk\) 完全一致,也就是 \(rk_i<rk_j\to val_i<val_j\)
我们考虑这个平衡树的结构,我们将它对应一个值域为 \([0,10^{18}]\) 的区间,构造一个类似线段树的结构,如果当前节点对应的区间是 \([l,r]\),则它的权值 \(val=mid=\dfrac{l+r}{2}\),它的左儿子对应区间 \([l,mid]\),右儿子对应区间 \([mid+1,r]\),这样所有的 \(val\) 都可以和 \(rk\),对应了。

但是这样就会遇到一个问题:我们插入或删除的时候,我们需要更新 \(val\) 的节点将会是一棵子树,所以我们需要一个子树大小期望或均摊 \(O(\log n)\) 的平衡树,也就是重量平衡树,例如 FHQ Treap 和替罪羊树。

我用的是 FHQ Treap。

这样我们就能够比较轻松地实现后缀平衡树了。

Luogu P6164 【模板】后缀平衡树

代码:

#include <bits/stdc++.h>
using namespace std;
int Qread()
{
    int x=0;char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') x=x*10+(ch^48),ch=getchar();
    return x;
}
int get_str(char *s)
{
    int len=1;
    do s[1]=getchar();while(s[1]<'A'||s[1]>'Z');
    do s[++len]=getchar();while(s[len]>='A'&&s[len]<='Z');
    s[len]=0;
    return len-1;
}
void decode(char *s,int len,int mask)
{
    for(int i=0;i<len;i++)
    {
        mask=(mask*131+i)%len;
        swap(s[i],s[mask]);
    }
}
mt19937 zsh(20070610);
struct Node{
    long long val;
    double l,r,tag;
    int lson,rson,siz;
}p[800010];
char s[800010];int top;
void get_node(int x)
{
    p[x].val=zsh();
    p[x].lson=p[x].rson=0;
    p[x].siz=1;
}
void update(int x){p[x].siz=p[p[x].lson].siz+p[p[x].rson].siz+1;}
bool cmp(int x,int y)
{
    if(s[x]!=s[y]) return s[x]<s[y];
    return p[x-1].tag<p[y-1].tag;
}
int merge(int x,int y)
{
    if(x&&y)
    {
        if(p[x].val>p[y].val) return p[x].rson=merge(p[x].rson,y),update(x),x;
        else return p[y].lson=merge(x,p[y].lson),update(y),y;
    }
    else return x+y;
}
void split(int rt,int k,int &x,int &y)
{
    if(rt)
    {
        if(cmp(rt,k)) return x=rt,split(p[x].rson,k,p[x].rson,y),update(x);
        else return y=rt,split(p[y].lson,k,x,p[y].lson),update(y);
    }
    else return x=y=0,void();
}
void reset_tag(int pos,double l,double r)
{
    if(!pos) return;
    p[pos].l=l,p[pos].r=r,p[pos].tag=(l+r)/2;
    reset_tag(p[pos].lson,l,p[pos].tag),reset_tag(p[pos].rson,p[pos].tag,r);
}
int rt;
void insert_node(int x)
{
    get_node(x);
    int *cur=&rt;
    double l=p[*cur].l,r=p[*cur].r;
    while(*cur&&p[x].val<p[*cur].val)
    {
        p[*cur].siz++;
        if(cmp(x,*cur)) r=p[*cur].tag,cur=&p[*cur].lson;
        else l=p[*cur].tag,cur=&p[*cur].rson;
    }
    split(*cur,x,p[x].lson,p[x].rson);
    update(x);
    reset_tag(x,l,r);
    *cur=x;
}
void delete_node(int x)
{
    int *cur=&rt;
    while(*cur&&*cur!=x)
    {
        p[*cur].siz--;
        if(cmp(x,*cur)) cur=&p[*cur].lson;
        else cur=&p[*cur].rson;
    }
    assert(*cur);
    double l=p[*cur].l,r=p[*cur].r;
    *cur=merge(p[x].lson,p[x].rson);
    reset_tag(*cur,l,r);
}
bool ccmp(char *s,int ssta,char *t,int tsta)
{
    while(ssta>=0&&tsta>=0)
    {
        if(s[ssta]!=t[tsta])
            return s[ssta]<t[tsta];
        ssta--,tsta--;
    }
    return false;
}
int q,len,mask,ans;
char t[3000010];
int get_ind(int x)
{
    if(!x) return 0;
    if(ccmp(t,len,s,x)) return get_ind(p[x].lson);
    else return p[p[x].lson].siz+1+get_ind(p[x].rson);
}
int main()
{
    q=Qread();
    top=len=get_str(s);
    p[0].l=0,p[0].r=1e18;
    for(int i=1;i<=len;i++) insert_node(i);
    while(q--)
    {
        get_str(t);
        if(t[1]=='A')
        {
            len=get_str(t);
            decode(t+1,len,mask);
            for(int i=1;i<=len;i++) s[++top]=t[i],insert_node(top);
        }
        else if(t[1]=='D')
        {
            len=Qread();
            for(int i=1;i<=len;i++) delete_node(top),s[top--]=0;
        }
        else
        {
            len=get_str(t);
            decode(t+1,len,mask);
            t[0]='[';
            ans=get_ind(rt);
            t[1]--;
            ans-=get_ind(rt);
            printf("%d\n",ans);
            mask^=ans;
        }
        assert(mask>=0);
    }
    return 0;
}