【学习笔记】(27) 整体 DP

发布时间 2023-09-19 21:49:31作者: Aurora-JC

1.算法简介

整体 DP 就是用线段树合并维护 DP。

有一些问题,通常见于二维的DP,有一维记录当前x的信息,但是这一维过大无法开下,O(nm) 也无法通过。

但是如果发现,对于 x,在第二维的一些区间内,取值都是相同的,并且这样的区间是有限个,就可以批量处理。

所以我们就可以用线段树来维护 DP。

对于序列的问题,可以直接扫过去,修改某些位置的点,或者线段树合并。

对于树上的问题,线段树合并。

2.例题

Ⅰ. P4577 [FJOI2018] 领导集团问题

\(f_{i,j}\) 表示以 \(i\) 为根的子树 \(min_w≥j\) 的答案,分两种情况讨论:

  1. \(f_{i,j} += f{son_i,j}\)

  2. \(f_{i,j} = max(f_{i,j},f_{i,w_i} + 1) (j \le w_i)\) 在合并完儿子后进行此操作。

对于转移 1,显然直接线段树合并即可,对于转移 2 ,发现其实是区间 +,且\(f_{i,j}\) 是单调递减的,所以我们只要维护区间最小值 \(mn\),然后找到那段区间即可,需要标记永久化(不用也行)。

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N = 2e5 + 67;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}

bool _u;

int n, cnt;
int a[N], b[N], rt[N];
int ls[N << 5], rs[N << 5], lz[N << 5], mn[N << 5];
vector<int> e[N];

void pushup(int x){
	mn[x] = min(mn[ls[x]], mn[rs[x]]) + lz[x];
}

void modify(int &x, int l, int r, int L, int R){
	if(!x) x = ++cnt;
	if(L <= l && r <= R) return ++lz[x], ++mn[x], void();
	int mid = (l + r) >> 1;
	if(L <= mid) modify(ls[x], l, mid, L, R);
	if(R > mid) modify(rs[x], mid + 1, r, L, R);
	pushup(x);
}
 
int query1(int x, int l, int r, int p){
	if(l == r && !x) return lz[x];
	int mid = (l + r) >> 1;
	if(p <= mid) return lz[x] + query1(ls[x], l, mid, p);
	else return lz[x] + query1(rs[x], mid + 1, r, p);
}

int query2(int x, int l, int r, int val){
	if(l == r) return l;
	int mid = (l + r) >> 1; val -= lz[x];
	if(mn[ls[x]] <= val) return query2(ls[x], l, mid, val);
	else return query2(rs[x], mid + 1, r, val);
}

int merge(int x, int y){
	if(!x || !y) return x + y;
	ls[x] = merge(ls[x], ls[y]), rs[x] = merge(rs[x], rs[y]);
	lz[x] += lz[y], pushup(x);
	return x;
}

void dfs(int x){
	for(auto y : e[x]) dfs(y), rt[x] = merge(rt[x], rt[y]);
	int tmp = query1(rt[x], 1, n, a[x]);
	modify(rt[x], 1, n, query2(rt[x], 1, n, tmp), a[x]);
} 

bool _v;

int main(){
	cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";
	
	n = read();
	for(int i = 1; i <= n; ++i) b[i] = a[i] = read();
	sort(b + 1, b + 1 + n);
	for(int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + 1 + n, a[i]) - b;
	for(int i = 2, f; i <= n; ++i) f = read(), e[f].pb(i);
	dfs(1); printf("%d\n", query1(rt[1], 1, n, 1));
	return 0;
}

Ⅱ.CF490F Treeland Tour

\(f/g_{i,j}\) 表示以 \(j\) 为结尾的LIS 和 LDS。线段树维护即可。

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N = 6e3 + 67, LIM = 1e6;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}

bool _u;

int n, ans, cnt;
int a[N], rt[N];
int ls[N << 5], rs[N << 5], pre[N << 5], suf[N << 5];
vector<int> e[N];

void modify(int &x, int l, int r, int p, int v, int *val){
	if(!x) x = ++cnt; val[x] = max(val[x], v);
	if(l == r) return ;
	int mid = (l + r) >> 1;
	if(p <= mid) modify(ls[x], l, mid, p, v, val);
	else modify(rs[x], mid + 1, r, p, v, val);
}

int merge(int x, int y){
	if(!x || !y) return x + y;
	pre[x] = max(pre[x], pre[y]), suf[x] = max(suf[x], suf[y]);
	ans = max(ans, max(pre[ls[x]] + suf[rs[y]], suf[rs[x]] + pre[ls[y]]));
	ls[x] = merge(ls[x], ls[y]), rs[x] = merge(rs[x], rs[y]);
	return x;
}

int query(int x, int l, int r, int L, int R, int *val){
	if(!x || L > R) return 0;
	if(L <= l && r <= R) return val[x];
	int mid = (l + r) >> 1, ans = 0;
	if(L <= mid) ans = max(ans, query(ls[x], l, mid, L, R, val));
	if(R > mid) ans = max(ans, query(rs[x], mid + 1, r, L, R, val));
	return ans;
}

void dfs(int x, int fa){
	int ns = 0, np = 0;
	for(auto y : e[x]){
		if(y == fa) continue;
		dfs(y, x);
		int tp = query(rt[y], 1, LIM, 1, a[x] - 1, pre);
		int ts = query(rt[y], 1, LIM, a[x] + 1, LIM, suf);
		ans = max(ans, max(np + ts, ns + tp) + 1);
		ns = max(ns, ts), np = max(np, tp);
		rt[x] = merge(rt[x], rt[y]); 
	}
	modify(rt[x], 1, LIM, a[x], np + 1, pre);
	modify(rt[x], 1, LIM, a[x], ns + 1, suf);
}

