CF613E Puzzle Lover 思考--zhengjun

发布时间 2023-07-28 20:29:14作者: A_zjzj

题很简单,一遍写对却比较困难。

犯的错误:

  • 预处理 \({base}^i\) 时应该要处理到 \(\max\{n,m\}\)

  • 去重的时候(reduce 函数)特判 \(m=1,2\)

代码

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int N=2e3+10,mod=1e9+7,base=23333;
int n,m;
char a[2][N],b[N];
int pw[N],f[2][N],g[2][N],pre[N],suf[N];
int Hash1(int t,int l,int r){
	return (f[t][r]+1ll*(mod-pw[r-l+1])*f[t][l-1])%mod;
}
int Hash2(int t,int l,int r){
	return (g[t][l]+1ll*(mod-pw[r-l+1])*g[t][r+1])%mod;
}
int merge(int x,int y,int len){
	return (1ll*x*pw[len]+y)%mod;
}
int ans,dp[N][N][2];
void solve(){
	memset(dp,0,sizeof dp);
	for(int i=1;i<=m;i++){
		pre[i]=(1ll*pre[i-1]*base+b[i])%mod;
	}
	for(int i=m;i>=1;i--){
		suf[i]=(suf[i+1]+1ll*pw[m-i]*b[i])%mod;
	}
	for(int i=0;i<=n;i++)dp[i][0][0]=dp[i][0][1]=1;
	for(int i=1;i<=n;i++){
		for(int j=1;j<=m;j++){
			for(int t=0;t<2;t++){
				if(b[j]!=a[t][i])continue;
				(dp[i][j][t]+=dp[i-1][j-1][t])%=mod;
				if(j>1&&b[j-1]==a[!t][i])(dp[i][j][t]+=dp[i-1][j-2][!t])%=mod;
			}
		}
		for(int t=0;t<2;t++){
			for(int len=2;len<=i;len++){
				int l=i-len+1,r=i,j=len*2;
				if(len*2>m)continue;
				if(merge(Hash2(!t,l,r),Hash1(t,l,r),len)==pre[j])
					++dp[i][j][t]%=mod;
			}
		}
	}
	for(int i=1;i<=n;i++){
		for(int t=0;t<2;t++){
			(ans+=dp[i][m][t])%=mod;
			for(int len=2;len<=n-i+1;len++){
				int l=i,r=i+len-1;
				if(len*2>m)continue;
				if(merge(Hash1(t,l,r),Hash2(!t,l,r),len)==suf[m-len*2+1])
					(ans+=dp[i-1][m-len*2][t])%=mod;
			}
		}
	}
}
void reduce(){
	if(m>1){
		if(m&1)return;
		for(int i=1;i<=n;i++){
			for(int t=0;t<2;t++){
				int len=m/2,l=i,r=i+len-1;
				if(r>n)continue;
				if(merge(Hash1(t,l,r),Hash2(!t,l,r),len)==pre[m])ans--;
			}
		}
		if(m==2)return;
		for(int i=1;i<=n;i++){
			for(int t=0;t<2;t++){
				int len=m/2,l=i-len+1,r=i;
				if(l<1)continue;
				if(merge(Hash2(!t,l,r),Hash1(t,l,r),len)==pre[m])ans--;
			}
		}
	}else{
		for(int i=1;i<=n;i++){
			for(int t=0;t<2;t++){
				ans-=b[m]==a[t][i];
			}
		}
	}
	(ans+=mod)%=mod;
}
int main(){
	freopen(".in","r",stdin);
	//freopen(".out","w",stdout);
	scanf("%s%s%s",a[0]+1,a[1]+1,b+1);
	n=strlen(a[0]+1),m=strlen(b+1);
	for(int i=pw[0]=1;i<=max(n,m);i++)pw[i]=1ll*pw[i-1]*base%mod;
	for(int t=0;t<2;t++){
		for(int i=1;i<=n;i++){
			f[t][i]=(1ll*f[t][i-1]*base+a[t][i])%mod;
		}
		for(int i=n;i>=1;i--){
			g[t][i]=(1ll*g[t][i+1]*base+a[t][i])%mod;
		}
	}
	solve();
	reverse(b+1,b+1+m);
	solve();
	reduce();
	cout<<ans;
	return 0;
}