Codeforces 1817F - Entangled Substrings(SA)

发布时间 2023-05-06 20:55:53作者: tzc_wk

为什么赛时不开串串题?为什么赛时不开串串题?为什么赛时不开串串题?为什么赛时不开串串题?为什么赛时不开串串题?

一种 SA 做法,本质上和 SAM 做法等价,但是说来也丢人,一般要用到 SAM 的题我都是拿 SA 过的/wul

考虑将 \(ac\) 看作一个整体。记 \(\text{occ}(S)\)\(S\) 出现位置的集合,那么根据题目条件,对于一组符合条件的字符串 \((a,b,c)\)(显然 \(ab\) 确定之后 \(c\) 唯一),\(\text{occ}(ac)\)\(\text{occ}(b)\) 的差分数组是相同的。而对于 \(S\) 的每一个子串 \(s\)\(\text{occ}(s)\) 最多只有 \(O(n)\) 种本质不同的集合。考虑求出这 \(n\) 种本质不同的集合的差分数组,具体方法是,并查集维护若干个集合,类似于 NOI2015 品酒大会,从大到小枚举 \(k\),对于每个 \(ht_i=k\)\(i\) 合并 \(i\)\(i-1\) 所在集合,同时取一个质数 \(p\) 作为哈希模数和一个哈希底数 \(B\),在集合内部维护 \(H=\sum_{x}B^{sa_x}\) 以及 \(M=\min_{x}sa_x\),这样 \(H·x^{-M}\) 就是这个集合对应差分数组的哈希值。

然后对于每一类哈希值,我们考虑其中的所有等价字符串,显然我们可以将它们分为若干类,每一类表示为,以 \(x\) 开始长度为 \([l,r]\) 的子串(如果有多个 \(x\),我们取最小的那个,具体实现见代码),显然类中的字符串不可能产生贡献,考虑两类字符串 \((x_1,l_1,r_1),(x_2,l_2,r_2)\),它们能够产生多少对合法的 \(a,b,c\),不妨设 \(x_1<x_2\),那么显然我们要选择一个 \(i\in[l_1,r_1],j\in[l_2,r_2]\) 使得 \(ac=s[x_1...x_1+i-1],b=s[x_2...x_2+j-1]\),为了让 \(ac\) 能够接上 \(b\),显然 \(i\) 必须要等于 \(x_2-x_1\),而 \(j\) 可以任意取,选择了 \(i\) 以后,\(a\)\(i-l_1+1\) 种可能,也就是说两类字符串 \((x_1,l_1,r_1),(x_2,l_2,r_2)\) 对答案的贡献为:

\[\begin{cases} 0&(x_2-x_1\notin[l,r])\\ (r_2-l_2+1)·(x_2-x_1-l_1+1)&(x_2-x_1\in[l,r]) \end{cases} \]

二分前缀和即可。这是 trivial 的。

(wjz 锐评这个题是板子题,其实上面很多部分都是 SAM 模板,如果用 SAM 写的可能还真是板子/cy)

时间复杂度 \(O(n\log n)\)

