线段树

发布时间 2023-05-01 00:49:56作者: 邪童

线段树又称区间树, 是一种基于分治思想的二叉树结构, 每个节点代表一段区间


线段树的每个节点代表一个区间

对于每个内部节点 [l,r] , 它的左儿子是 [l,mid] , 右儿子是 [mid+1,r]

用一维数组存整棵树

\[对于编号为x的节点 \begin{cases} 父节点: [\dfrac{x}{2}] \quad\quad\quad\quad\, x\gg1 \\[1.5ex] 左儿子: 2x \quad\quad\quad\quad\,\,\,\, x\ll1\\[1.5ex] 右儿子: 2x+1 \quad\quad\quad x\ll1|1\\[1.5ex] \end{cases} \]

对于一个长度为 n 的区间, 需要建立大小为 4n 的数组维护

每个节点表示该节点表示区间的某种属性


//线段树操作模板(以维护区间最大值为例)

//结构体存储整棵线段树
struct Node
{
    int l,r;
    int v;	//区间[l,r]的最大值
}tr[N*4]; //空间大小开区间长度四倍

//建树操作
void build (int u,int l,int r)	//构建节点u,其维护的区间是[l,r]
{
    tr[u]={l,r};
    if(l==r)return;	//已经是叶子节点
    int mid=l+r>>1;
    build(u<<1,l,mid),build(u<<1|1,mid+1,r);	//递归构建左右子区间
}

//push_up操作,用子节点信息来更新父节点信息(以维护区间最大值为例)
void push_up (int u)
{
    tr[u].v=max(tr[u<<1].v,tr[u<<1|1].v);
}

//query操作,用来查询某一段区间内的信息(以最大值为例)
int query (int u,int l,int r)	//从节点u开始查询,[l,r]表示需要查询的目标区间
{
    if(tr[u].l>=l&&tr[u].r<=r)return tr[u].v;
    //说明当前节点维护的区间已经被查询区间完全包含,不需要继续向下递归
    
    int res=-0x3f3f3f3f;
    int mid=tr[u].l+tr[u].r>>1;
    if(l<=mid)res=max(res,query(u<<1,l,r));	//递归左子区间
    if(r>mid)res=max(res,query(u<<1|1,l,r));	//递归右子区间
    return res;
}

//modify操作,用来修改某一叶子节点并更新其所有父节点
void modify (int u,int x,int v)	//从节点u开始递归查找,将第x个点的值修改为v
{
    if(tr[u].l==x&&tr[u].r==x)tr[u].v=v;
    else
    {
        int mid=tr[u].l+tr[u].r>>1;
        if(x<=mid)modify(u<<1,x,v);
        else modify(u<<1|1,x,v);
        push_up(u);
    }
}

懒标记

在进行区间修改时, 可以先需要修改的区间打上标记, 等查询到该区间时再将标记传递

借助懒标记, 可以节省大量的时间, 将区间修改的时间复杂度降至 O(logn)

//带懒标记的线段树模板(以区间修改,区间和查询为例)

struct Node
{
    int l,r;
    long long sum,add;	//add即为懒标记,储存子节点需要加上的数(不包括父节点)
}tr[N*4];

void pushup (int u)
{
    tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
}

void pushdown (int u)	//向下传递懒标记
{
    auto &root=tr[u],&left=tr[u<<1],&right=tr[u<<1|1];
    if(root.add)
    {
        left.add+=root.add;
        left.sum+=(long long)(left.r-left.l+1)*root.add;
        right.add+=root.add;
        right.sum+=(long long)(right.r-right.l+1)*root.add;
        root.add=0;
    }
}

void build (int u,int l,int r)
{
    if(l==r)tr[u]={l,r,w[r],0};
    else
    {
        tr[u]={l,r};
        int mid=l+r>>1;
        build(u<<1,l,mid),build(u<<1|1,mid+1,r);
        pushup(u);
    }
}

void modify (int u,int l,int r,int d)	//给[l,r]内每个数加上d
{
    if(tr[u].l>=l&&tr[u].r<=r)
    {
        tr[u].sum+=(long long)(tr[u].r-tr[u].l+1)*d;
        tr[u].add+=d;
    }
    else	//左右子区间要分裂
    {
        pushdown(u);
        int mid=tr[u].l+tr[u].r>>1;
        if(l<=mid)modify(u<<1,l,r,d);
        if(r>mid)modify(u<<1|1,l,r,d);
        pushup(u);
    }
}

long long query (int u,int l,int r)	//查询区间[l,r]的和
{
    if(tr[u].l>=l&&tr[u].r<=r)return tr[u].sum;
    
    pushdown(u);
    int mid=tr[u].l+tr[u].r>>1;
    long long sum=0;
    if(l<=mid)sum=query(u<<1,l,r);
    if(r>mid)sum+=query(u<<1|1,l,r);
    return sum;
}