bool _v;

int main(){
	cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";
	
	n = read();
	for(int i = 1; i <= n; ++i) a[i] = read();
	for(int i = 1; i < n; ++i){
		int u = read(), v = read();
		e[u].pb(v), e[v].pb(u);
	}
	dfs(1, 0);
	printf("%d\n", ans);
	return 0;
}

Ⅲ. P6773 [NOI2020] 命运

发现对于 若干点对 \((u,v)\),如果对于当前 \(x\) 节点满足 \(v\)\(x\) 的子树内,且 \(u\)\(x\) 的祖先,那么我们只要满足 \(u\) 节点最深的点对即可。

所以我们可以设 \(f[u][i]\) 表示 表示以 \(u\) 为根的子树内,下端点在子树内并且没有被满足的限制中上端点的最深深度为 \(i\) 的方案数。进行分类讨论。

  • \((u,v)\) 的权值 为 \(1\)\(f[u][i] = \sum\limits_{j = 0}^{dep[u]} f[u][i] \times f[v][j]\)

  • \((u,v)\) 的权值 为 \(0\)\(f[u][i] = \sum\limits_{j = 0}^{i} f[u][i] \times f[v][j] + \sum\limits_{j = 0}^{i - 1} f[u][j] \times f[v][i]\)

\(g[u][i] = \sum\limits_{j = 0}^{i} f[u][j]\)

那么转移式就可以写成

\[f[u][i] = f[u][i] \times (g[v][dep[u]] + g[v][i]) + g[u][i - 1] \times f[v][i] \]

线段树合并即可。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 5e5 + 67, mod = 998244353;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}

bool _u;

int n, m, nod;
int dep[N], ls[N << 5], rs[N << 5], rt[N];
ll laz[N << 5], sum[N << 5];

void add(ll &x, ll y){x += y; if(x >= mod) x -= mod;}
void mul(ll &x, ll y){x = x * y % mod;}

struct Edge{
	int tot, to[N << 1], nxt[N << 1], hd[N];
	void add(int u, int v){to[++tot] = v, nxt[tot] = hd[u], hd[u] = tot;}
}e, g;

void build(int &u, int l, int r, int p){
	u = ++nod; sum[u] = laz[u] = 1;
	if(l == r) return ;
	int mid = (l + r) >> 1;
	if(p <= mid) build(ls[u], l, mid, p);
	else build(rs[u], mid + 1, r, p); 
}

void pushup(int u){sum[u] = (sum[ls[u]] + sum[rs[u]]) % mod;}
void pushdown(int u){
	if(laz[u] != 1){
		mul(sum[ls[u]], laz[u]), mul(sum[rs[u]], laz[u]);
		mul(laz[ls[u]], laz[u]), mul(laz[rs[u]], laz[u]);
	} laz[u] = 1;
}

ll query(int u, int l, int r, int p){
	if(!u || r <= p) return sum[u];
	int mid = (l + r) >> 1; ll ans = 0; pushdown(u);
	if(p > mid) ans = query(rs[u], mid + 1, r, p);
	return add(ans, query(ls[u], l, mid, p)), ans;
}

int merge(int x, int y, int l, int r, ll &s1, ll &s2){ //s1 表示 g[v][dep[u]] + g[v][i], s2 表示 g[u][i - 1] 
	if(!x && !y) return 0;
	if(!y) return add(s2, sum[x]), mul(laz[x], s1), mul(sum[x], s1), x;
	if(!x) return add(s1, sum[y]), mul(laz[y], s2), mul(sum[y], s2), y;
	if(l == r){
		ll tmp = sum[x];
		add(s1, sum[y]), mul(sum[x], s1);
		add(sum[x], sum[y] * s2 % mod), add(s2, tmp);
		//注意先后顺序 
		return x;
	}
	pushdown(x), pushdown(y);
	int mid = (l + r) >> 1;
	ls[x] = merge(ls[x], ls[y], l, mid, s1, s2);
	rs[x] = merge(rs[x], rs[y], mid + 1, r, s1, s2);
	return pushup(x), x;
}

void dfs(int x, int fa){
	int mxd = 0; dep[x] = dep[fa] + 1;
	for(int i = g.hd[x]; i; i = g.nxt[i]) mxd = max(mxd, dep[g.to[i]]); //找到最深的点 
	build(rt[x], 0, n, mxd);
	for(int i = e.hd[x]; i; i = e.nxt[i]){
		int y = e.to[i]; if(y == fa) continue;
		dfs(y, x);
		ll zx = query(rt[y], 0, n, dep[x]), zxq = 0;
		rt[x] = merge(rt[x], rt[y], 0, n, zx, zxq);
	}
}

bool _v;

int main(){
	cerr << abs(&_u - &_v) / 1048576.0 << " MB\n";
	
	
//	freopen("destiny4.in", "r", stdin);
	n = read();
	for(int i = 1; i < n; ++i){
		int u = read(), v = read();
		e.add(u, v), e.add(v, u);
	}
	m = read();
	for(int i = 1; i <= m; ++i){
		int u = read(), v = read();
		g.add(v, u);
	}
	dfs(1, 0);
	printf("%lld\n", query(rt[1], 0, n, 0));
	return 0;
}

参考:https://www.cnblogs.com/alex-wei/p/DP_optimization_method_II.html、https://www.cnblogs.com/Miracevin/p/10942795.html