线段树 trick 汇总

发布时间 2023-10-13 21:09:28作者: wangxuzhou

区间最大子段和

模板题(luogu.P4513)

思路

可以发现,求最大子段和的过程可以分解为许多状态,状态 \([l,r]\) 表示区间 \([l,r]\) 的各项参数,如最大子段和。每个状态 \([l,r]\) 可以由 \([l,\frac{l+r}{2}]\)\([\frac{l+r}{2}+1,r]\) 递推而来。

具体来说,用线段树维护动态区间最大子段和,需要维护几个参数:

  • \(sum(l,r)\) 表示区间 \([l,r]\) 的所有元素之和。

  • \(maxl(l,r)\) 表示区间 \([l,r]\)\(l\) 为左端点的最大子段和。

  • \(maxr(l,r)\) 表示区间 \([l,r]\)\(r\) 为右端点的最大子段和。

  • \(maxn(l,r)\) 表示区间 \([l,r]\) 的最大子段和。

状态转移:

\[sum(l,r)=sum(l,\frac{l+r}{2})+sum(\frac{l+r}{2}+1,r) \]

\[maxl(l,r)=\max\Big(maxl(l,\frac{l+r}{2}),sum(l,\frac{l+r}{2})+maxl(\frac{l+r}{2}+1,r)\Big) \]

\[maxr(l,r)=\max\Big(maxr(\frac{l+r}{2}+1,r),sum(\frac{l+r}{2}+1,r)+maxr(l,\frac{l+r}{2})\Big) \]

\[maxn(l,r)=\max\bigg(\max\Big(maxn(l,\frac{l+r}{2}),maxn(\frac{l+r}{2}+1,r)\Big),maxr(l,\frac{l+r}{2})+maxl(\frac{l+r}{2}+1,r)\bigg) \]

在线段树上,设节点 \(p\) 表示区间 \([l,r]\),则其左子节点表示区间 \([l,\frac{l+r}{2}]\),其右子节点表示区间 \([\frac{l+r}{2}+1,r]\),所以上面这些公式能够很方便的计算,具体详见下面代码中的 \(\operatorname{pushup}\) 函数。

代码

#include <bits/stdc++.h>
using namespace std;
int n, m, a[1000005];
struct node {
    long long l, r, sum, maxl, maxr, maxn;
} t[1000005 << 2];
void pushup(int p) {
    t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
    t[p].maxl = max(t[p << 1].maxl, t[p << 1].sum + t[p << 1 | 1].maxl);
    t[p].maxr = max(t[p << 1 | 1].maxr, t[p << 1 | 1].sum + t[p << 1].maxr);
    t[p].maxn = max(max(t[p << 1].maxn, t[p << 1 | 1].maxn), t[p << 1].maxr + t[p << 1 | 1].maxl);
}
void build(int p, int l, int r) {
    t[p].l = l, t[p].r = r;

    if (l == r) {
        t[p].sum = t[p].maxl = t[p].maxr = t[p].maxn = a[l];
        return;
    }

    int mid = l + r >> 1;
    build(p << 1, l, mid);
    build(p << 1 | 1, mid + 1, r);
    pushup(p);
}
void add(int p, int x, int k) {
    if (t[p].l == t[p].r) {
        t[p].sum = t[p].maxl = t[p].maxr = t[p].maxn = k;
        return;
    }

    int mid = t[p].l + t[p].r >> 1;

    if (x <= mid)
        add(p << 1, x, k);
    else
        add(p << 1 | 1, x, k);

    pushup(p);
}
node query(int p, int l, int r) {
    if (l <= t[p].l && t[p].r <= r)
        return t[p];

    int mid = t[p].l + t[p].r >> 1;
    node ret;

    if (r <= mid)
        ret = query(p << 1, l, r);
    else if (l > mid)
        ret = query(p << 1 | 1, l, r);
    else {
        node ls = query(p << 1, l, r), rs = query(p << 1 | 1, l, r);
        ret.sum = ls.sum + rs.sum;
        ret.maxl = max(ls.maxl, ls.sum + rs.maxl);
        ret.maxr = max(rs.maxr, rs.sum + ls.maxr);
        ret.maxn = max(max(ls.maxn, rs.maxn), ls.maxr + rs.maxl);
    }

    return ret;
}
int main() {
    cin >> n >> m;

    for (int i = 1; i <= n; i++)
        cin >> a[i];

    build(1, 1, n);

    while (m--) {
        int op, x, y;
        cin >> op >> x >> y;

        if (op == 1) {
            if (x > y)
                swap(x, y);

            cout << query(1, x, y).maxn << '\n';
        } else
            add(1, x, y);
    }

    return 0;
}