manacher算法

发布时间 2023-10-23 16:16:57作者: Aisaka_Taiga

manacher算法

斯♥哈♥学长的博客https://www.cnblogs.com/luckyblock/p/17044694.html#5140558

为什么老师叫他马拉车算法/yiw

简介

我们都知道,求最长回文子串可以枚举每一个开始的点,然后直接一个一个比较就完事,但这样的复杂度是接近 \(O(n^{2})\) 的,不出意外你就会image,那么我们这个时候就需要用到 manacher 算法了,他和 KMP 一样都是基于暴力的思想改进的一种算法,利用了已求得的回文半径加速了比较的过程,使复杂度降到了 \(O(n)\) 级别。

过程

上面已经提到 manacher 算法是解决最长回文子串的问题的。

回文串可以分为两种类型:一种是长度为奇数的回文串,比如 \(12321\),另一种是长度为偶数的回文串,比如: \(123321\),这个时候我们要是想分类讨论的话也太麻烦了,所以我们就开始用一些奇♂技♂淫♂巧:我们在每一个串的首尾和每一个字符之间插♂入一个字符串,就像下面这样:\(!1!2!3!2!1!\)\(!1!2!3!3!2!1!\),这个时候能发现,这俩串都成了长度为奇数的回文串了,那么我们就可以不用分类讨论了。

我们定义 \(p[i]\) 为以 \(i\) 为中心的回文串的最长半径,比如串 \(\%a\%b\%a\%\),其中 \(b\) 为第四个字符,则 \(p[4]=4\)(因为以他为中心的最长回文串是 \(\%a\%b\%a\%\),而这个回文串的半径为 \(4\))。所以我们在原字符串中的这个回文字串长度就是 \(p[i]-1\),为什么嘞?因为我们原字符串中是这样的 \(aba\),设上面修改过的串的长度为 \(n\),则原字符串的长度为 \(\frac{n-1}{2}\),而我们的半径刚好是 \(\frac{n+1}{2}\),半径减去原字符串长度则为 \(\frac{n+1-n+1}{2}\),也就是 \(1\),所以我们就可以知道半径减去 \(1\),就是原字符串的长度啦。(你问我为什么给你证这个?当然是后面的我不会证)接下来我们就要来快速求出 \(p[i]\) 数组了,我们假设 \(pos\) 是一个已经记录的值,R$ 为以 \(pos\) 为中心的回文串的最长右边界,然后现在要求 \(p[i]\)\(j\)\(i\) 关于 \(pos\) 的对称点,也就是说 \(j=2\times pos-i\)

这个时候就分两种情况了,第一种情况就是 \(j\) 的回文串没有跳出 \(L\),也就是这样:

image

因为 \(L\)\(R\) 是一个回文串,所以 \(i\) 的回文和 \(j\) 的回文相同,即 \(p[i]=p[j]\).

第二种情况就是 \(j\) 的最长回文越过了 \(L\),也就是下面这样:

image

如果在这种情况下,\(j\) 的最长回文就不一定是 \(i\) 的最长回文了。然后黄色那块肯定还是回文。

所以综上,\(p[i]=min(p[2\times pos-1],R-i)\)

如果可以的话,继续暴力拓展即可。

复杂度不会证,看上面学长的博客吧。

P3805 【模板】manacher 算法

#include<bits/stdc++.h>
#define int long long
#define N 20001000
using namespace std;
int p[N<<1],ans,cnt;
char s[N<<1];
signed main()
{
	char c=getchar();
	s[0]='~';s[++cnt]='%';
	while(c<'a'||c>'z')c=getchar();
	while(c>='a'&&c<='z')s[++cnt]=c,s[++cnt]='%',c=getchar();
	int r=0,mid=0;
	for(int t=1;t<=cnt;t++)
	{
		if(t<=r)p[t]=min(p[mid*2-t],r-t+1);//如果当前的中心是小于r的话,取两头的min 
		while(s[t-p[t]]==s[t+p[t]])++p[t];//向两边扩展 
		if(p[t]+t>r)r=p[t]+t-1,mid=t;//如果超界就要更新r和mid的值 
		ans=max(p[t],ans);//如果要是当前的半径比答案大就替换 
	}
	cout<<ans-1<<endl;
	return 0;
}

P4555 [国家集训队]最长双回文串