const int MAXN=2e5;
const int BS1=3;
const int BS2=5;
const int MOD1=1004535809;
const int MOD2=924844033;
int qpow(int x,int e,int mod){int ret=1;for(;e;e>>=1,x=1ll*x*x%mod)if(e&1)ret=1ll*ret*x%mod;return ret;}
const int IV1=qpow(BS1,MOD1-2,MOD1);
const int IV2=qpow(BS2,MOD2-2,MOD2);
char s[MAXN+5];int n,sa[MAXN+5],rk[MAXN+5],ht[MAXN+5],buc[MAXN+5],seq[MAXN+5];
void getsa(){
	int vmax=122,gr=0;
	for(int i=1;i<=n;i++)buc[s[i]]++;
	for(int i=1;i<=vmax;i++)buc[i]+=buc[i-1];
	for(int i=n;i;i--)sa[buc[s[i]]--]=i;
	for(int i=1;i<=n;i++){
		if(s[sa[i]]!=s[sa[i-1]])++gr;
		rk[sa[i]]=gr;
	}vmax=gr;
	for(int k=1;k<=n;k<<=1){
		static pii x[MAXN+5];
		for(int i=1;i<=n;i++){
			if(i+k<=n)x[i]=mp(rk[i],rk[i+k]);
			else x[i]=mp(rk[i],0);
		}memset(buc,0,sizeof(buc));
		static int seq[MAXN+5];int num=0;gr=0;
		for(int i=n-k+1;i<=n;i++)seq[++num]=i;
		for(int i=1;i<=n;i++)if(sa[i]>k)seq[++num]=sa[i]-k;
		for(int i=1;i<=n;i++)buc[x[i].fi]++;
		for(int i=1;i<=vmax;i++)buc[i]+=buc[i-1];
		for(int i=n;i;i--)sa[buc[x[seq[i]].fi]--]=seq[i];
		for(int i=1;i<=n;i++){
			if(x[sa[i]]!=x[sa[i-1]])++gr;
			rk[sa[i]]=gr;
		}vmax=gr;if(vmax==n)break;
	}
}
void getht(){
	int k=0;
	for(int i=1;i<=n;i++){
		if(rk[i]==1)continue;if(k)--k;int j=sa[rk[i]-1];
		while(i+k<=n&&j+k<=n&&s[i+k]==s[j+k])++k;
		ht[rk[i]]=k;
	}
}
vector<int>vec[MAXN+5];
int pw1[MAXN+5],pw2[MAXN+5],ipw1[MAXN+5],ipw2[MAXN+5];
int f[MAXN+5],L[MAXN+5],R[MAXN+5],hs1[MAXN+5],hs2[MAXN+5],mn[MAXN+5];
int find(int x){return (!f[x])?x:f[x]=find(f[x]);}
map<pii,vector<tuple<int,int,int> > >strs;
void ins(pii hs,int x,int l,int r){if(l<=r)strs[hs].pb(mt(x,l,r));}
int main(){
	scanf("%s",s+1);n=strlen(s+1);getsa();getht();
	for(int i=2;i<=n;i++)vec[ht[i]].pb(i);
	for(int i=(pw1[0]=pw2[0]=ipw1[0]=ipw2[0]=1);i<=n;i++){
		pw1[i]=1ll*pw1[i-1]*BS1%MOD1;ipw1[i]=1ll*ipw1[i-1]*IV1%MOD1;
		pw2[i]=1ll*pw2[i-1]*BS2%MOD2;ipw2[i]=1ll*ipw2[i-1]*IV2%MOD2;
	}
	for(int i=1;i<=n;i++){L[i]=R[i]=i;mn[i]=sa[i];hs1[i]=pw1[sa[i]];hs2[i]=pw2[sa[i]];}
	for(int i=1;i<=n;i++)ins(mp(1,1),sa[i],max(ht[i],ht[i+1])+1,n-sa[i]+1);
	for(int i=n;i;i--)for(int id:vec[i]){
		int fu=find(id-1),fv=find(id);f[fu]=fv;L[fv]=L[fu];
		hs1[fv]=(hs1[fv]+hs1[fu])%MOD1;hs2[fv]=(hs2[fv]+hs2[fu])%MOD2;
		chkmin(mn[fv],mn[fu]);
		int HS1=1ll*hs1[fv]*ipw1[mn[fv]]%MOD1;
		int HS2=1ll*hs2[fv]*ipw2[mn[fv]]%MOD2;
		ins(mp(HS1,HS2),mn[fv],max(ht[L[fv]],ht[R[fv]+1])+1,i);
	}
	ll res=0;
	for(auto tt:strs){
		vector<tuple<int,int,int> >v=tt.se;sort(v.begin(),v.end());
		vector<ll>s0(v.size()),s1(v.size());
		for(int i=0;i<v.size();i++){
			s0[i]=((!i)?0:s0[i-1])+get<2>(v[i])-get<1>(v[i])+1;
			s1[i]=((!i)?0:s1[i-1])+1ll*(get<2>(v[i])-get<1>(v[i])+1)*get<0>(v[i]);
		}
//		printf("solving:\n");
//		for(tuple<int,int,int>t:v)printf("(%d,%d,%d)\n",get<0>(t),get<1>(t),get<2>(t));
		for(int i=0;i<v.size();i++){
			int x=get<0>(v[i]),L=get<1>(v[i]),R=get<2>(v[i]);
			int l=i+1,r=(int)(v.size())-1,pl=-1,pr=-1;
			while(l<=r){
				int mid=l+r>>1;
				if(get<0>(v[mid])>=x+L)pl=mid,r=mid-1;
				else l=mid+1;
			}
			l=i+1,r=(int)(v.size())-1;
			while(l<=r){
				int mid=l+r>>1;
				if(get<0>(v[mid])<=x+R)pr=mid,l=mid+1;
				else r=mid-1;
			}
			if(pl<=pr&&pl!=-1&&pr!=-1){
				res+=(s1[pr]-((!pl)?0:s1[pl-1]));
				res-=(s0[pr]-((!pl)?0:s0[pl-1]))*(x+L-1);
			}
		}
	}printf("%lld\n",res);
	return 0;
}
/*
aabaabaaabbaabbbaabaabaaabbaabbbaabaabaaabbaabbbaabaabaaabbaabbb
*/
/*
abcdedcba
*/