区间最大子段和
思路
可以发现,求最大子段和的过程可以分解为许多状态,状态 \([l,r]\) 表示区间 \([l,r]\) 的各项参数,如最大子段和。每个状态 \([l,r]\) 可以由 \([l,\frac{l+r}{2}]\) 和 \([\frac{l+r}{2}+1,r]\) 递推而来。
具体来说,用线段树维护动态区间最大子段和,需要维护几个参数:
-
\(sum(l,r)\) 表示区间 \([l,r]\) 的所有元素之和。
-
\(maxl(l,r)\) 表示区间 \([l,r]\) 以 \(l\) 为左端点的最大子段和。
-
\(maxr(l,r)\) 表示区间 \([l,r]\) 以 \(r\) 为右端点的最大子段和。
-
\(maxn(l,r)\) 表示区间 \([l,r]\) 的最大子段和。
状态转移:
\[sum(l,r)=sum(l,\frac{l+r}{2})+sum(\frac{l+r}{2}+1,r)
\]
\[maxl(l,r)=\max\Big(maxl(l,\frac{l+r}{2}),sum(l,\frac{l+r}{2})+maxl(\frac{l+r}{2}+1,r)\Big)
\]
\[maxr(l,r)=\max\Big(maxr(\frac{l+r}{2}+1,r),sum(\frac{l+r}{2}+1,r)+maxr(l,\frac{l+r}{2})\Big)
\]
\[maxn(l,r)=\max\bigg(\max\Big(maxn(l,\frac{l+r}{2}),maxn(\frac{l+r}{2}+1,r)\Big),maxr(l,\frac{l+r}{2})+maxl(\frac{l+r}{2}+1,r)\bigg)
\]
在线段树上,设节点 \(p\) 表示区间 \([l,r]\),则其左子节点表示区间 \([l,\frac{l+r}{2}]\),其右子节点表示区间 \([\frac{l+r}{2}+1,r]\),所以上面这些公式能够很方便的计算,具体详见下面代码中的 \(\operatorname{pushup}\) 函数。
代码
#include <bits/stdc++.h>
using namespace std;
int n, m, a[1000005];
struct node {
long long l, r, sum, maxl, maxr, maxn;
} t[1000005 << 2];
void pushup(int p) {
t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
t[p].maxl = max(t[p << 1].maxl, t[p << 1].sum + t[p << 1 | 1].maxl);
t[p].maxr = max(t[p << 1 | 1].maxr, t[p << 1 | 1].sum + t[p << 1].maxr);
t[p].maxn = max(max(t[p << 1].maxn, t[p << 1 | 1].maxn), t[p << 1].maxr + t[p << 1 | 1].maxl);
}
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r;
if (l == r) {
t[p].sum = t[p].maxl = t[p].maxr = t[p].maxn = a[l];
return;
}
int mid = l + r >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
pushup(p);
}
void add(int p, int x, int k) {
if (t[p].l == t[p].r) {
t[p].sum = t[p].maxl = t[p].maxr = t[p].maxn = k;
return;
}
int mid = t[p].l + t[p].r >> 1;
if (x <= mid)
add(p << 1, x, k);
else
add(p << 1 | 1, x, k);
pushup(p);
}
node query(int p, int l, int r) {
if (l <= t[p].l && t[p].r <= r)
return t[p];
int mid = t[p].l + t[p].r >> 1;
node ret;
if (r <= mid)
ret = query(p << 1, l, r);
else if (l > mid)
ret = query(p << 1 | 1, l, r);
else {
node ls = query(p << 1, l, r), rs = query(p << 1 | 1, l, r);
ret.sum = ls.sum + rs.sum;
ret.maxl = max(ls.maxl, ls.sum + rs.maxl);
ret.maxr = max(rs.maxr, rs.sum + ls.maxr);
ret.maxn = max(max(ls.maxn, rs.maxn), ls.maxr + rs.maxl);
}
return ret;
}
int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(1, 1, n);
while (m--) {
int op, x, y;
cin >> op >> x >> y;
if (op == 1) {
if (x > y)
swap(x, y);
cout << query(1, x, y).maxn << '\n';
} else
add(1, x, y);
}
return 0;
}