【线段树优化 dp】AT_dp_w Intervals 题解

发布时间 2023-11-21 08:21:34作者: Pengzt

AT_dp_w

先不看数据范围,考虑 dp。

\(f_i\) 表示前 \(i\) 个字符且强制第 \(i\) 个字符为 \(1\) 的最大分数。

\(f_i = \max(f_{j - 1} +\sum\limits_{r_k\ge i\ge l_k\ge j}a_k)\)

这个是一份 \(\mathcal{O}(n^2m)\) 暴力 dp 的代码片段:

ll ans = 0;
for(int i = 1; i <= n; ++i) {
	f[i] = -inf;
	for(int j = 1; j <= i; ++j) {
		ll val = 0;
		for(int k = 1; k <= m; ++k) if(a[k] >= j && a[k] <= i && b[k] >= i) val += c[k];
		f[i] = max(f[i], f[j - 1] + val);
	}
	ans = max(ans, f[i]);
}

考虑怎么去优化这个 dp。

发现这个 dp 式子每个位置 \(j\) 的贡献都是有一个 \(f_{j - 1}\) 的,然后后面带有 \(\sum\) 的式子在每次 \(i\) 变化后的变动次数是很小的,总次数就是区间总数 \(m\)。考虑将 \(f+\sum\) 插入线段树中,则每次只需要区间加和全局取 \(\max\),最后就只需要查询根节点的值了。

具体地,将所有区间按右端点排序,然后用一个指针在上面扫,每次对 \(r_j = i\) 的区间进行更新即可。时间复杂度 \(\mathcal{O}((n + m)\log n)\)

代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector<int>
#define eb emplace_back
#define pii pair<int, ll>
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
bool Mbe;
mt19937_64 rng(35);
constexpr int N = 2e5 + 10;
constexpr ll inf = 1e18;
struct node {
	int l, r, v;
};
int n, m;
ll f[N];
node a[N];
struct segt {
	ll val[N << 2], laz[N << 2];
	void tag(int u, ll v) {
		laz[u] += v;
		val[u] += v;
	}
	void down(int u) {
		if(laz[u]) {
			tag(u << 1, laz[u]);
			tag(u << 1 | 1, laz[u]);
			laz[u] = 0;
		}
	}
	void modify(int u, int L, int R, int l, int r, ll v) {
		if(l <= L && R <= r) return tag(u, v);
		down(u);
		int m = (L + R) >> 1;
		if(l <= m) modify(u << 1, L, m, l, r, v);
		if(r > m) modify(u << 1 | 1, m + 1, R, l, r, v);
		val[u] = max(val[u << 1], val[u << 1 | 1]);
	}
} t;
bool Med;
int main() {
	fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
//	freopen("data.in", "r", stdin);
//	freopen("data.out", "w", stdout);
	ios :: sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	cin >> n >> m;
	for(int i = 1; i <= m; ++i) cin >> a[i].l >> a[i].r >> a[i].v;
	sort(a + 1, a + m + 1, [](node a, node b){
		return a.r < b.r;
	});
	for(int i = 1, j = 1; i <= n; ++i) {
		t.modify(1, 1, n, i, i, max(t.val[1], 0ll));
		for(; j <= m && a[j].r == i; ++j)
			t.modify(1, 1, n, a[j].l, a[j].r, a[j].v);
	}
	cout << max(0ll, t.val[1]) << "\n";
	cerr << TIME << "ms\n";
	return 0;
}