CF1827C Palindrome Partition 题解

发布时间 2023-11-29 18:19:32作者: Farmer_D

题目链接

点击打开链接

题目解法

首先考虑一个朴素的 \(dp\)
\(f_i\) 表示以 \(i\) 结尾的合法子串的个数
为了不重不漏,我们令 \(le_i\) 表示以 \(i\) 为右端点,离 \(i\) 最近的偶回文串的左端点,然后不难得到转移为 \(f_i=f_{le_i-1}+1\)
为什么不会漏?考虑以 \(i\) 为右端点,且比最短回文串长的子串一定可以拆分成若干个回文串

\(le_i\) 可以二分哈希求出最长的回文串 \([i-l,i+l)\),然后线段树覆盖即可维护
时间复杂度 \(O(n\log n)\)

#include <bits/stdc++.h>
#define F(i,x,y) for(int i=(x);i<=(y);i++)
#define DF(i,x,y) for(int i=(x);i>=(y);i--)
#define ms(x,y) memset(x,y,sizeof(x))
#define SZ(x) (int)x.size()-1
#define pb push_back
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
typedef pair<int,int> pii;
template<typename T> void chkmax(T &x,T y){ x=max(x,y);}
template<typename T> void chkmin(T &x,T y){ x=min(x,y);}
inline int read(){
    int FF=0,RR=1;
    char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
    for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
    return FF*RR;
}
const int N=500100;
int n,seg[N<<2],f[N];
char str[N];
struct Hash{
    int P;
    ull p[N],h[2][N];
    void build(){
        p[0]=1;
        F(i,1,n) p[i]=p[i-1]*P; 
        F(i,1,n) h[0][i]=h[0][i-1]*P+str[i];
        h[1][n+1]=0;
        DF(i,n,1) h[1][i]=h[1][i+1]*P+str[i];
    }
    ull gethash(int seq,int l,int r){
        if(!seq) return h[seq][r]-h[seq][l-1]*p[r-l+1];
        else return h[seq][l]-h[seq][r+1]*p[r-l+1];
    }
}ha[2];
void pushdown(int x){ if(seg[x]!=-1) seg[x<<1]=seg[x<<1^1]=seg[x],seg[x]=-1;}
void modify(int l,int r,int x,int L,int R,int v){
    if(L<=l&&r<=R){ seg[x]=v;return;}
    pushdown(x);
    int mid=(l+r)>>1;
    if(mid>=L) modify(l,mid,x<<1,L,R,v);
    if(mid<R) modify(mid+1,r,x<<1^1,L,R,v);
}
int le[N];
void get_seg(int l,int r,int x){
    if(l==r){ le[l]=seg[x];return;}
    pushdown(x);
    int mid=(l+r)>>1;
    get_seg(l,mid,x<<1),get_seg(mid+1,r,x<<1^1);
}
const int inf=1e9;
void work(){
    n=read();scanf("%s",str+1);
    ha[0].build(),ha[1].build();
    F(i,1,n<<2) seg[i]=-1;
    F(i,2,n){
        int l=0,r=n+1;
        while(l<r-1){
            int mid=(l+r)>>1;
            if(!(i-mid>=1&&i+mid-1<=n)){ r=mid;continue;}
            if(ha[0].gethash(0,i,i+mid-1)==ha[0].gethash(1,i-mid,i-1)&&
               ha[1].gethash(0,i,i+mid-1)==ha[1].gethash(1,i-mid,i-1)) l=mid;
            else r=mid;
        }
        // cout<<"!!!"<<i<<' '<<l<<'\n';
        if(l) modify(1,n,1,i,i+l-1,i);
    }
    get_seg(1,n,1);
    // F(i,1,n) cout<<le[i]<<' ';cout<<'\n';
    F(i,1,n) f[i]=0;
    F(i,1,n) if(le[i]!=-1) f[i]=f[le[i]*2-i-2]+1;
    LL ans=0;
    F(i,1,n) ans+=f[i];
    printf("%lld\n",ans);
}
int main(){
    ha[0].P=131,ha[1].P=131;
    int T=read();
    while(T--) work();    
    fprintf(stderr,"%d ms\n",int(1e3*clock()/CLOCKS_PER_SEC));
    return 0;
}