SA&SAM 不怎么详细的详解

发布时间 2023-07-07 17:09:27作者: by_chance

后缀数组(SA):将一个字符串的所有后缀排序得到的数组。
算法:倍增+双关键字基数排序。
算法流程:

  • 首先对所有字符排序,记下每个位置的排名。
  • 将相邻两个字符看作一个整体,用他们的两个排名分别作为两个关键字排序。
  • 将相邻两个“两个字符”看作一个整体,用他们的两个排名分别作为两个关键字排序。
  • ……

Code:(洛谷模板P3809)

#include<bits/stdc++.h>
using namespace std;
const int N=1e6+5;
int n,m,x[N],y[N],c[N],sa[N];
char s[N];
int main(){
	scanf("%s",s+1);
	n=strlen(s+1);m=300;//m表示值域,开到ASCLL码范围之外肯定可以
	for(int i=1;i<=n;i++){x[i]=s[i];++c[x[i]];}
	for(int i=2;i<=m;i++)c[i]+=c[i-1];
	for(int i=n;i>=1;i--)sa[c[x[i]]--]=i;
	//这三行表示第一次基数排序
	//再理解一下基数排序的过程:第一个循环记下每个第一关键字的数目,
	//第二个循环求前缀和,第三个循环记录排名
	for(int k=1;k<=n;k<<=1){
		int num=0;
		for(int i=n-k+1;i<=n;i++)y[++num]=i;
		for(int i=1;i<=n;i++)if(sa[i]>k)y[++num]=sa[i]-k;
		//取出第二关键字,后半段直接当作极小值,前半段用上一次的sa更新
		for(int i=1;i<=m;i++)c[i]=0;
		for(int i=1;i<=n;i++)c[x[i]]++;
		for(int i=2;i<=m;i++)c[i]+=c[i-1];
		for(int i=n;i>=1;i--){sa[c[x[y[i]]]--]=y[i];y[i]=0;}
		//这三行是基数排序,和前面类似
		swap(x,y);num=1;x[sa[1]]=1;
		for(int i=2;i<=n;i++){
			if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])
				x[sa[i]]=num;//如果两个关键字都相同,那么排名也相同
			else x[sa[i]]=++num;
		}
		//求下一次的第一关键字
		if(num==n)break;m=num;
	}
	for(int i=1;i<=n;i++)printf("%d ",sa[i]);
	return 0;
}

应用:
一个小套路是字符串复制一遍接在后面。
更主要的应用需要联系 \(height\) 数组:\(height[i]=LCP(sa[i-1],sa[i])\),这里 \(sa\) 数组理解为字符串,LCP为最长公共前缀。
\(height\) 数组用到引理:\(height[rk[i]] \ge height[rk[i-1]]-1\),然后暴力即可。
Code:

void LCP(){
	int k=0;
	for(int i=1;i<=n;i++)rk[sa[i]]=i;//rk[i]表示i的排名
	for(int i=1;i<=n;i++){
		if(rk[i]==1)continue;if(k)--k;//至多减掉1,然后暴力向后找
		int j=sa[rk[i]-1];
		while(i+k<=n&&j+k<=n&&s[i+k]==s[j+k])++k;
		h[rk[i]]=k;
	}
}

后缀自动机(SAM):压缩一个字符串所有子串的一个自动机。

对任一结点,以它为终点的路径会构成一个子串,并且这些子串是其中最长那个的若干个连续后缀(例:aaaaba,aaaba,aaba,aba)。每个结点记:

  • maxlen,minlen:表示该点代表的最长字符串长和最短字符串长。
  • trans[]:转移函数。
  • slink:后缀链接。指向比这个点的最短字符串恰好短1的字符串(去掉开头,在上面的例子里就是ba)所在的结点。
  • endpos:该结点的所有字符串在原字符串中的出现位置的集合(记结尾)。可以理解为结点就是按照endpos拆分的。但是记集合肯定炸掉,所以实际上记的是集合大小。
  • pre:该点的最长字符串是否为原字符串的一个前缀。

