<学习笔记> SAM

发布时间 2024-01-01 08:25:30作者: _bloss

SAM

定义

字符串 \(s\)\(SAM\) 是一个接受 \(s\) 的所有后缀的最小 \(DFA\)(确定性有限自动机或确定性有限状态自动机)。

  • \(\mathrm{endpos(t)}:\) 子串 \(t\) 在原串 \(s\) 中所有出现位置(最后一个字符位置)的集合。

  • \(\mathrm{len(u)}:\) \(u\) 这个节点所代表的等价类中最大子串长。

  • \(\mathrm{link(u)}:\) 后缀链接。

  • \(\mathrm{Parent} 树:\) 由后缀链接构成的树。

性质

引理一:字符串 \(s\) 的两个非空子串 \(u\)\(w\) (假设 $|u| \leq |w| $)的 \(\operatorname{endpos}\) 相同,当且仅当字符串 \(u\)\(s\) 中的每次出现,都是以 \(w\) 后缀的形式存在的。

引理二: 考虑两个非空子串 \(u\)\(w\) (假设 \(\left|u\right|\le \left|w\right|\) )。那么要么 \(\operatorname{endpos}(u)\cap \operatorname{endpos}(w)=\varnothing\) ,要么 \(\operatorname{endpos}(w)\subseteq \operatorname{endpos}(u)\) ,取决于 \(u\) 是否为 \(w\) 的一个后缀:
\(\begin{cases} \mathrm{endpos}(w) \subseteq \mathrm{endpos}(u) & \mathrm{if} \ u\ \mathrm{is\ a\ suffix\ of}\ w \\ \mathrm{endpos}(u) \cup \mathrm{endpos}(w) = \varnothing & \mathrm{otherwise} \end{cases}\)

引理三: 考虑一个 \(\operatorname{endpos}\) 等价类,将类中的所有子串按长度非递增的顺序排序。每个子串都不会比它前一个子串长,与此同时每个子串也是它前一个子串的后缀。换句话说,对于同一等价类的任一两子串,较短者为较长者的后缀,且该等价类中的子串长度恰好覆盖整个区间 \([x,y]\)

引理四: 所有后缀链接构成一棵根节点为 \(t_0\) 的树。

  • 根据引理三可知,\(len[u]=len[link[u]]+1\),所以后缀树上这个点到根的并为 \([1,len[u]]\)

  • 又根据引理二,后缀树上 \(u\) 的祖先满足均为 \(u\) 的后缀,所以两个前缀的最长公共后缀为这两个点在后缀树上 \(lca\) 的长度。根据这个就可以求 \(\mathrm{lcp}\)

  • 状态数 \(2n-1\);转移数 \(3n-4\)

  • 后缀树的叶子节点为原串的一个前缀,每个节点的 \(endpos\) 集合为其子树中前缀的个数。

构建

现在,任务转化为实现给当前字符串添加一个字符 \(c\) 的过程。

  • \(\textit{last}\) 为添加字符 \(c\) 之前,整个字符串对应的状态(一开始我们设 \(\textit{last}=0\) ,算法的最后一步更新 \(\textit{last}\) )。

  • 创建一个新的状态 \(\textit{cur}\) ,并将 \(\operatorname{len}(\textit{cur})\) 赋值为 \(\operatorname{len}(\textit{last})+1\) ,在这时 \(\operatorname{link}(\textit{cur})\) 的值还未知。

  • 现在我们按以下流程进行(从状态 \(\textit{last}\) 开始)。如果还没有到字符 \(c\) 的转移,我们就添加一个到状态 \(\textit{cur}\) 的转移(相当于那些所代表的后缀都可以加一个字符变成一个新的后缀),遍历后缀链接。如果在某个点已经存在到字符 \(c\) 的转移,我们就停下来,并将这个状态标记为 \(p\)

  • 如果没有找到这样的状态 \(p\) ,我们就到达了虚拟状态 \(-1\) ,我们将 \(\operatorname{link}(\textit{cur})\) 赋值为 \(0\) 并退出。

  • 假设现在我们找到了一个状态 \(p\) ,其可以通过字符 \(c\) 转移。我们将转移到的状态标记为 \(q\)

  • 如果 \(\operatorname{len}(p)+1=\operatorname{len}(q)\) ,我们只要将 \(\operatorname{link}(\textit{cur})\) 赋值为 \(q\) 并退出。(这样其实就是 \(q\) 这个节点所代表的状态均可以作为 \(cur\) 的后缀,可以证明一下)。

