[COCI2015-2016#2] VUDU 题解
题意
给一个长度为 \(N\) 的整数序列 \({a}\),对于其一共 \(\frac{N\times (N+1)}{2}\) 个的非空区间,求有多少个区间的平均数大于等于 \(p\)。
暴力做法
显然,可以直接枚举区间左端点 \(j\) 与区间右端点 \(i\),检查区间平均值是否大于等于 \(p\) 即可,用代码实现即为:
for(int i=1;i<=n;i++){
for(int j=1;j<=i;j++){
if(sum[i]-sum[j-1]>=p*(i-j+1)){//乘法避免丢精度
ans++;
}
}
}
用前缀和数组 \(sum\) 优化,时间复杂度为 \(O(n^2)\)。显然,对于 \(2 \times 10^5\) 的数据,该程序会超时。
正解
只要对暴力算法的不等式稍加改变,我们就能得到正解。
对于代码中的不等式,我们可以将不等式转化为:\(sum_i - p \times i - p \geq sum_{j-1} - p \times j\)。
再结合 \(j \leq i\),就显然可以看出公式可以用树状数组优化。
对于每一个 \(i\),我们都将 \(sum_{i-1} - p \times i\) 插入进树状数组中,再将树状数组中小于等于 \(sum_i - p \times i - p\) 的元素个数查询并累加,就可以得到最终答案。
时间复杂度为 \(O(n \log n)\)。
注意:由于不等式中的左右两项可能很大,所以在操作前需要先对其离散化。
代码
#include<stdio.h>
#include<algorithm>
typedef long long ll;
const ll N=2e6+5;
int n,cnt;
ll p,sum[N],a[N],b[N],g[N],c[N];
int lowbit(int x){return x&~x+1;}
void add(int x,ll v){while(x<=cnt) c[x]+=v,x+=lowbit(x);}
ll query(int x){
ll ans=0;
while(x) ans+=c[x],x-=lowbit(x);
return ans;
}
int main(){
scanf("%d",&n);
ll x;
for(int i=1;i<=n;i++){
scanf("%lld",&x);
sum[i]=sum[i-1]+x;
}
scanf("%lld",&p);
for(int i=1;i<=n;i++){
a[i]=sum[i]-p-p*i;//不等式左边项
b[i]=sum[i-1]-p*i;//不等式右边项
g[++cnt]=a[i],g[++cnt]=b[i];
}
std::sort(g+1,g+cnt+1);//离散化排序
cnt=std::unique(g+1,g+cnt+1)-g-1;//离散化去重
ll ans=0;
for(int i=1;i<=n;i++){
int x=std::lower_bound(g+1,g+cnt+1,a[i])-g;
int y=std::lower_bound(g+1,g+cnt+1,b[i])-g;//离散化
add(y,1);
ans+=query(x);
}
printf("%lld",ans);
return 0;
}