打开标签发现有 manacher 算法所以开始考虑 manacher 算法

但是当我求出 \(p\) 数组的时候,我沉默了,两个的话不好找,所以考虑维护其他的东西。

既然我们是要求两相邻的回文串的最大长度,那么我们可以考虑维护一下以一个点为开始的回文串最大长度和以一个点为结尾的回文串最大长度,但是这就又有一个问题,就是当前点的字符是被算入了两个串里,考虑一下,我们在一开始插入了一些特殊字符,我们只要枚举这些特殊字符的不就好了吗!因为后面删去是不影响的。然后就是该如何维护,首先我们可以在求 \(p\) 的时候顺手维护一部分,因为当 \(i\) 为中点时,左端点为 \(i-p[i]+1\),右端点为 \(i+p[i]-1\),那么我们就可以先对这个求一个max,将其存放到 \(l\)\(r\) 两个数组里面去,但是由于求 \(p\) 的时候都是最大的,所以导致一些短的回文串的开头结尾的数组没有存入,所以我们需要再推一遍,首先 \(r\) 是要倒着枚举的,因为是逐渐向左移,然后,当前的 \(r\) 数组有两种选择,一种是他的右一个字符的最大回文长度减一,因为中间插入了字符,在循环的时候是 \(i+=2\) 的,然后减去一也变成了减去二,下标也是 \(i+2\),最后得出: \(r[i]=\max(r[i],r[i+2]-2)\)\(l\) 数组同理可推得:\(l[i]=\max(l[i],l[i-2]-2)\)。然后最后遍历一遍所有的特殊字符对于 \(l[i]+r[i]\) 取一个 \(\max\) 就好啦!

#include<bits/stdc++.h>
#define int long long
#define N 1000100
using namespace std;
int n,p[N<<1],cnt,l[N<<1],r[N<<1],ans;
char s[N<<1];
signed main()
{
	char c=getchar();
	s[cnt]='~',s[++cnt]='%';
	while(c<'a'||c>'z')c=getchar();
	while(c>='a'&&c<='z')s[++cnt]=c,s[++cnt]='%',c=getchar();
	s[cnt+1]=')';
	int j=0,mid=0;
	for(int i=1;i<=cnt;i++)
	{
		if(i<=j)p[i]=min(p[2*mid-i],j-i+1);
		while(s[i-p[i]]==s[i+p[i]])p[i]++;
		if(p[i]+i>j)j=p[i]+i-1,mid=i;
		l[i-p[i]+1]=max(l[i-p[i]+1],p[i]-1);
		r[i+p[i]-1]=max(r[i+p[i]-1],p[i]-1);
	}
//	cout<<"cao"<<endl;
	for(int i=cnt;i>=1;i-=2)r[i]=max(r[i],r[i+2]-2);
	for(int i=1;i<=cnt;i+=2)l[i]=max(l[i],l[i-2]-2);
	for(int i=1;i<=cnt;i+=2)if(l[i]&&r[i])ans=max(l[i]+r[i],ans);
	cout<<ans<<endl;
	return 0;
} 

P1659 [国家集训队]拉拉队排练

既然是求解回文串的问题,首先可以考虑一下 manacher 算法。

我发现题解中有的是直接在原串上跑 manacher,这样可以直接过滤掉长度为偶数的回文串,如果是要在修改后的串上跑 manacher 的话,我们在统计答案的时候可以直接特判一下,把是以原串里的字符为中心的回文串最大长度给算上,这样在原串里也是长度为奇数的回文串了。

对于排序的问题,我们可以开一个桶 \(t\) 数组来记录一下当前长度的回文串的数量,很容易发现,我们的 \(p[i]\) 里面存放的是以 \(i\) 为中心最大的回文半径,所以我们知道,像 \(abcdedcba\) 这种回文串,不仅只有这一个最长的,当去掉两头各一个字符时,就可以得到了另一个回文串 \(bcdedcb\)

现在考虑如果处理以上的情况。如果是在往桶里面存放的时候用一个循环来一直减二修改的话,像一些以当前点为回文中心的字符串很长的情况就会 T 掉了,所以我们要用到类似于线段树的懒标记的方法,我们在将当前的回文串长度也就是 \(p[i]-1\) 放入桶中的时候,先不要急着给后面的也放入,等到了查询的时候,当前的 \(t[p[i]-1]\) 计算完了以后,我们再给 \(t[p[i]-1-2]\) 加上 \(t[p[i]-1]\) 的值,就相当于长度为 \(p[i]-1\) 的回文串全都去掉了两头各一个字符,变为了 长度为 \(p[i]-1-2\) 的回文串,这样可以使我们的复杂度大大降低,跑的也很快。

