<学习笔记> 后缀树(数)组

发布时间 2023-12-26 10:15:55作者: _bloss

后缀排序

倍增+基数排序

code
bool cmp(int x,int y,int k){ // 常数优化,使访问连续
    if(oldrk[x]==oldrk[y] && oldrk[x+k]==oldrk[y+k]) return 1;
    else return 0;
}
void get_sa(){
    m=10001;
    for(int i=1;i<=n;i++) ++cnt[rk[i]=s[i]];
    for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
    for(int i=n;i>=1;i--) sa[cnt[rk[i]]--]=i;
    for(w=1;;w<<=1,m=p){
        p=0;
        for(int i=n;i+w>n;i--) id[++p]=i;
        for(int i=1;i<=n;i++){
            if(sa[i]-w>0) id[++p]=sa[i]-w;//因为看的是第二个,所以要 -w ,可以手摸
        }
        //基数排序
        memset(cnt,0,sizeof(int)*(m+1));
        for(int i=1;i<=n;i++) ++cnt[rk[id[i]]];
        for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
        for(int i=n;i>=1;i--) sa[cnt[rk[id[i]]]--]=id[i];
        //给 rk 赋值,因为会有相等的,所以要特殊处理
        memcpy(oldrk,rk,sizeof(int)*(n+1));
        p=0;
        for(int i=1;i<=n;i++){
            if(cmp(sa[i],sa[i-1],w)) rk[sa[i]]=p;
            else rk[sa[i]]=++p;
        }
        if(p==n) break; // 优化
    }
}

LCP Theorem

定义 \(height[i]= \mathrm{lca}(sa[i],sa[i-1])\)

\(\mathrm{lcp}(i,j)=\min_{k=i+1}^j \mathrm{lcp}(k-1,k)=\min_{k=i+1}^j height[k]\)

求 LCP 的方法

结论: \(height[rk_i] \geq height[rk_{i-1}]-1\)

证明:

假如 \(k\)\(i-1\) 的开头字母相同,那么 \(i\)\(k+1\)\(lcp\)\(height[rk_{i-1}]-1\),根据 \(LCP Theorem\),它们的 \(lcp\)\(i\)\(k+1\) 的最小值,那么 \(height[rk_{i}] \geq height[rk_{i-1}]-1\)

如果开头字母不同,那么 \(height[rk_{i-1}]=0\) 更加成立。

code

void get_height(){
    int k=0;
    for(int i=1;i<=n;i++){
        if(rk[i]==1) continue;
        if(k) k--;
        while(s[sa[rk[i]]+k]==s[sa[rk[i]-1]+k]) k++;
        height[rk[i]]=k;
    }
}

例题

sandy 的卡片

注意到可以转化为差分序列,然后就是要求使每个字符串都有一个相同的子串的长度最大。然后就用后缀树组,注意在每两个字符串之间差一个极大值,防止相邻串组成一个新后缀。然后进行二分找最长串,主要是 \(check\) 函数。其实只要保证存在 \(n\)\(height[i] \geq mid\),且开始字符来自不同串,注意要连续。