\(last \rightarrow T\) 后缀路径上的前一个状态为 \(p'\)。根据操作,可知 $$ 有一条转移边。则此时 \(\mathrm{minlen}(cur) = \mathrm{minlen}(p') + 1 = (\mathrm{len}(p) + 1) + 1 = \mathrm{len}(q) + 1\),说明 \(q\) 恰好与 \(cur\) 的后缀链接的定义相匹配。

  • 否则就会有些复杂。需要复制状态 \(q\) :我们创建一个新的状态 \(\textit{clone}\) ,复制 \(q\) 的除了 \(\operatorname{len}\) 的值以外的所有信息(后缀链接和转移)。我们将 \(\operatorname{len}(\textit{clone})\) 赋值为 \(\operatorname{len}(p)+1\)
    复制之后,我们将后缀链接从 \(\textit{cur}\) 指向 \(\textit{clone}\) ,也从 \(q\) 指向 \(\textit{clone}\)
    最终我们需要使用后缀链接从状态 \(p\) 往回走,只要存在一条通过 \(p\) 到状态 \(q\) 的转移,就将该转移重定向到状态 \(\textit{clone}\)
    \(minlen(clone)=minlen(q)\)\(len(clone)=len(p)+1\)
    \(minlen(q)=len(clone)+1\)

新建点的原因是因为加入 \(c\) 这个字符,使得 \(q\)\(endpos\) 集合变成两类,一类比另一类多一个 \(n\)。所以要新建一个节点将这个区分出来,所以一开始将 \(q\) 的状态大多数赋给 \(clone\),因为是由一个点拆出来的,需要将 \(len(u) \leq len(p)\) 的所有存在 \(u \rightarrow q\) 的转移边的全部改为指向 \(clone\),这样就可以将 \(endpos(q)\) 合法。

code
void insert(int c){
    int cur=++tot;
    len[cur]=len[last]+1;
    int p=last;
    while(p!=-1 && !ch[p][c]){
        ch[p][c]=cur;
        p=_link[p];
    }
    if(p==-1) _link[cur]=0;
    else{
        int q=ch[p][c];
        if(len[q]==len[p]+1) _link[cur]=q;
        else{
            int clone=++tot;
            len[clone]=len[p]+1,_link[clone]=_link[q];
            for(int i=0;i<26;i++) ch[clone][i]=ch[q][i];
            while(p!=-1 && ch[p][c]==q){
                ch[p][c]=clone;
                p=_link[p];
            }   
            _link[cur]=_link[q]=clone;
        }
    }
    last=cur;
}

应用

两个字符串的最长公共子串