\(k\) 的范围较大,要用快速幂,同时注意取模和判无解。

#include<bits/stdc++.h>
#define int long long
#define P 19930726
#define N 1000010
using namespace std;
int n,k,p[N<<1],t[N<<1],cnt,maxn;
char s[N<<1];
inline int ksm(int x,int y)
{
	int sum=1;
	while(y)
	{
		if(y%2==1)
		  sum=(sum*x)%P;
		x=(x*x)%P;
		y=y>>1;
	}
	return sum%P;
}
signed main()
{
	cin>>n>>k;
	char c=getchar();
	s[cnt]='!';s[++cnt]='%';
	while(c<'a'||c>'z')c=getchar();
	while(c>='a'&&c<='z')s[++cnt]=c,s[++cnt]='%',c=getchar();
	s[cnt+1]='~';
	int j=0,mid=0;
	for(int i=1;i<=cnt;i++)
	{
		if(i<=j)p[i]=min(p[2*mid-i],j-i+1);
		while(s[i+p[i]]==s[i-p[i]])p[i]++;
		if(p[i]+i>j)j=p[i]+i-1,mid=i;
		if(s[i]!='%')
		{
			int xx=p[i]-1;
			t[xx]++;
			maxn=max(maxn,p[i]-1);
		}
	}
//	for(int i=1;i<=maxn;i++)
//	  cout<<t[i]<<" ";
//	cout<<endl;
	int ans=1,sum=0;
	for(int i=maxn;i>=1;i--)
	{
		if(sum+t[i]>k)
		{
			ans=(ans*ksm(i,k-sum)%P);
			sum+=t[i];
			break;
		}
		sum+=t[i];
		t[i-2]+=t[i];
		ans=(ans*ksm(i,t[i])%P);
	}
	if(sum<k)
	{
		cout<<"-1"<<endl;
		return 0;
	}
	cout<<ans<<endl;
	return 0;
}

P9606 [CERC2019] ABB

不难发现最坏情况就是把前 \(n-1\) 个字符反转一下复制到后面。

发现后缀的回文串长度会使得添加的字符个数减少,例如样例三中 murderforajarof 实际上最后是在后面复制了 redrum,也就是说我们找到一个最右端为 \(n\) 的最长回文串,\(n\) 减去这个串的长度即为最优方案。

考虑如何计算。

求最大回文想到 manacher 算法,发现求出 \(p\) 数组后,若当前的右端点为最后一个字符的位置就可以计算,也就是当 \(p_i + i - 1 =m\) 时,答案为 \(n - p_i + 1\)


/*
 * @Author: Aisaka_Taiga
 * @Date: 2023-10-23 15:12:00
 * @LastEditTime: 2023-10-23 16:08:55
 * @LastEditors: Aisaka_Taiga
 * @FilePath: \Desktop\P9606.cpp
 * The heart is higher than the sky, and life is thinner than paper.
 */
#include <bits/stdc++.h>

#define int long long
#define N 1000100

using namespace std;

int n, ans, p[N], m;
char s[N], ss[N];

signed main()
{
    cin >> n;
    for(int i = 1; i <= n; i ++) cin >> ss[i];
    s[0] = '~', s[++ m] = '%';
    for(int i = 1; i <= n; i ++)
    {
        s[++ m] = ss[i];
        s[++ m] = '%';
    }
    int r = 0, mid = 0;
    for(int i = 1; i <= m; i ++)
    {
        if(i <= r) p[i] = min(p[2 * mid - i], r - i + 1);
        while(s[i + p[i]] == s[i - p[i]]) p[i] ++;
        if(p[i] + i > r) r = p[i] + i - 1, mid = i;
    }
    ans = n - 1;
    // for(int i = 1; i <= m; i ++) if(s[i] != '%') cout << (p[i] / 2) << " "; cout << endl;
    for(int i = 1; i <= m; i ++)
        if(p[i] + i - 1 == m)
            ans = min(ans, n - p[i] + 1);
    cout << ans << endl;
    return 0;
}