code
#include<bits/stdc++.h>
using namespace std;
const int N=200005;
int a[1005][1005],len[N];
int n,m;
int mx=0,mn=N;
int s[N],top=0;
int rk[N],sa[N],p,w,cnt[N],id[N],oldrk[N];
int pos[N];
int height[N];
bool cmp(int x,int y,int k){
    if(oldrk[x]==oldrk[y] && oldrk[x+k]==oldrk[y+k]) return 1;
    return 0;
}
void get_sa(){
    for(int i=1;i<=n;++i) ++cnt[rk[i]=s[i]];
    for(int i=1;i<=m;++i) cnt[i]+=cnt[i-1];
    for(int i=n;i>=1;--i) sa[cnt[rk[i]]--]=i;
    p=n;
    for(w=1;;w<<=1,m=p){
        p=0;
        for(int i=n;i+w>n;--i) id[++p]=i;
        for(int i=1;i<=n;++i){
            if(sa[i]>w) id[++p]=sa[i]-w;
        }
        memset(cnt,0,sizeof(int)*(m+1));
        for(int i=1;i<=n;++i) ++cnt[rk[id[i]]];
        for(int i=1;i<=m;++i) cnt[i]+=cnt[i-1];
        for(int i=n;i>=1;i--) sa[cnt[rk[id[i]]]--]=id[i];
        
        memcpy(oldrk,rk,sizeof(int)*(n+1));
        p=0;
        for(int i=1;i<=n;++i){
            if(cmp(sa[i],sa[i-1],w)) rk[sa[i]]=p;
            else rk[sa[i]]=++p;
        }
        if(p==n) break;
    }
}
void get_height(){
    int k=0;
    for(int i=1;i<=n;i++){
        if(rk[i]==1) continue;
        if(k) k--;
        while(s[sa[rk[i]]+k]==s[sa[rk[i]-1]+k]) k++;
        height[rk[i]]=k; 
    }
}
bool vis[N];
int st[N];
int kkk=0;
bool check(int mid){
    memset(vis,0,sizeof(vis));
    int cnt=0;
    for(int i=1;i<=n;i++){
        if(height[i]<mid){
            while(cnt) vis[st[cnt]]=0,cnt--;
        }
        if(!vis[pos[sa[i]]]){
            vis[pos[sa[i]]]=1;
            st[++cnt]=pos[sa[i]];
            if(cnt==kkk) return 1;
        }
    }
    return 0;
}
signed main(){
    scanf("%d",&n);
    kkk=n;
    for(int i=1;i<=n;i++){
        scanf("%d",&len[i]);
        for(int j=1;j<=len[i];j++){
            scanf("%d",&a[i][j]);
            mx=max(mx,a[i][j]-a[i][j-1]);
        }
    }
    mx++;
    for(int i=1;i<=n;i++){
        for(int j=1;j<len[i];j++){
            s[++top]=a[i][j+1]-a[i][j];
            mn=min(mn,s[top]);
            pos[top]=i;
        }
        s[++top]=mx;
    }
    n=top;
    m=mx;
    if(mn<0){
        mn=-mn;
        m+=mn;
        for(int i=1;i<=top;i++) s[i]+=mn;
    }
    get_sa();
    get_height();
    int l=0,r=101;
    int ans=-1;
    while(l<=r){
        int mid=(l+r)/2;
        if(check(mid)){
            l=mid+1;
            ans=mid;
        }
        else{
            r=mid-1;
        }
    }
    printf("%d",ans+1);
    
}

[SCOI2012] 喵星球上的点名

把询问接在姓名后
求出 height 数组
对于每次询问,找到它的排名,然后在 sa 上找到这个查询所覆盖的区间,可以二分实现
可以用 st 表将查询优化到 \(O(1)\)
如果存在区间的话,对于第一个询问就是这个区间里存在多少种不同的串
考虑离线处理,莫队,其实也可以 HH的项链
但是第二个询问如何搞?
考虑将每个串变为两个操作,加入和删除。类似于后缀和,当这个串第一次加进来的话,那么给这个串加上可能最多的出现次数,如果删除一个串后,这个串不再存在,那么给这个串减去可能最多的出现次数。

