洛谷P4287 [SHOI2011]双倍回文

发布时间 2023-04-27 21:47:52作者: 文艺平衡树

题目

洛谷P4287 [SHOI2011]双倍回文

思路

回文子串题,马拉车感觉不太好做,那就把回文自动机建出来看看。

好的现在我们有了一个\(PAM\),这个\(PAM\)上储存了所有普通回文子串的信息,然后我们考虑所谓“双倍回文子串”和普通回文串有啥关系。

首先双倍回文子串一定是一个回文串,所以是存在\(PAM\)之内的。然后根据定义,它的长度一定是\(4\)的倍数,所以连长度都不满足的串就不用做进一步的判断了。

然后只要它的前一半也是一个回文串,我们就可以断定它是一个“双倍回文子串”了。

根据这个思路,我们直接检查\(PAM\)里的所有节点就可以了。

现在的瓶颈就是:对于一个结点\(P\),它代表了一个长度是\(4\)的倍数的回文串\(A\),我们怎么去取出它的前一半。

我们可以在每次插入字母时,将这个字母在\(s\)中的位置记录到它所对应的节点里(注意到一个节点虽然只对应一个本质不同的回文串,但是它的位置可以不唯一,所以我们要开个vector,我选择记录右端点)。

只需要拿这个位置减去\(|A|/2\),就得到了前半个串的右端点(如何找到这个点在\(PAM\)中的对应节点?我们只要在insert的时候加一个返回值就可以了)。

方便起见,假设这个右端点的对应节点为Q。

只要\(Q.len==|A|/2\),这个串就合法,对答案有贡献,看起来非常正确。

于是我们写完了交上去,发现错了一个点(是的我就是这么写的,然后debug了半天qwq)。

问题就在于这个条件是充分不必要条件,\(Q\)对应的串很有可能本身是一个更长的回文(设为串\(B\))。

举个例子,\(s="abbaccabbaabba",A="abbaabba",B="abbaccabba"\),这个时候\(Q.len\)等于10,上述程序就会错误地认为\(A\)不合法。

所以我们需要沿fail指针接着跳,只要\(B\)的一个后缀满足长度为\(|A|/2\)就可行。

排除了这个坑点,这道题就做完了。

代码

点击查看代码
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<vector>
#define ll long long
#define maxn 500010
using namespace std;
struct PAM{
	int tot,last;
	char *s;
	struct node{
		int len,fail;
		int ch[27];
		vector<int> right;
	} p[maxn];
	void init(char *t){
		tot=last=1;
		p[0].fail=1;
		p[1].len=-1;
		s=t;
	}
	int getfail(int pos,int tmp){
		while(tmp-p[pos].len-1<0||s[tmp-p[pos].len-1]!=s[tmp]) pos=p[pos].fail;
		return pos;
	}
	int insert(int x,int tmp){
		int cur=getfail(last,tmp);
		int now=p[cur].ch[x];
		if(!now){
			now=++tot;
			p[now].len=p[cur].len+2;
			p[now].fail=p[getfail(p[cur].fail,tmp)].ch[x];
			p[cur].ch[x]=now;
		}
		last=now;
		p[now].right.push_back(tmp);
		return now;
	}
} P;
int id[maxn];
int check(int pos){
	int i;
	PAM::node t=P.p[pos];
	if(t.len%4) return 0;
	for(i=0;i<t.right.size();++i){
		int q=t.right[i]-t.len/2;
		int z=id[q];
		while(P.p[z].len>t.len/2) z=P.p[z].fail;
		if(P.p[z].len==t.len/2) return 1;
	}
	return 0;
}
int main(){
	int n,m,i,j;
	char s[maxn];
	int ans=0;
	// freopen("test.in","r",stdin);
	// freopen("test.out","w",stdout);
	scanf("%d%s",&n,s);
	P.init(s);
	for(i=0;i<n;++i){
		id[i]=P.insert(s[i]-'a'+1,i);
	}
	for(i=2;i<=P.tot;++i){
		if(check(i)) ans=max(ans,P.p[i].len);
	}
	printf("%d",ans);
	return 0;
}