AT_joisc2015_h 题解

发布时间 2023-12-29 09:53:58作者: Xttttr

传送门

题意:给定长为 \(n\) 的字符串 \(s\),你可以选择一个区间,将区间内的字符从小到大排序,求可以得到的最长回文子串长度,字符集大小为 \(n\)

很有意思的题目。

首先容易做到 \(O(n^3)\)。考虑怎么优化。

我们先考察排序的区间和回文区间的关系。

  1. 如果两个区间无交,那么显然排序不会对回文串长度有影响。
  2. 如果排序区间包含了回文区间,那么答案就是最多的相同字符数,容易求出。
  3. 剩下的情况可以根据两个区间的中点的关系分成两种,这两种可以通过把回文串翻转来转化,于是只用考虑一种。

不妨设排序区间的中点在回文串中点右侧,根据排序区间和回文串右半部分的关系,可以分为 4 种,但这 4 种可以一起考虑。(下面的图都来自 官方题解
img

具体地,对于前两种情况,我们枚举回文串中心,暴力往外拓展找最长的回文串,然后考虑两边的贡献;对于后两种情况,我们枚举排序区间左侧的前一个数作为的左端点,往右找到最近的比左端点前的数小的第一个数,那么最终回文串中心的一段都是这个数,然后再考虑两边的贡献。这时两种情况后面的处理大致相同了。

我们从左端往左找最长的不下降连续段,记录每个数的出现次数,然后从右端点开始往右扫。对于前两种情况,遇到小于左端点的数就停止;对于后两种情况,直到遇到比第一个遇到的比左端点小的数更小就停止。对于前两种情况,在扫描的过程中,两边的贡献就是右边从小到大第一个出现次数超过左边的数之前的数的个数;对于后两种情况,贡献就是比左端点小的数的个数加上右边从小到大第一个出现次数超过左边的数之前的数的个数(不算比左端点小的数)。

说起来比较抽象,看图就好理解了。
img
左侧有一个 \(2\)、两个 \(4\)、三个 \(6\),右边有两个 \(1\)、一个 \(2\)、三个 \(4\),右边的三个 \(4\) 已经比左侧的两个 \(4\) 要多了,那么能做贡献的就只有在中间的两个 \(1\),和两边的一个 \(2\)、两个 \(4\),回文串长度就是 \(3+2+3=8\)

需要注意的是对于第一种情况,排序区间的右端点往右的部分也可能和左边形成回文串,提前预处理出 \(f[i][j]\) 表示从 \(i\) 往左,\(j\) 往右最长的相等串的长度即可。

复杂度 \(O(n^2)\)

代码也还好。

const int N=3051;
int n,c;
int a[N],ans;
int f[N][N],cntl[N],cntr[N];
inline int calc(int l,int r,int lim=0){
    memset(cntl,0,sizeof(cntl));memset(cntr,0,sizeof(cntr));
    int L=l,maxn=0,cntmid=0,ansl=1,ansr=0;
    cntl[a[l]]++;
    while(a[L-1]>=a[L])cntl[a[--L]]++,ansl++;
    int posl=0,posr=a[L];
    for(int i=r;i<=n;i++){
        cntr[a[i]]++;
        if(a[i]<lim)return maxn;
        if(a[i]==lim)cntmid++;
        else if(cntr[a[i]]>cntl[a[i]]){
            while(posr>a[i]){
                ansl-=cntl[posr--];
            }
        }
        else{
            while(posl<=c&&cntl[posl]<=cntr[posl]){
                ansr+=cntl[posl++];
            }
        }
        int cur=min(ansr+cntr[posl],ansl);
        if(cur+cntmid==i-r+1)cur+=f[l-cur][i+1];
        maxn=max(maxn,2*cur+cntmid);
    }
    return maxn;
}
inline void solve(){
    for(int i=1;i<=n;i++){
        for(int j=i;j<=n;j++){
            if(a[i]==a[j])f[i][j]=f[i-1][j+1]+1;
            else f[i][j]=0;
        }
    }
    for(int i=1;i<=n;i++){
        ans=max(ans,2*f[i][i]-1+calc(i-f[i][i],i+f[i][i]));
    }
    for(int i=1;i<n;i++){
        ans=max(ans,2*f[i][i+1]+calc(i-f[i][i+1],i+1+f[i][i+1]));
    }
    for(int i=1;i<=n;i++){
        int minn=0;
        for(int j=i+1;j<=n;j++){
            if(a[j]<a[i]){
                minn=a[j];
                break;
            }
        }
        ans=max(ans,calc(i,i+1,minn));
    }
}
int main(){
    ios::sync_with_stdio(false);cin.tie(0),cout.tie(0);
    cin>>n>>c;
    for(int i=1;i<=n;i++)cin>>a[i];
    solve();
    reverse(a+1,a+n+1);
    for(int i=1;i<=n;i++)a[i]=c-a[i]+1;
    solve();
    sort(a+1,a+n+1);
    int cnt=0;
    for(int i=1;i<=n;i++){
        if(a[i]!=a[i-1])cnt=1;
        else cnt++,ans=max(ans,cnt);
    }
    cout<<ans<<endl;
    return 0;
}