code
#include<bits/stdc++.h>
using namespace std;
const int N=1000005;
int s[N],top=0,pos[N];
int sa[N],rk[N],cnt[N],id[N],oldrk[N],p,w,height[N];
int pre[N];
int len[N];
int f[N][50],t;
int n,q,m;
struct asd{
    int x,y,id;
}st[N];
int poss[N],tt;
int ans1[N],ans2[N];
int L[N],R[N];
bool cmp(int x,int y,int k){
    if(oldrk[x]==oldrk[y] && oldrk[x+k]==oldrk[y+k]) return 1;
    else return 0;
}
void get_sa(){
    m=10001;
    for(int i=1;i<=n;i++) ++cnt[rk[i]=s[i]];
    for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
    for(int i=n;i>=1;i--) sa[cnt[rk[i]]--]=i;
    for(w=1;;w<<=1,m=p){
        p=0;
        for(int i=n;i+w>n;i--) id[++p]=i;
        for(int i=1;i<=n;i++){
            if(sa[i]-w>0) id[++p]=sa[i]-w;
        }
        memset(cnt,0,sizeof(int)*(m+1));
        for(int i=1;i<=n;i++) ++cnt[rk[id[i]]];
        for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
        for(int i=n;i>=1;i--) sa[cnt[rk[id[i]]]--]=id[i];

        memcpy(oldrk,rk,sizeof(int)*(n+1));
        p=0;
        for(int i=1;i<=n;i++){
            if(cmp(sa[i],sa[i-1],w)) rk[sa[i]]=p;
            else rk[sa[i]]=++p;
        }
        if(p==n) break; 
    }
}
void get_height(){
    int k=0;
    for(int i=1;i<=n;i++){
        if(rk[i]==1) continue;
        if(k) k--;
        while(s[sa[rk[i]]+k]==s[sa[rk[i]-1]+k]) k++;
        height[rk[i]]=k;
        f[rk[i]][0]=k;
    }
}
void init(){
    t=log2(n)+1;
    for(int j=1;j<=t;j++)
        for(int i=1;i<=n;i++)
            f[i][j]=min(f[i][j-1],f[i+(1<<(j-1))][j-1]);
}
int ask(int x,int y){
    if(x>y) return 10001;
    int kl=log2(y-x+1);
    return min(f[x][kl],f[y-(1<<kl)+1][kl]);
}
bool amp(asd a,asd b){
    if(poss[a.x]==poss[b.x]) return a.y<b.y;
    else return poss[a.x]<poss[b.x];
}
bool flat[N];
signed main(){
    scanf("%d%d",&n,&q);
    int lm=n;
    for(int i=1;i<=n;i++){
        int len1,len2;
        scanf("%d",&len1);
        for(int j=1;j<=len1;j++){
            top++;
            scanf("%d",&s[top]);
            pos[top]=i;
        }
        s[++top]=10001;
        scanf("%d",&len2);
        for(int j=1;j<=len2;j++){
            top++;
            scanf("%d",&s[top]);
            pos[top]=i;
        }
        s[++top]=10001;
    }
    for(int i=1;i<=q;i++){
        scanf("%d",&len[i]);
        pre[i]=top+1;
        for(int j=1;j<=len[i];j++){
            int x;
            scanf("%d",&x);
            s[++top]=x;
        }
        s[++top]=10001;
    }
    n=top;
    get_sa();
    get_height();
    init();
    top=0;
    for(int i=1;i<=q;i++){
        int l=0,r=rk[pre[i]],anss1=0,anss2=-1;
        while(l<=r){
            int mid=(l+r)/2;
            if(ask(mid+1,rk[pre[i]])>=len[i]){
                anss1=mid;
                r=mid-1;
            }
            else l=mid+1;
        }
        l=rk[pre[i]]-1,r=n;
        while(l<=r){
            int mid=(l+r)/2;
            if(ask(rk[pre[i]]+1,mid)>=len[i]){
                anss2=mid;
                l=mid+1;
            }
            else r=mid-1;
        }
        if(anss1<=anss2){
            st[++top]={anss1,anss2,i};
        }
    }
    tt=log2(q);
    for(int i=1;i<=tt;i++){
        L[i]=(i-1)*tt+1;
        R[i]=i*tt;
    }
    if(R[tt]<q) tt++,L[tt]=R[tt-1]+1,R[tt]=top;
    for(int i=1;i<=tt;i++){
        for(int j=L[i];j<=R[i];j++){
            poss[j]=i;
        }
    }
    sort(st+1,st+q+1,amp);
    memset(cnt,0,sizeof(cnt));
    int ls=1,rs=0;
    int sum=0;
    for(int i=1;i<=q;i++){
        while(rs<st[i].y){
            rs++;
            if(cnt[pos[sa[rs]]]==0 && pos[sa[rs]]>0){
                sum++;
                ans2[pos[sa[rs]]]+=(q-i+1);
            }
            cnt[pos[sa[rs]]]++;
        }
        while(rs>st[i].y){
            cnt[pos[sa[rs]]]--;
            if(cnt[pos[sa[rs]]]==0 && pos[sa[rs]]>0){
                sum--;
                ans2[pos[sa[rs]]]-=(q-i+1);
            }
            rs--;
        }
        while(ls>st[i].x){
            ls--;
            if(cnt[pos[sa[ls]]]==0 && pos[sa[ls]]>0){
                sum++;
                ans2[pos[sa[ls]]]+=(q-i+1);
            }
            cnt[pos[sa[ls]]]++;
        }
        while(ls<st[i].x){
            cnt[pos[sa[ls]]]--;
            if(cnt[pos[sa[ls]]]==0 && pos[sa[ls]]>0){
                sum--;
                ans2[pos[sa[ls]]]-=(q-i+1);
            }
            ls++;
        }
        ans1[st[i].id]=sum;
    }
    for(int i=1;i<=q;i++){
        printf("%d\n",ans1[i]);
    }
    for(int i=1;i<=lm;i++){
        printf("%d ",ans2[i]);
    }
}

