AT_abc310_e

发布时间 2023-11-12 13:20:26作者: mRXxy0o0

一种极其无脑全是套路的做法(绝对不是因为没想到正解)

关注到要求出每一个区间的值,我们可以迅速联想到从小到大枚举右端点,同时维护前面所有以这个点为右端点的区间值。

根据一般套路,大致发现线段树可以做到这一点。

观察题目给出的变换规律,可以发现:新增一个 \(0\) 等同于赋值为 \(1\);新增一个 \(1\) 等同于取反。

这两个操作都可以很好的用线段树实现。

具体地,从小到大枚举右端点,同时维护每个左端点到这里的值,新增左端点时等同于区间操作,查询也是区间求和。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N=1e6+10;
int n;
ll ans;
char a[N];
struct node{
	int l,r,sum1,sum0,laz1,laz2;
}tr[N<<2];
inline void pushup(int u){
	tr[u].sum0=tr[u<<1].sum0+tr[u<<1|1].sum0;
	tr[u].sum1=tr[u<<1].sum1+tr[u<<1|1].sum1;
}
inline void pushdown(int u){
	if(tr[u].laz1){
		swap(tr[u<<1].sum0,tr[u<<1].sum1);
		swap(tr[u<<1|1].sum0,tr[u<<1|1].sum1);
		tr[u<<1].laz1^=1;
		tr[u<<1|1].laz1^=1;
		if(~tr[u].laz2) tr[u].laz2^=1;
		tr[u].laz1=0;
	}
	if(~tr[u].laz2){
		if(tr[u].laz2){
			tr[u<<1].sum0=tr[u<<1|1].sum0=0;
			tr[u<<1].sum1=tr[u<<1].r-tr[u<<1].l+1;
			tr[u<<1|1].sum1=tr[u<<1|1].r-tr[u<<1|1].l+1;
		}
		else{
			tr[u<<1].sum1=tr[u<<1|1].sum1=0;
			tr[u<<1].sum0=tr[u<<1].r-tr[u<<1].l+1;
			tr[u<<1|1].sum0=tr[u<<1|1].r-tr[u<<1|1].l+1;
		}
		tr[u<<1].laz1=tr[u<<1|1].laz1=0;
		tr[u<<1].laz2=tr[u<<1|1].laz2=tr[u].laz2;
		tr[u].laz2=-1;
	}
}
inline void build(int u,int l,int r){
	tr[u]={l,r,0,0,0,-1};
	if(l==r) return ;
	int mid=l+r>>1;
	build(u<<1,l,mid),build(u<<1|1,mid+1,r);
}
inline void mdf(int u,int l,int r,int k){
	if(l<=tr[u].l&&tr[u].r<=r){
		if(k){
			swap(tr[u].sum0,tr[u].sum1);
			tr[u].laz1^=1;
		}
		else{
			tr[u].sum0=0;
			tr[u].sum1=tr[u].r-tr[u].l+1;
			tr[u].laz1=0;
			tr[u].laz2=1;
		}
		return ;
	}
	pushdown(u);
	int mid=tr[u].l+tr[u].r>>1;
	if(l<=mid) mdf(u<<1,l,r,k);
	if(r>mid) mdf(u<<1|1,l,r,k);
	pushup(u);
}
inline void add(int u,int x,int v){
	if(tr[u].l==tr[u].r){
		tr[u].sum0=(v==0);
		tr[u].sum1=(v==1);
		return ;
	}
	pushdown(u);
	int mid=tr[u].l+tr[u].r>>1;
	if(x<=mid) add(u<<1,x,v);
	else add(u<<1|1,x,v);
	pushup(u);
}
inline int query(int u,int l,int r){
	if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum1;
	pushdown(u);
	int mid=tr[u].l+tr[u].r>>1,res=0;
	if(l<=mid) res+=query(u<<1,l,r);
	if(r>mid) res+=query(u<<1|1,l,r);
	return res;
}
int main(){
	scanf("%d%s",&n,a+1);
	build(1,1,n);
	for(int i=1;i<=n;++i){
		if(i>1){
			mdf(1,1,i-1,a[i]^48);
		}
		add(1,i,a[i]^48);
		ans+=query(1,1,i);
	}
	printf("%lld",ans);
	return 0;
}