Beautiful Pair

发布时间 2023-09-13 19:57:03作者: cqbzwwh

Beautiful Pair

简要题意:

给出一个长度为 \(n\) 的序列。要求它的子串中,满足左右端点之积小于等于子串中的最大值的个数。

思路

肯定要求出不同的 \([l,r]\) 中的最大值。显然一个一个枚举区间会超时。所以考虑当 \(a_i\) 作为最大值的时候,会产生哪些区间,这些区间满足条件(左端点右端点之积小于等于 \(a_i\))的有多少。

第一步是找出当 \(a_i\) 作为最大值时,有哪些区间。这里有两种方法

  1. 用笛卡尔树维护一个大根堆。若以 \(i\) 号点为根的子树的范围为 \([l,r]\),那么左端点\(\in [l,u]\),右端点\(\in [u,r]\) 的区间的最大值就是 \(a[i]\).

  2. 从左到右扫描一遍,找出左边第一个大于 \(a_i\) 的值的位置,设为\(lpos_i\),从右往左扫描一遍,找出右边第一个大于 \(a_i\) 的值的位置,设为\(rpos_i\).可以发现,左端点\(\in [lpos_i+1,u]\),右端点\(\in [u,rpos_i-1]\) 的区间的最大值就是 \(a[i]\).其本质与笛卡尔树是一样的。

显然用\(a_i\) 为最大值的最长的区间表示出所有的\(a_i\) 为最大值的区间

然后用cdq分治的思想。定义一个函数 \(f(l,r)\) 来计算左端点,右端点均在\([l,r]\) 中的合法(符合题意)区间数。

设以 \(a_i\) 为最大值的时候,最长的区间为 \([l,r]\),当前要计算的就是 \(f(l,r)\)

对于左端点在 \([l,i-1]\) 的区间,它们的最大值不关 \(a_i\) 的事,所以直接加上。\(f(l,r) \pm f(l,i-1)\), 右边同理,\(f(l,r) \pm f(i+1,r)\)

重点在于如何计算左端点 \(\in [l,i]\),右端点 \(\in [i,r]\) 的区间。

枚举左边一半,对每一个左边的数 \(x\),若取它为左端点,右端点 \(y\) 必须 \(y \leq \lfloor \frac{a[i]}{x} \rfloor\)

要计算前缀和,用树状数组实现。

但现在时间复杂度仍然不对,考虑用dsu on tree,并且每次只枚举左右区间中较短的那一部分。

时间复杂度为 \(O(n \log^2n)\)

\(Code\):

#include<bits/stdc++.h>
using namespace std;
int n,len;
const int MAXN=1e5+5;
int a[MAXN],b[MAXN];
int st[MAXN],tp;
int rs[MAXN],ls[MAXN];
int siz[MAXN];
void build(){
	for(int i=1;i<=n;i++){
		int k=tp;
		while(k&&a[st[k]]<a[i])	k--;
		if(k)	rs[st[k]]=i;
		if(k<tp)	ls[i]=st[k+1];
		tp=k;
		st[++tp]=i;
	}
}
void pre(int u){
	if(u==0)	return;
	pre(ls[u]),pre(rs[u]);
	siz[u]=siz[ls[u]]+siz[rs[u]]+1;
}
typedef long long ll;
ll c[MAXN];
int lb(int x){
	return x&(-x);
}
void add(int x,int val){
	for(int i=x;i<=n;i+=lb(i))	c[i]+=1ll*val;
}
ll get_sum(int x){
	ll sum=0;
	for(int i=x;i>=1;i-=lb(i))	sum+=c[i];
	return sum;
}
int get(int x){//小于等于 
	return upper_bound(b+1,b+1+len,x)-b-1;
}
ll dfs(int u,int l,int r,bool f){
	if(u==0)	return 0;
	if(l==r){
		if(f)	add(get(a[u]),1);
		return (a[u]==1);
	}
	ll sum=0;
	if(siz[ls[u]]>siz[rs[u]]){
		sum+=dfs(rs[u],u+1,r,0);
		sum+=dfs(ls[u],l,u-1,1);
		add(get(a[u]),1);
		for(int i=u;i<=r;i++){
			sum+=get_sum(get(a[u]/a[i]));
		}
		for(int i=u+1;i<=r;i++)	add(get(a[i]),1);
	}else{
		sum+=dfs(ls[u],l,u-1,0);
		sum+=dfs(rs[u],u+1,r,1);
		add(get(a[u]),1);
//		printf("%d %d\n",get_sum(4),get_sum(1));
		for(int i=l;i<=u;i++){
			sum+=get_sum(get(a[u]/a[i]));
//			printf("%d %d\n",i,get_sum(get(a[u]/a[i])));
		}
		for(int i=l;i<=u-1;i++)	add(get(a[i]),1);
	}
//	printf("	B%d %d %d %d\n",u,l,r,sum);
	if(!f){
		for(int i=l;i<=r;i++)	add(get(a[i]),-1);
	}
//	printf("%d\n",get_sum(2)-get_sum(1));
	return sum;
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
		b[i]=a[i];
	}
	sort(b+1,b+1+n);
	len=unique(b+1,b+1+n)-b-1;
	build();
	pre(st[1]);
	printf("%lld",dfs(st[1],1,n,0));
	return 0;
}