字符串

求一个区间的字符串的子串和一个区间的字符串的最长公共前缀,直接暴力的话可以将这个区间取出来,然后在后面接上后一个串,然后考虑后面那个串相邻两个的公共前缀,复杂度 \(O(n^2 \log n)\)
如果还是这样应用到全局上的话,那么就要求起点和终点在区间范围内,然后二分长度。其实就是 lcp 是一个单峰的,且自己为峰顶,那么可以二分出左右边界,然后在这个边界里查找是否存在左端点在 [l,r-mid+1] 内的点,主席树维护。

code
#include<bits/stdc++.h>
using namespace std;
const int N=200005;
char s[N];
int sa[N],rk[N],id[N],oldrk[N],cnt[N],p,w,height[N];
int n,m,q;
int f[N][40],t;
bool cmp(int x,int y,int k){
    if(oldrk[x]==oldrk[y] && oldrk[x+k]==oldrk[y+k]) return 1;
    else return 0;
}
void get_sa(){
    m=128;
    for(int i=1;i<=n;i++) ++cnt[rk[i]=s[i]];
    for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
    for(int i=n;i>=1;i--) sa[cnt[rk[i]]--]=i;
    for(w=1;;w<<=1,m=p){
        p=0;
        for(int i=n;i+w>n;i--) id[++p]=i;
        for(int i=1;i<=n;i++){
            if(sa[i]-w>0) id[++p]=sa[i]-w;
        }
        
        memset(cnt,0,sizeof(int)*(m+1));
        for(int i=1;i<=n;i++) ++cnt[rk[id[i]]];
        for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
        for(int i=n;i>=1;i--) sa[cnt[rk[id[i]]]--]=id[i];
        
        memcpy(oldrk,rk,sizeof(int)*(n+1));
        p=0;
        for(int i=1;i<=n;i++){
            if(cmp(sa[i],sa[i-1],w)) rk[sa[i]]=p;
            else rk[sa[i]]=++p;
        }
        if(p==n) break;
    }
}
void get_height(){
    int k=0;
    for(int i=1;i<=n;i++){
        if(rk[i]==1) continue;
        if(k) k--;
        while(s[sa[rk[i]]+k]==s[sa[rk[i]-1]+k]) k++;
        height[rk[i]]=k;
        f[rk[i]][0]=k;
    }
}
void init(){
    t=log(n)+1;
    for(int j=1;j<=t;j++)
        for(int i=1;i<=n;i++)
            f[i][j]=min(f[i][j-1],f[i+(1<<(j-1))][j-1]);
}
int ask(int x,int y){
    if(x>y) return 100005;
    int kl=log2(y-x+1);
    return min(f[x][kl],f[y-(1<<kl)+1][kl]);
}
int rt[N],idx=0;
struct tree{
    int l,r,sum;
}tr[N*70];
void change(int &p,int q,int l,int r,int wh){
    p=++idx;
    tr[p]=tr[q];
    if(l==r){
        tr[p].sum++;
        return;
    }
    int mid=(l+r)/2;
    if(wh<=mid) change(tr[p].l,tr[q].l,l,mid,wh);
    else change(tr[p].r,tr[q].r,mid+1,r,wh);
    tr[p].sum=tr[tr[p].l].sum+tr[tr[p].r].sum;
}
int query(int p,int q,int l,int r,int ls,int rs){
    if(ls>rs) return 0;
    if(l>=ls && r<=rs) return tr[q].sum-tr[p].sum;
    int mid=(l+r)/2;
    int sum=0;
    if(ls<=mid) sum+=query(tr[p].l,tr[q].l,l,mid,ls,rs);
    if(rs>mid) sum+=query(tr[p].r,tr[q].r,mid+1,r,ls,rs);
    return sum;
}
bool check(int len,int c,int a,int b){
    int l=0,r=rk[c],ans1=0,ans2=0;
    while(l<=r){
        int mid=(l+r)/2;
        if(ask(mid+1,rk[c])>=len){
            r=mid-1;
            ans1=mid;
        }
        else l=mid+1;
    }
    l=rk[c]-1,r=n;
    while(l<=r){
        int mid=(l+r)/2;
        if(ask(rk[c]+1,mid)>=len){
            l=mid+1;
            ans2=mid;
        }
        else r=mid-1;
    }
    int sum=query(rt[ans1-1],rt[ans2],1,n,a,b-len+1);
    return (sum>0);
}
signed main(){
    scanf("%d%d",&n,&q);
    scanf("%s",s+1);
    get_sa();
    get_height();
    init();
    for(int i=1;i<=n;i++) change(rt[i],rt[i-1],1,n,sa[i]);
    for(int op=1;op<=q;op++){
        int a,b,c,d;
        scanf("%d%d%d%d",&a,&b,&c,&d);
        int l=0,r=d-c+1;
        while(l<r){
            int mid=(l+r+1)/2;
            if(check(mid,c,a,b)) l=mid;
            else r=mid-1;
        }
        printf("%d\n",l);
    }
}

