简要题意:
给出一个长度为 \(n\) 的序列。要求它的子串中,满足左右端点之积小于等于子串中的最大值的个数。
思路
肯定要求出不同的 \([l,r]\) 中的最大值。显然一个一个枚举区间会超时。所以考虑当 \(a_i\) 作为最大值的时候,会产生哪些区间,这些区间满足条件(左端点右端点之积小于等于 \(a_i\))的有多少。
第一步是找出当 \(a_i\) 作为最大值时,有哪些区间。这里有两种方法
-
用笛卡尔树维护一个大根堆。若以 \(i\) 号点为根的子树的范围为 \([l,r]\),那么左端点\(\in [l,u]\),右端点\(\in [u,r]\) 的区间的最大值就是 \(a[i]\).
-
从左到右扫描一遍,找出左边第一个大于 \(a_i\) 的值的位置,设为\(lpos_i\),从右往左扫描一遍,找出右边第一个大于 \(a_i\) 的值的位置,设为\(rpos_i\).可以发现,左端点\(\in [lpos_i+1,u]\),右端点\(\in [u,rpos_i-1]\) 的区间的最大值就是 \(a[i]\).其本质与笛卡尔树是一样的。
显然用以 \(a_i\) 为最大值的最长的区间表示出所有的以 \(a_i\) 为最大值的区间
然后用cdq分治的思想。定义一个函数 \(f(l,r)\) 来计算左端点,右端点均在\([l,r]\) 中的合法(符合题意)区间数。
设以 \(a_i\) 为最大值的时候,最长的区间为 \([l,r]\),当前要计算的就是 \(f(l,r)\)。
对于左端点在 \([l,i-1]\) 的区间,它们的最大值不关 \(a_i\) 的事,所以直接加上。\(f(l,r) \pm f(l,i-1)\), 右边同理,\(f(l,r) \pm f(i+1,r)\)
重点在于如何计算左端点 \(\in [l,i]\),右端点 \(\in [i,r]\) 的区间。
枚举左边一半,对每一个左边的数 \(x\),若取它为左端点,右端点 \(y\) 必须 \(y \leq \lfloor \frac{a[i]}{x} \rfloor\)
要计算前缀和,用树状数组实现。
但现在时间复杂度仍然不对,考虑用dsu on tree,并且每次只枚举左右区间中较短的那一部分。
时间复杂度为 \(O(n \log^2n)\)
\(Code\):
#include<bits/stdc++.h>
using namespace std;
int n,len;
const int MAXN=1e5+5;
int a[MAXN],b[MAXN];
int st[MAXN],tp;
int rs[MAXN],ls[MAXN];
int siz[MAXN];
void build(){
for(int i=1;i<=n;i++){
int k=tp;
while(k&&a[st[k]]<a[i]) k--;
if(k) rs[st[k]]=i;
if(k<tp) ls[i]=st[k+1];
tp=k;
st[++tp]=i;
}
}
void pre(int u){
if(u==0) return;
pre(ls[u]),pre(rs[u]);
siz[u]=siz[ls[u]]+siz[rs[u]]+1;
}
typedef long long ll;
ll c[MAXN];
int lb(int x){
return x&(-x);
}
void add(int x,int val){
for(int i=x;i<=n;i+=lb(i)) c[i]+=1ll*val;
}
ll get_sum(int x){
ll sum=0;
for(int i=x;i>=1;i-=lb(i)) sum+=c[i];
return sum;
}
int get(int x){//小于等于
return upper_bound(b+1,b+1+len,x)-b-1;
}
ll dfs(int u,int l,int r,bool f){
if(u==0) return 0;
if(l==r){
if(f) add(get(a[u]),1);
return (a[u]==1);
}
ll sum=0;
if(siz[ls[u]]>siz[rs[u]]){
sum+=dfs(rs[u],u+1,r,0);
sum+=dfs(ls[u],l,u-1,1);
add(get(a[u]),1);
for(int i=u;i<=r;i++){
sum+=get_sum(get(a[u]/a[i]));
}
for(int i=u+1;i<=r;i++) add(get(a[i]),1);
}else{
sum+=dfs(ls[u],l,u-1,0);
sum+=dfs(rs[u],u+1,r,1);
add(get(a[u]),1);
// printf("%d %d\n",get_sum(4),get_sum(1));
for(int i=l;i<=u;i++){
sum+=get_sum(get(a[u]/a[i]));
// printf("%d %d\n",i,get_sum(get(a[u]/a[i])));
}
for(int i=l;i<=u-1;i++) add(get(a[i]),1);
}
// printf(" B%d %d %d %d\n",u,l,r,sum);
if(!f){
for(int i=l;i<=r;i++) add(get(a[i]),-1);
}
// printf("%d\n",get_sum(2)-get_sum(1));
return sum;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
b[i]=a[i];
}
sort(b+1,b+1+n);
len=unique(b+1,b+1+n)-b-1;
build();
pre(st[1]);
printf("%lld",dfs(st[1],1,n,0));
return 0;
}