构造算法:增量法。

  • 先给一个新点 \(z\),从代表整个(已完成的)字符串的点 \(u\) 开始,按照slink往前跳。如果相应的转移函数 \(trans[u][c]\) 为空,就指向加入的新点,然后接着跳。跳到自动机的起点,或不再满足条件时停止。如果跳到起点,那么新点 \(z\) 的minlen为1,slink指向起点。
  • 如果没有跳到起点,在停下的位置 \(v\) 检查 该点的对应转移函数 \(trans[v][c]\) 指向的结点 \(x\) 的最长字符串 是否为 停下结点 \(v\) 的最长字符串加上转移函数的字符(即长度是否恰好多1)。如果是,直接将 \(z\) 的slink指向 \(x\) 即可。
  • 剩下的情况,需要在 \(x\) 中拆出一部分 \(y\),让 \(v\) 指向 \(y\)。同时,让 \(v\) 沿着slink向前跳,满足 \(trans[v][c]==x\) 的都改为指向 \(y\)。至于 \(y\) 的参数,maxlen[y]=maxlen[v]+1,trans[y]=trans[x],slink[y]=slink[x],minlen[y]=maxlen[slink[y]]+1。以及minlen[x]=minlen[z]=maxlen[y]+1,slink[x]=slink[z]=slink[y]

下面来求endpos。显然slink会构成一棵树,而一个点 \(u\) 的endpos肯定包含了所有slink指向 \(u\) 的点的endpos。这可能不完全,但是会发现:只有当 \(pre[u]=1\) 时还需要再加一,否则不需要。那么可以构造完SAM后,自底向上做拓扑排序,递推即可。

Code:(洛谷模板P3804)

#include<bits/stdc++.h>
using namespace std;
const int N=2e6+5;
char s[N];
int n,m,st,ans[N];long long res;
int deg[N];queue<int> Q;
int maxlen[N],minlen[N],trans[N][30],slink[N],pre[N],endpos[N];
int new_state(int _maxlen,int _minlen,int*_trans,int _slink,int _pre){
    maxlen[m]=_maxlen;minlen[m]=_minlen;
    for(int i=1;i<=26;i++){
        if(_trans==NULL)trans[m][i]=-1;
        else trans[m][i]=_trans[i];
	}
    slink[m]=_slink;pre[m]=_pre;
    return m++;
}
int add_char(char ch,int u){
	int c=ch-'a'+1,z=new_state(maxlen[u]+1,-1,NULL,-1,1),v=u;
	while(v!=-1&&trans[v][c]==-1){trans[v][c]=z;v=slink[v];}
	if(v==-1){minlen[z]=1;slink[z]=0;return z;}
	int x=trans[v][c];
	if(maxlen[v]+1==maxlen[x]){minlen[z]=maxlen[x]+1;slink[z]=x;return z;}
	int y=new_state(maxlen[v]+1,-1,trans[x],slink[x],0);
	minlen[x]=maxlen[y]+1;slink[x]=y;
	minlen[z]=maxlen[y]+1;slink[z]=y;
	while(v!=-1&&trans[v][c]==x){trans[v][c]=y;v=slink[v];}
	minlen[y]=maxlen[slink[y]]+1;
	return z;
}
int main(){
	scanf("%s",s+1);n=strlen(s+1);
	st=new_state(0,-1,NULL,-1,1);
	for(int i=1;i<=n;i++)st=add_char(s[i],st);--m;
	for(int i=1;i<=m;i++)deg[slink[i]]++;
	for(int i=1;i<=m;i++)if(!deg[i])Q.push(i);
	while(!Q.empty()){
		int u=Q.front();Q.pop();
		if(pre[u])endpos[u]++;
		endpos[slink[u]]+=endpos[u];
		deg[slink[u]]--;
		if(!deg[slink[u]])Q.push(slink[u]);
	}
	for(int i=1;i<=m;i++)ans[maxlen[i]]=max(ans[maxlen[i]],endpos[i]);
	for(int i=n-1;i>=1;i--){
		ans[i]=max(ans[i],ans[i+1]);
		if(ans[i]>=2)res=max(res,1ll*ans[i]*i);
	}
	printf("%lld\n",res);
	return 0;
}