CF1585F. Non-equal Neighbours

发布时间 2023-05-29 21:21:59作者: xx019

三倍经验:CF1591F. Non-equal NeighboursARC115E - LEQ and NEQ

提供一种力大砖飞的数据结构 \(O(n\log n)\) 做法,非常好写/好调,去掉数据结构部分只有 1k。

定义 \(f_{i,j}\) 表示前 \(i\) 个数,最后一个为 \(j\) 的方案数。显然第 1 维可以压掉,写成 \(f_j\) 的形式。

然后这个东西可以前缀和做到 \(O(\sum a)\)。更具体地说,对于前 \(i-1\) 个数,定义 \(s=\sum\limits_{k=1}^{a_{i-1}}f_k\),加上第 \(i\) 个数之后有 \(f_k=s-f_k\)。这个东西似乎不能优化了。

但是,我们可以发现,对于很多连续的 \(f_k\),他们的值是一样的:对于 \(a_i\ge a_{i-1}\),由于原来从 \(a_{i-1}+1\)\(a_i\) 的这些位置都没有值,所以相当于在最后插入了值为 \(s\) 的一段;对于 \(a_i<a_{i-1}\),相当于舍弃后面一部分 dp 值。

当然,每次剩余的那些段都会把值从 \(v_i\) 变成 \(s-v_i\),但这并不影响。

每次最多加入一段,所以最多 \(n\) 段;每段最多加一次删一次,故时间复杂度 \(O(n\log n)\)。那个 \(\log n\) 是用线段树维护每段具体值的时间。

code:

点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int inf=1e18,mod=998244353;
inline int read(){
	int x=0,f=1;char ch=getchar();
	while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
	while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
	return x*f;
}
int a[500005];
struct Node{
	int l,r;
}q[500005];
struct segtree{
	#define ls p<<1
	#define rs p<<1|1
	#define lson l,mid,ls
	#define rson mid+1,r,rs
	struct Node{
		int s,add,mul;
	}c[2000005];
	void pushup(int p){
		c[p].s=(c[ls].s+c[rs].s+mod)%mod;
	}
	void pushdown(int l,int r,int p){
		if(c[p].mul!=1){
			c[ls].s=(c[ls].s*c[p].mul%mod+mod)%mod;
			c[rs].s=(c[rs].s*c[p].mul%mod+mod)%mod;
			c[ls].add=(c[ls].add*c[p].mul%mod+mod)%mod;
			c[rs].add=(c[rs].add*c[p].mul%mod+mod)%mod;
			c[ls].mul=(c[ls].mul*c[p].mul%mod+mod)%mod;
			c[rs].mul=(c[rs].mul*c[p].mul%mod+mod)%mod;
			c[p].mul=1;
		}
		if(c[p].add!=0){
			int siz=r-l+1,ln=siz-(siz>>1),rn=siz>>1;
			c[ls].s=(c[ls].s+ln*c[p].add%mod+mod)%mod;
			c[rs].s=(c[rs].s+rn*c[p].add%mod+mod)%mod;
			c[ls].add=(c[ls].add+c[p].add+mod)%mod;
			c[rs].add=(c[rs].add+c[p].add+mod)%mod;
			c[p].add=0;			
		}
	}
	void build(int l,int r,int p){
		c[p].add=0;
		c[p].mul=1;
		if(l==r){
			c[p].s=0;
			return;
		}
		int mid=(l+r)>>1;
		build(lson);
		build(rson);
		pushup(p);
	}
	void mul(int l,int r,int p,int L,int R,int k){
		if(L>R)return;
		if(L<=l&&r<=R){
			c[p].s=(c[p].s*k%mod+mod)%mod;
			c[p].add=(c[p].add*k%mod+mod)%mod;
			c[p].mul=(c[p].mul*k%mod+mod)%mod;
			return;
		}
		int mid=(l+r)>>1;pushdown(l,r,p);
		if(L<=mid)mul(lson,L,R,k);
		if(R>mid)mul(rson,L,R,k);
		pushup(p);
	}
	void add(int l,int r,int p,int L,int R,int k){
		if(L>R)return;
		if(L<=l&&r<=R){
			c[p].s=(c[p].s+(r-l+1)*k%mod+mod)%mod;
			c[p].add=(c[p].add+k+mod)%mod;
			return;
		}
		int mid=(l+r)>>1;pushdown(l,r,p);
		if(L<=mid)add(lson,L,R,k);
		if(R>mid)add(rson,L,R,k);
		pushup(p);
	}
	int query(int l,int r,int p,int L,int R){
		if(L>R)return 0;
		if(L<=l&&r<=R)return c[p].s;
		int mid=(l+r)>>1,res=0;pushdown(l,r,p);
		if(L<=mid)res=(res+query(lson,L,R)+mod)%mod;
		if(R>mid)res=(res+query(rson,L,R)+mod)%mod;
		return res;
	}
	#undef ls 
	#undef rs
	#undef lson
	#undef rson
}Tr;
void solve(){
	int n=read(),L=1,R=0,sum=0;
	for(int i=1;i<=n;i++)a[i]=read();
	Tr.build(1,n,1);
	q[++R]=(Node){1,a[1]},sum=(sum+a[1])%mod;
	Tr.mul(1,n,1,R,R,0);Tr.add(1,n,1,R,R,1);
	for(int i=2;i<=n;i++){
		if(a[i]>=a[i-1]){
			Tr.mul(1,n,1,L,R,-1);Tr.add(1,n,1,L,R,sum);
			q[++R]=(Node){a[i-1]+1,a[i]};
			Tr.mul(1,n,1,R,R,0);Tr.add(1,n,1,R,R,sum);
			sum=(sum*a[i]%mod-sum+mod)%mod;
		}
		else{
			int nsum=sum;
			while(L<=R&&q[R].l>a[i])nsum=(nsum-Tr.query(1,n,1,R,R)*(q[R].r-q[R].l+1)%mod+mod)%mod,R--;
			if(L<=R&&q[R].r>a[i])nsum=(nsum-Tr.query(1,n,1,R,R)*(q[R].r-a[i])%mod+mod)%mod,q[R].r=a[i];
			Tr.mul(1,n,1,L,R,-1);Tr.add(1,n,1,L,R,sum);
			sum=(sum*a[i]%mod-nsum+mod)%mod;	
		}
	}
	printf("%lld\n",sum);
}
signed main(){
	int T=1;
	while(T--)solve();
	return 0;
}