CF1223F Stack Exterminable Arrays

发布时间 2023-10-23 16:16:57作者: Candycar

CSP-S2023 T2原题在此!!!!

题意:

给一个序列进行栈操作,从左到右入栈,若当前入栈元素等于栈顶元素则栈顶元素出栈,否则当前元素入栈。若进行完操作后栈为空,这说这个序列是可以被消除的。

给你一个长度为\(n\)的序列\(a\),问\(a\)有多少子串是可以被消除的。

数据范围:

\(1\leq q\leq 3\times 10^5,1\leq n\leq 3\times 10^5\)

我们先考虑一种朴素的方法:令 \(dp_i\) 表示以 \(i\) 为结尾的合法串的个数。然后我们考虑怎么转移这个方程。然后我们发现这个东西似乎有点类似于括号匹配问题,则可以考虑令 \(nxt_i\) 表示 \(i\sim nxt_i\) 之间的串是一个合法串,且 \(nxt_i\) 的位置最小。

然后就可以列出转移方程 \(dp_i=dp_{nxt_i+1}+1\)

然后我们考虑怎么快速的求这个 \(nxt_i\):发现如果求出了 \(nxt_{i+1}\) ,若 \(a_{i}=a_{nxt_i+1}\)\(nxt_i=nxt_{i+1}+1\),不然就比较 \(nxt_{nxt_{i+1}}+1\),以此类推

然后我们考虑去优化这个东西。令 \(mp_{i,j}\) 记录 \(i\) 右边的一个 \(pos\) ,满足 [i,pos-1] 是合法的括号序列,那么 \(nxt_i=mp_{i+1,a_i}\)

点击查看代码
#include<bits/stdc++.h>
#define int long long
#define mem(a) memset(a,0,sizeof(a))
#define set(a,b) memset(a,b,sizeof(a))
#define ls p<<1
#define rs p<<1|1
#define pb(x) push_back(x)
#define pt(x) putchar(x)
#define T int t;cin>>t;while(t--)
#define rand RAND
#define LOCAL
using namespace std;
char buf[1<<20],*p1,*p2;
#define gc()(p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?EOF:*p1++)
template<class Typ> Typ &re(Typ &x){char ch=gc(),sgn=0; x=0;for(;ch<'0'||ch>'9';ch=gc()) sgn|=ch=='-';for(;ch>='0'&&ch<='9';ch=gc()) x=x*10+(ch^48);return sgn&&(x=-x),x;}
template<class Typ> void wt(Typ x){if(x<0) putchar('-'),x=-x;if(x>9) wt(x/10);putchar(x%10^48);}
const int inf=0x3f3f3f3f3f;
const int maxn=3e5+5;
const int mod=1e9+7;
int seed = 19243;
unsigned rand(){return seed=(seed*48271ll)%2147483647;}
void file(string s){freopen((s+".in").c_str(),"r",stdin);freopen((s+".out").c_str(),"w",stdout);return;}
int n;
int a[maxn];
int nxt[maxn];
map<int,int>mp[maxn];
int dp[maxn];
signed main(){
	T{
		cin>>n;
		for(int i=1;i<=n;i++)cin>>a[i];
		for(int i=1;i<=n+1;i++)mp[i].clear(),dp[i]=0,nxt[i]=-1;
		for(int i=n;i;i--){
			if(mp[i+1].find(a[i])!=mp[i+1].end()){
				int pos=mp[i+1][a[i]];
				nxt[i]=pos;
				swap(mp[i],mp[pos+1]);
			}
			mp[i][a[i]]=i;
		}
		dp[n]=dp[n+1]=0;
		for(int i=n-1;i>=1;i--){
			if(nxt[i]==-1)continue;
			dp[i]=dp[nxt[i]+1]+1;
		}
		int ans=0;
		for(int i=1;i<=n;i++)ans+=dp[i];
		cout<<ans<<endl;
	}
	return 0;
}