CF653F Paper task

发布时间 2023-05-06 02:41:18作者: zzxLLL

题意:给定一个仅包含左右括号的字符串,求其中合法的括号子串种数

一眼看上去没思路,先从简单的问题看起。


如何判断一个括号序列是合法的?

将左括号变成 \(1\) ,右括号变成 \(-1\) ,然后得到前缀和数组 \(pre\)。如果 \(pre\) 中没有负数说明括号序列合法。


怎么求合法的括号子串个数

考虑枚举左端点 \(L\) ,然后找其最远可能对应的右端点 \(R\) ,即不出现右括号个数大于左括号个数的情况。

怎么找?上面的结论可以转化一下:对于合法括号子串 \(S[l \cdots r]\)

1.\(S_l\)(

2.\(\forall i \in [l , r] , pre_i - pre_l \ge -1\) ,即 \(\min\limits_{i \in [l , r]}{pre_i} - pre_l \ge -1\)

3.\(pre_r - pre_l = -1\)

可以对 \(pre\) 数组维护一个 ST 表,通过第二条性质二分查询其右端点,再在左右端点之间寻找 \(pre_x = pre_l - 1\)\(x\) 的个数即可。

于是问题转化为给定区间内有多少个数等于某给定值。于是你想到了这题,并想到可以对于每一个值开一个 \(vector\)

vector<int>v[M<<1];
#define lwb(t,x) lower_bound(v[t].begin(),v[t].end(),x)
#define upb(t,x) upper_bound(v[t].begin(),v[t].end(),x)
int query(int l,int r,int val){ // [l,r] 中 val 的出现次数 
	return upb(val+M,r)-lwb(val+M,l);
}

\(v_t\) 存储的是所有等于 \(t\) 的数的下标,单调递增。\(t\) 可能为负怎么办?查询时给每个 \(t\) 加上 \(5 \times 10^5\) 即可。


然而题目要求的是种数,怎么办?

上面会得出来很多相同的字符串,我们现在要做的是去重

如何去重?全部算完再一个个去重太麻烦,考虑在计算时就不算重复的情况。这个时候你想到了字符串去重的板子,于是使用后缀数组

考虑每个后缀对答案的贡献。对于排名为 \(i\) 的后缀,它的前 \(height_i\) 个字符已经被前 \(i-1\) 个后缀计算过了,所以真正新增的字符串的右端点\([sa_i + height_i , n]\) 中。

要求新增的字符串对答案的贡献。结合上面枚举左端点 \(L\) 找右端点 \(R\) 的做法可知:令 \(r_i\)\(sa_i\) 作为左端点时找到的最远右端点,那么排名为 \(i\) 的后缀的贡献为 \([sa_i + height_i , r_i]\) 中值为 \(pre_{sa_i} - 1\) 的数的个数。将贡献累加即可。


时间复杂度 \(O(n \log{n})\) ,这题也没什么细节,十分好写。

#include<cmath>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int M=5e5+10;

int len,pre[M],st[M][21];
char str[M];
int query(int l,int r){
	int k=log2(r-l+1);
	return min(st[l][k],st[r-(1<<k)+1][k]);
}

int x[M],y[M],sa[M],buc[M],ht[M];
void SA(){
	int tot=127;
	for(int i=1;i<=len;i++) buc[x[i]=str[i]]++;
	for(int i=2;i<=tot;i++) buc[i]+=buc[i-1];
	for(int i=len;i>=1;i--) sa[buc[x[i]]--]=i;
	
	for(int k=1;k<=len;k<<=1){
		int cur=0;
		for(int i=len-k+1;i<=len;i++) y[++cur]=i;
		for(int i=1;i<=len;i++)
			if(sa[i]>k) y[++cur]=sa[i]-k;
		
		for(int i=1;i<=tot;i++) buc[i]=0;
		for(int i=1;i<=len;i++) buc[x[i]]++;
		for(int i=2;i<=tot;i++) buc[i]+=buc[i-1];
		for(int i=len;i>=1;i--) sa[buc[x[y[i]]]--]=y[i];
		
		swap(x,y),x[sa[1]]=cur=1;
		for(int i=2;i<=len;i++)
			if(y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k]) x[sa[i]]=cur;
			else x[sa[i]]=++cur;
		if(cur==len) break;
		tot=cur;
	}
}
void getheight(){
	for(int i=1,cur=0;i<=len;i++){
		if(x[i]==0) continue;
		if(cur) cur--;
		while(str[i+cur]==str[sa[x[i]-1]+cur]) cur++;
		ht[x[i]]=cur;
	}
}

vector<int>v[M<<1];
#define lwb(t,x) lower_bound(v[t].begin(),v[t].end(),x)
#define upb(t,x) upper_bound(v[t].begin(),v[t].end(),x)
int query(int l,int r,int val){ // [l,r] 中 val 的出现次数 
	return upb(val+M,r)-lwb(val+M,l);
}

int find_r(int p){
	int lb=p,rb=len,ans=p;
	while(lb<=rb){
		int mid=(lb+rb)>>1;
		if(query(p,mid)-pre[p]>=-1) ans=mid,lb=mid+1;
		else rb=mid-1;
	}
	return ans;
}

int main(){
	scanf("%d",&len);
	scanf(" %s",str+1);
	
	for(int i=1;i<=len;i++)
		st[i][0]=pre[i]=pre[i-1]+(str[i]=='('?1:-1);
	for(int j=1;j<21;j++)
		for(int i=1;i+(1<<j)-1<=len;i++)
			st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
	
	SA(),getheight();
	for(int i=1;i<=len;i++) v[pre[i]+M].push_back(i);
	
	long long ans=0;
	for(int i=1,L,R;i<=len;i++){
		if(str[sa[i]]==')') break;
		L=sa[i]+ht[i],R=find_r(sa[i]);
		if(L<=R) ans+=query(L,R,pre[sa[i]]-1);
	}
	printf("%lld",ans);
	return 0;
}