权值线段树 学习笔记

发布时间 2023-09-29 11:11:35作者: Miya555

8月集训学了权值线段树,当时没怎么加强训练。

国庆刚好开始有时间,巩固巩固。补上学习笔记。

首先介绍权值树。其本质是一个记录每个数出现次数的线段树,也就是由桶建成的树。

接下来介绍各种操作。

1.插入。

由于统计的是出现次数,从这个数往上依次加1即可。

void insert(int x,int l,int r,int k){
    //插入一个数k
    if(l==r){
        tr[x]++;
        return;
    }
    int mid=(l+r)/2;
    if(k<=mid) insert(x*2,l,mid,k);
    else insert(x*2+1,mid+1,r,k);
    pushup(x);
}

2.删除。

同上,依次往上减-1。

void del(int x,int l,int r,int k){
    //删除一个数k
    if(l==r){
        tr[x]--;
        return;
    }
    int mid=(l+r)/2;
    if(k<=mid) del(x*2,l,mid,k);
    else del(x*2+1,mid+1,r,k);
    pushup(x);
}

3.查询区间数的数量。同线段树的查询。

int query(int x,int l,int r,int ql,int qr){
    //查询ql,qr之间一共有多少个数
    if(l>=ql&&r<=qr) return tr[x];
    int mid=(l+r)/2,sum=0;
    if(ql<=mid) sum=query(x*2,l,mid,ql,qr);
    if(qr>mid) sum+=query(x*2+1,mid+1,r,ql,qr);
    return sum;
}

4.查询第K大。(单纯的权值线段树只能查询整个区间。区间第K大需要树套树)

int kth(int x,int l,int r,int k){
    if(l==r) return l;//查到了,返回即可
    int mid=(l+r)/2;
    if(k<=tr[x*2]) return kth(x*2,l,mid,k); 
    return kth(x*2+1,mid+1,r,k-tr[x*2]);
}

5.查询排名。往下递归查找,每次加上左子树的值(因为左子树都比这个数要小)

int rnk(int x,int l,int r,int k){
    if(l==r) return 1;
    int mid=(l+r)/2;
    if(k<=mid) return rnk(x*2,l,mid,k);
    return rnk(x*2+1,mid+1,r,k)+tr[x*2];
}

6.前驱后继

前驱实际上就是比n的排名小一位的数,也就是kth(rnk(x)-1)

后继就是n+1的排名位置的数,也就是kth(rnk(x+1))

 7.空间优化。

离散化,动态开点。

咕咕咕

 

 

 

例题。

1. P3369 【模板】普通平衡树

使用了动态开点。

//produced by miya555
//stupid mistakes: query貌似一直出问题
//ideas: 尝试用权值线段树实现平衡树,先试试。
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ls (o<<1)
#define rs (ls|1)
const int N=500000;

int b[N],a[N],val[N],st[N],n,tot;
void add(int o,int l,int r,int k,int pos)
{
    st[o]+=pos;
    if(l==r) return;
    int mid=(l+r)>>1;
    if(k<=mid) add(ls,l,mid,k,pos);
    else add(rs,mid+1,r,k,pos);
}

int query(int o,int l,int r,int k)
{
    if(l==r) return 1;
    int mid=(l+r)>>1;
    if(k<=mid) return query(ls,l,mid,k);
    else return st[ls]+query(rs,mid+1,r,k);
}

int find(int o,int l,int r,int k)
{
    if(l==r) return l;
    int mid=(l+r)>>1;
    if(st[ls]>=k) return find(ls,l,mid,k);
    else return find(rs,mid+1,r,k-st[ls]);
}

main()
{
    cin>>n;
    for(int i=1;i<=n;i++)
    {
        cin>>val[i];
        cin>>a[i];
        if(val[i]!=4)
            b[++tot]=a[i];
    }
    
    sort(b+1,b+tot+1);
    
    for(int i=1;i<=n;i++)
    {
        if(val[i]!=4)
            a[i]=lower_bound(b+1,b+tot+1,a[i])-b;
    }
    for(int i=1;i<=n;i++)
    {
        
        int op=val[i];
         if(op==1){
             add(1,1,tot,a[i],1);
             
         }else if(op==2) {
             add(1,1,tot,a[i],-1);
             
             
         }else if(op==3) {
             cout<<query(1,1,tot,a[i])<<endl;
             
             
         }else if(op==4) {
             cout<<b[find(1,1,tot,a[i])]<<endl;
             
             
         }else if(op==5) {
             cout<<b[find(1,1,tot,query(1,1,tot,a[i])-1)]<<endl;
             
         }else{
             cout<<b[find(1,1,tot,query(1,1,tot,a[i]+1))]<<endl;
         }
        
    }
    return 0;
}

 

2.  P2073  送花

//produced by miya555
//stupid mistakes:
//ideas:搓一搓权值线段树板子
#include <bits/stdc++.h>
using namespace std;
const int N=1e7+10;
struct{
    int c,w;
} tr[8*N];
inline void pushup(int x){
    tr[x].c=tr[x*2].c+tr[x*2+1].c;
    tr[x].w=tr[x*2].w+tr[x*2+1].w;
}
void insert(int x,int l,int r,int k,int p){

    if(l==r){
        if(tr[x].c==0){
            tr[x].c+=p;
            
        }
        return;
    }
    int mid=(l+r)/2;
    if(k<=mid) insert(x*2,l,mid,k,p);
    else insert(x*2+1,mid+1,r,k,p);
    pushup(x);
}
void update(int u,int l,int r,int x,int v)    
{
    if(l==r)
    {
        if(tr[u].c==x) return;        
        tr[u].c=x;                    
        tr[u].w=v;
    }
    else
    {
        int mid=l+r>>1;        
        if(x<=mid) update(u<<1,l,mid,x,v);
        else update(u<<1|1,mid+1,r,x,v);
        pushup(u);
    }
}
void remove(int u,int l,int r,int s)    
{
    if(l==r) tr[u]={0,0};            
    else
    {
        int mid=l+r>>1;
        if(s)                            
        {
            if(!tr[u<<1|1].c) remove(u<<1,l,mid,s);    
            else remove(u<<1|1,mid+1,r,s);            
        }
        else                            
        {
            if(!tr[u<<1].c) remove(u<<1|1,mid+1,r,s);    
            else remove(u<<1,l,mid,s);
        }
        pushup(u);                        
    }
}
int kth(int x,int l,int r,int k){
    if(l==r) return l;
    int mid=(l+r)/2;
    if(k<=tr[x*2].c) return kth(x*2,l,mid,k); 
    return kth(x*2+1,mid+1,r,k-tr[x*2].c);
}

int n,y;
int t;
int main(){
    while(1){
        int opt,x;
        cin>>opt;
        if(opt==-1) break;
        if(opt==1){
            cin>>x>>y;
            update(1,1,1e6,y,x);
            t++;
        }
        else if(tr[1].c==0) continue;
        else if(opt==2) remove(1,1,1e6,1);
        else remove(1,1,1e6,0);
    }
    printf("%d %d\n",tr[1].w,tr[1].c);
}