对一个串建 \(SAM\),在另一个串上跑匹配,一位一位考虑,在 \(SAM\) 上维护一个指针 \(p\),当加入 \(i\) 时,如果存在 \(q=ch[p][c_i]\) 那么 \(mx_i=mx_{i-1}+1\),而不能为 \(len_q\),因为会存在 \(len_{q} > len_{p}+1\)。如果不存在,那我们跳 \(link_p\) 直到跳到 \(ch[p'][c_i]\) 存在,此时的 \(mx_i=len[p']+1\),这是因为 \(len[p'] \leq minlen[p] \leq mx[i-1]\),所以是合法的。

code

    int v=0,l=0;
    for(int i=1;i<=m;i++){
        int c=b[i]-'a';
        if(ch[v][c]){
            v=ch[v][c];
            l++;
        }
        else{
            int p=v;
            while(p!=-1 && !ch[p][c]) p=_link[p];
            if(p==-1) v=l=0;
            else{
                v=ch[p][c];
                l=len[p]+1;
            }
        }
        ans=max(ans,l);
    }

不同子串个数

方法一:其实就是求从 \(t_0\) 出发的所有路径,设 \(d_{v}\) 可以表示为所有 \(v\) 的转移的末端的和。则有转移 $$ d_{v}=1+\sum_{w:(v,w,c)\in DAWG}d_{w} $$。
答案为 \(d_{t_0}-1\)

方法二:\(ans=\sum \operatorname{len}(i)-\operatorname{len}(\operatorname{link}(i))\)

如果求子串可以重复,然后你要求解 \(d_{v}\) 那么就变成了 $$ d_{v}=endpos(v)+\sum_{w:(v,w,c)\in DAWG}d_{w} $$,相当于有那么多的位置可以选。

字典序第 k 大子串

其实就是从 \(t_0\) 出发找第 \(k\) 大的路径,所以求出 \(d_{v}\) 之后直接在树上找就可以。

例题

差异

两个前缀的公共后缀就是 \(link\) 树上 \(lca\) 的长度,所以建反串跑 \(dp\)

code
#include<bits/stdc++.h>
using namespace std;
const int N=2*5*1e5+5;
int ch[N][26],len[N],_link[N],tot,last;
char s[N];
void init(){
    len[0]=0,_link[0]=-1;
    tot=last=0;
}
int siz[N];
void insert(int c){
    int cur=++tot;
    len[cur]=len[last]+1;
    int p=last;
    siz[tot]=1;
    while(p!=-1 && !ch[p][c]){
        ch[p][c]=cur;
        p=_link[p];
    }
    if(p==-1) _link[cur]=0;
    else{
        int q=ch[p][c];
        if(len[p]+1==len[q]) _link[cur]=q;
        else{
            int clone=++tot;
            len[clone]=len[p]+1,_link[clone]=_link[q];
            for(int i=0;i<26;i++) ch[clone][i]=ch[q][i];
            while(p!=-1 && ch[p][c]==q){
                ch[p][c]=clone;
                p=_link[p];
            }
            _link[cur]=_link[q]=clone;
        }
    }
    last=cur;
}
int head[N*2],nex[N*2],ver[N*2],idx=0;
void add(int x,int y){
    ver[++idx]=y,nex[idx]=head[x],head[x]=idx;
}
long long ans=0;
void dfs(int x){
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        dfs(y);
        ans-=2ll*len[x]*siz[x]*siz[y];
        siz[x]+=siz[y];
    }
}
signed main(){
    scanf("%s",s+1);
    int n=strlen(s+1);
    init();
    for(int i=n;i>=1;i--) insert(s[i]-'a');
    for(int i=1;i<=tot;i++) add(_link[i],i);
    for(int i=1;i<=n;i++) ans+=1ll*i*(n-1);
    dfs(0);
    printf("%lld",ans);
}

熟悉的文章

首先这个有单调性,所以可以二分。然后对于判断可以进行一个 \(dp\),设 \(mx_i\) 表示以 \(i\) 结尾的最长匹配后缀,这个可以对作文库建广义 \(SAM\),然后对作文匹配。设 \(dp_i\) 表示以 \(i\) 结尾的作文最多可以匹配多少,那么有转移 \(dp_{i}=max(dp_{i-1},dp_j+i-j) (i-mx_i \leq j \leq i-L)\) 其中 \(L\) 为二分的,发现 \(i-mx_i\) 单调不减,具有决策单调性,形式很符合单调队列优化 \(dp\) 的尿性。

然后如何找最长匹配后缀。一位一位考虑,在 \(SAM\) 上维护一个指针 \(p\),当加入 \(i\) 时,如果存在 \(q=ch[p][c_i]\) 那么 \(mx_i=mx_{i-1}+1\),而不能为 \(len_q\),因为会存在 \(len_{q} > len_{p}+1\)。如果不存在,那我们跳 \(link_p\) 直到跳到 \(ch[p'][c_i]\) 存在,此时的 \(mx_i=len[p']+1\),这是因为 \(len[p'] \leq minlen[p] \leq mx[i-1]\),所以是合法的。

code
#include<bits/stdc++.h>
using namespace std;
const int N=2*1e6+10;
char a[N];
int tr[N][2],fa[N],str[N],idx=0;
int ch[N][2],len[N],_link[N],last[N],tot=0;
void init(){
    _link[0]=-1,tot=0;len[0]=0;
}
queue<int> q;
void ins(char *s){
    int n=strlen(s+1);
    int u=0;
    for(int i=1;i<=n;i++){
        int c=s[i]-'0';
        if(!tr[u][c]) tr[u][c]=++idx;
        fa[tr[u][c]]=u,str[tr[u][c]]=c;
        u=tr[u][c];
    }
}
int insert(int las,int c){
    int cur=++tot;
    len[cur]=len[las]+1;
    int p=las;
    while(p!=-1 && !ch[p][c]){
        ch[p][c]=cur;
        p=_link[p];
    }
    if(p==-1) _link[cur]=0;
    else{
        int q=ch[p][c];
        if(len[q]==len[p]+1) _link[cur]=q;
        else{
            int clone=++tot;
            len[clone]=len[p]+1,_link[clone]=_link[q];
            for(int i=0;i<=1;i++) ch[clone][i]=ch[q][i];
            while(p!=-1 && ch[p][c]==q){
                ch[p][c]=clone;
                p=_link[p];
            }
            _link[q]=_link[cur]=clone;
        }
    }
    return cur;
}
int mx[N];
int dp[N];
int n;
int st[N],sl,sr;
bool check(int L){
    memset(dp,0,sizeof(int)*(n+1));
    sl=1,sr=0;
    for(int i=1;i<=n;i++){
        if(i-L>=0){
            while(sl<=sr && dp[i-L]-(i-L)>dp[st[sr]]-st[sr]) sr--;
            st[++sr]=i-L;

        }
        while(st[sl]<i-mx[i]  && sl<=sr) sl++;
        dp[i]=dp[i-1];
        if(sl<=sr) dp[i]=max(dp[i-1],dp[st[sl]]-st[sl]+i);
    }
    if(dp[n]*10>=n*9) return 1;
    return 0;
}
signed main(){
    int Q,m;
    scanf("%d%d",&Q,&m);
    for(int i=1;i<=m;i++){
        scanf("%s",a+1);
        ins(a);
    }
    for(int i=0;i<=1;i++) if(tr[0][i]) q.push(tr[0][i]);
    init();
    while(!q.empty()){
        int x=q.front();
        q.pop();
        last[x]=insert(last[fa[x]],str[x]);
        for(int i=0;i<=1;i++) if(tr[x][i]) q.push(tr[x][i]);
    }
    int ans=N;
    for(int op=1;op<=Q;op++){
        scanf("%s",a+1);
        int p=0;
        n=strlen(a+1);
        for(int i=1;i<=n;i++){
            int c=a[i]-'0';
            if(ch[p][c]){// 注意 到达的点未必是 len[q]=len[p]+1
                mx[i]=mx[i-1]+1;
                p=ch[p][c];
            }
            else{
                while(p!=-1 && !ch[p][c]) p=_link[p];
                if(p==-1) p=0,mx[i]=0;
                else{
                    mx[i]=len[p]+1;
                    p=ch[p][c];
                }   
            }
        }
        int l=0,r=n;
        while(l<r){
            int mid=(l+r+1)/2;
            if(check(mid)) l=mid;
            else r=mid-1;
        }
        printf("%d\n",l);
    }
}

广义 SAM

就是求多个串的 \(SAM\) 问题。

首先建一课 \(trie\),然后 \(bfs\) 加点,加的过程中注意记录 \(last\)

code
void ins(char *s){
    int n=strlen(s+1);
    int u=0;
    for(int i=1;i<=n;i++){
        int c=s[i]-'a';
        if(!tr[u][c]) tr[u][c]=++idx;
        fa[tr[u][c]]=u,str[tr[u][c]]=c;
        u=tr[u][c];
    }
}
int _link[N],ch[N][26],len[N],last[N],tot;
void init(){
    _link[0]=-1;
    tot=0;
}
int insert(int las,int c){
    int cur=++tot;
    len[cur]=len[las]+1;
    int p=las;
    while(p!=-1 && !ch[p][c]){
        ch[p][c]=cur;
        p=_link[p];
    }
    if(p==-1) _link[cur]=0;
    else{
        int q=ch[p][c];
        if(len[q]==len[p]+1) _link[cur]=q;
        else{
            int clone=++tot;
            len[clone]=len[p]+1,_link[clone]=_link[q];
            for(int i=0;i<26;i++) ch[clone][i]=ch[q][i];
            while(p!=-1 && ch[p][c]==q){
                ch[p][c]=clone;
                p=_link[p];
            }
            _link[q]=_link[cur]=clone;
        }
    }
    return cur;
}

int main{
     for(int i=0;i<26;i++) if(tr[0][i]) q.push(tr[0][i]);
    init();
    while(!q.empty()){
        int x=q.front();
        q.pop();
        last[x]=insert(last[fa[x]],str[x]);
        for(int i=0;i<26;i++){
            if(tr[x][i]){
                q.push(tr[x][i]);
            }
        }
    }
}

多个串的 LCS

我们需要对每个节点建立一个长度为 \(k\) 的数组 \(flag\)

字典树插入的时候将被操作的点标记,然后可以发现在后缀树上,这个节点可以被它子树里面所有的标记所标记。所以 \(dfs\) 合并的时候假如找到一个点对于所有的标记 \(flat\) 均为 \(1\),则这个长度可以贡献答案。

code
#include<bits/stdc++.h>
using namespace std;
const int N=2*1e6+5;
char s[15][N];
queue<int> q;
bitset<20> v[N],w[N];
int tr[N][26],idx=0,fa[N],str[N];
void ins(char *s){
    int n=strlen(s+1);
    int u=0;
    for(int i=1;i<=n;i++){
        int c=s[i]-'a';
        if(!tr[u][c]) tr[u][c]=++idx;
        fa[tr[u][c]]=u,str[tr[u][c]]=c;
        u=tr[u][c];
    }
}
int _link[N],ch[N][26],len[N],last[N],tot;
void init(){
    _link[0]=-1;
    tot=0;
}
int insert(int las,int c){
    int cur=++tot;
    len[cur]=len[las]+1;
    int p=las;
    while(p!=-1 && !ch[p][c]){
        ch[p][c]=cur;
        p=_link[p];
    }
    if(p==-1) _link[cur]=0;
    else{
        int q=ch[p][c];
        if(len[q]==len[p]+1) _link[cur]=q;
        else{
            int clone=++tot;
            len[clone]=len[p]+1,_link[clone]=_link[q];
            for(int i=0;i<26;i++) ch[clone][i]=ch[q][i];
            while(p!=-1 && ch[p][c]==q){
                ch[p][c]=clone;
                p=_link[p];
            }
            _link[q]=_link[cur]=clone;
        }
    }
    return cur;
}
int head[N*2],ver[N*2],nex[N*2],tot_1=0;
void add(int x,int y){
    ver[++tot_1]=y,nex[tot_1]=head[x],head[x]=tot_1;
}
int ans=0;
int ps=0;
void dfs(int x){
    for(int i=head[x];i;i=nex[i]){
        int y=ver[i];
        dfs(y);
        v[x]|=v[y];
    }
    if(v[x].count()==ps) ans=max(ans,len[x]);
}
signed main(){
    while(scanf("%s",s[++ps]+1)!=EOF){
        ins(s[ps]);
    }
    ps--;
    for(int i=0;i<26;i++) if(tr[0][i]) q.push(tr[0][i]);
    init();
    while(!q.empty()){
        int x=q.front();
        q.pop();
        last[x]=insert(last[fa[x]],str[x]);
        for(int i=0;i<26;i++){
            if(tr[x][i]){
                q.push(tr[x][i]);
            }
        }
    }
    for(int q=1;q<=ps;q++){
        int u=0;
        int n=strlen(s[q]+1);
        v[u]|=(1<<q);
        for(int i=1;i<=n;i++){
            int c=s[q][i]-'a';
            u=ch[u][c];
            v[u]|=(1<<q);
        }
    }
    for(int i=1;i<=tot;i++){
        add(_link[i],i);
    }
    dfs(0);
    printf("%d",ans);
}