线段树模板

发布时间 2023-10-15 20:44:47作者: du463

线段树理解起来不难,主要是书写起来比较麻烦
这里学的是董晓老师的线段树模板

#include<bits/stdc++.h>
using namespace std;
#define lc p<<1
#define rc p<<1|1
#define N 500005
int n,w[N];
struct node
{
    int l,r,sum,add;//add用于懒标记

}tr[N*4];
//建树,深搜递归的过程
void build(int p,int l,int r){
    tr[p]={l,r,w[l]};
    if(l==r){//是最后的叶子节点就返回
        return ;
    } 
    int m=(l+r)>>1;//不是叶子节点就要继续裂开,往下分
    build(lc,l,m);
    build(rc,m+1,r);
    tr[p].sum=tr[lc].sum+tr[rc].sum;

}
//点的修改
void update(int p,int x,int k){//点修改,从根节点进入,递归找到子节点[x,x]
    if(tr[p].l==x&&tr[p].r==x){//叶子节点直接修改
        tr[p].sum+=k;
        return ;
    }
    int m=(tr[p].l+tr[p].r)>>1;//非叶子节点就需要继续往下分,直到到叶子节点
    if(x<=m){
        update(lc,x,k);//在左子树上就进入左子树
    }
    if(x>m){
        update(rc,x,k);//在右子树上就进入右子树,只会进入左右中的一个

    }
    tr[p].sum=tr[lc].sum+tr[rc].sum;

}

//区间查询
int query(int p,int x,int y){//区间查询,从根节点出发
    if(x<=tr[p].l&&tr[p].r<=y){
        return tr[p].sum;
        //如果是这个区间里的就将数据返回,因为肯定是需要加上这段区间
        //覆盖就返回
    }
    int m=(tr[p].l+tr[p].r)>>1;//不覆盖就继续向下裂开
    int sum=0;
    if(x<=m){
        sum+=query(lc,x,y);
    }
    if(y>m){
        sum+=query(rc,x,y);

    }
    return sum; 
}
//区间修改
//懒标记
void pushup(int p){
    tr[p].sum=tr[lc].sum+tr[rc].sum;

}
void pushdown(int p){
    if(tr[p].add){
        tr[lc].sum+=tr[p].add*(tr[lc].r-tr[rc].l+1);//因为是区间每个数都要加一个add
        tr[rc].sum+=tr[p].add*(tr[rc].r-tr[rc].l+1);
        tr[lc].add+=tr[p].add;
        tr[rc].add+=tr[p].add;
        tr[p].add=0;//因为已经将我们做的懒人标记传给下面了,所以之前打的标记就要取消
    }
}
void update1(int p,int x,int y,int k){
    if(x<=tr[p].l&&tr[p].r<=y){//覆盖则修改
        tr[p].sum+=(tr[p].r-tr[p].l+1)*k;
        tr[p].add+=k;
        return ;

    }
    //不覆盖则裂开
    int m=(tr[p].l+tr[p].r)>>1;
    pushdown(p);//向下更新
    if(x<=m){
        update1(lc,x,y,k);

    }
    if(y>m){
        update1(rc,x,y,k);
    }
    pushup(p);//向上更新
}

int main(){
    int n;
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
    }
    build(1,1,n);//建树
    //之后就是关于一系列的查询与修改

}