/*
求一个区间的字符串的子串和一个区间的字符串的最长公共前缀
对于区间形式,要用到可持久化数据结构
但是好像没有办法直接维护
是不是将问题转化形式
直接暴力的话可以将这个区间取出来,然后在后面接上后一个串
然后考虑后面那个串相邻两个的公共前缀
复杂度 $O(n^2 \log n)$
如果还是这样应用到全局上的话
那么就要求起点和终点在区间范围内
可不可以二分长度
有什么要求呢?
还是要求这两个串最长前缀必须满足要求,且起点必须在范围内,终点可以大于范围
可以求出起点应该属于的区间,暴力的话就是直接扫
看了题解发现就差一点点了
其实就是 lcp 是一个单峰的,且自己为峰顶,那么可以二分出左右边界
然后在这个边界里查找是否存在左端点在 [l,r-mid+1] 内的点
然后就是一颗主席树
*/

P4248 [AHOI2013] 差异

求出左侧 height 第一个比它小和右侧第一个不大于它的,然后算贡献,用单调栈。

注意一个字符可能会使你 \(RE\)

code
#include<bits/stdc++.h>
using namespace std;
const int N=4*1e6+10;
char s[N];
int n,m;
int p,w,rk[N],sa[N],id[N],oldrk[N],cnt[N],height[N];
bool cmp(int x,int y,int k){
    if(oldrk[x]==oldrk[y] && oldrk[x+k]==oldrk[y+k]) return 1;
    else return 0;
}
void get_sa(){
    m=128;
    n=strlen(s+1);
    for(int i=1;i<=n;i++) ++cnt[rk[i]=s[i]];
    for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
    for(int i=n;i>=1;i--) sa[cnt[rk[i]]--]=i;
    for(w=1;;w<<=1,m=p){
        p=0;
        for(int i=n;i+w>n;i--) id[++p]=i;
        for(int i=1;i<=n;i++){
            if(sa[i]-w>0) id[++p]=sa[i]-w;
        }
        
        memset(cnt,0,sizeof(int)*(m+1));
        for(int i=1;i<=n;i++) ++cnt[rk[id[i]]];
        for(int i=1;i<=m;i++) cnt[i]+=cnt[i-1];
        for(int i=n;i>=1;i--) sa[cnt[rk[id[i]]]--]=id[i];
        
        memcpy(oldrk,rk,sizeof(int)*(n+1));
        p=0;
        for(int i=1;i<=n;i++){
            if(cmp(sa[i],sa[i-1],w)) rk[sa[i]]=p;
            else rk[sa[i]]=++p;
        }
        if(p==n) break;
    }
}
void get_height(){
    int k=0;
    for(int i=1;i<=n;i++){
        if(rk[i]<=1) continue;
        if(k) k--;
        while(s[sa[rk[i]]+k]==s[sa[rk[i]-1]+k]) k++;
        height[rk[i]]=k;
    }
}
int ls[N],rs[N];
int st[N],top=0;
signed main(){
    scanf("%s",s+1);
    get_sa();
    get_height();
    for(int i=1;i<=n;i++){
        while(top>0 && height[st[top]]>height[i]) top--;
        ls[i]=st[top];
        st[++top]=i;
    }
    top=0;
    st[0]=n+1;
    for(int i=n;i>=1;i--){
        while(top>0 && height[st[top]]>=height[i]) top--;
        rs[i]=st[top];
        st[++top]=i;
    }
    long long sum=0;
    for(int i=1;i<=n;i++){
        sum+=1ll*i*(n-1);
    }
    for(int i=1;i<=n;i++){
        int l=ls[i]+1,r=rs[i]-1;
        sum-=2ll*height[i]*(i-l+1)*(r-i+1);
    }
    printf("%lld",sum);
}

参考资料