树上动态DP状态设计及实现细节

发布时间 2023-11-12 12:54:43作者: mRXxy0o0

状态设计

由于具有更改操作,我们希望更改后会变的东西可以简单的通过线段树上单点修改来维护。

对于一般的常数层转移 DP,这一点较好处理。但是对于树上 DP,就需要结合重儿子进行设计另一个 \(g\) 数组,表示不含重儿子的 DP 值,就可以结合树剖快速计算。

这道,各点有不同代价,可覆盖子树所有叶节点,求覆盖子树所有叶节点的最小代价。

最先列的方程及转移矩阵:

\[f_u=\sum\min(f_v,a_v) \]

\[\begin{pmatrix} f_{son} & a_{son} \end{pmatrix} \begin{pmatrix} g_u & \infty \\ g_u & a_u-a_{son} \end{pmatrix} = \begin{pmatrix} f_{u} & a_{u} \end{pmatrix} \]

这是肯定不行的,因为改 \(a\) 值的时候会影响父亲,很麻烦。可以改变定义来规避这一点。

实现细节

  • 维护乘积的方向:横向量&逆向;列向量&正向

  • upd可以通过改g数组而不是传参

  • 本来不应使用叶子节点的转移矩阵,但是发现这与dp初值是一样的,可以减少码量

  • change的过程:
    算第一个g
    循环,当前upd,算下一个g
    最后一个upd

板子

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N=1e5+10,M=2,inf=1e9;
int n,m,a[N],h[N],ne[N<<1],e[N<<1],idx,f[N][2],g[N][2],ed[N];
int dep[N],fa[N],size[N],son[N],top[N],num[N],cnt,be[N];
inline void add(int u,int v){
	ne[++idx]=h[u],h[u]=idx,e[idx]=v;
}
inline void dfs(int u,int fath){
	dep[u]=dep[fath]+1,fa[u]=fath,size[u]=1;
	f[u][1]=g[u][1]=a[u];
	for(int i=h[u],v;i;i=ne[i]){
		v=e[i];
		if(v!=fath){
			dfs(v,u);
			size[u]+=size[v];
			if(size[v]>size[son[u]]) son[u]=v,ed[u]=ed[v];
		}
	}
	for(int i=h[u],v;i;i=ne[i]){
		v=e[i];
		if(v!=fath){
			f[u][0]+=max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
			if(v!=son[u]) g[u][0]+=max(f[v][0],f[v][1]),g[u][1]+=f[v][0];
		}
	}
	if(!son[u]) ed[u]=u;
}
inline void dfs1(int u,int gd){
	num[u]=++cnt,be[cnt]=u,top[u]=gd;
	if(son[u]) dfs1(son[u],gd);
	for(int i=h[u],v;i;i=ne[i]){
		v=e[i];
		if(v!=fa[u]&&v!=son[u]) dfs1(v,v);
	}
}
struct mar{
	int d[M][M];
};
inline mar operator*(mar a,mar b){
	mar c;
	c.d[0][0]=max(a.d[0][0]+b.d[0][0],a.d[0][1]+b.d[1][0]);
	c.d[0][1]=max(a.d[0][0]+b.d[0][1],a.d[0][1]+b.d[1][1]);
	c.d[1][0]=max(a.d[1][0]+b.d[0][0],a.d[1][1]+b.d[1][0]);
	c.d[1][1]=max(a.d[1][0]+b.d[0][1],a.d[1][1]+b.d[1][1]);
	return c;
}
struct node{
	int l,r;
	mar s;
}tr[N<<2];
inline void pushup(int u){
	tr[u].s=tr[u<<1|1].s*tr[u<<1].s;
}
inline void build(int u,int l,int r){
	tr[u]=/**/{l,r};
	if(l==r){
		tr[u].s=/**/{{{g[be[l]][0],g[be[l]][1]},{g[be[l]][0],-inf}}};
		return ;
	}
	int mid=l+r>>1;
	build(u<<1,l,mid),build(u<<1|1,mid+1,r);
	pushup(u);
}
inline void mdf(int u,int x){
	if(tr[u].l==tr[u].r){
		tr[u].s=/**/{{{g[be[x]][0],g[be[x]][1]},{g[be[x]][0],-inf}}};
		return ;
	}
	int mid=tr[u].l+tr[u].r>>1;
	if(x<=mid) mdf(u<<1,x);
	else mdf(u<<1|1,x);
	pushup(u);
}
inline mar query(int u,int l,int r){
	if(l<=tr[u].l&&tr[u].r<=r) return tr[u].s;
	int mid=tr[u].l+tr[u].r>>1;
	if(r<=mid) return query(u<<1,l,r);
	if(l>mid) return query(u<<1|1,l,r);
	return query(u<<1|1,l,r)*query(u<<1,l,r);
}
inline void change(int u,int x){
	g[u][1]+=x-a[u];
	mar la,nw;
	while(top[u]!=1){
		la=query(1,num[top[u]],num[ed[u]]);
		mdf(1,num[u]);
		nw=query(1,num[top[u]],num[ed[u]]);
		u=fa[top[u]];
		g[u][0]+=max(nw.d[0][0],nw.d[0][1])-max(la.d[0][0],la.d[0][1]);
		g[u][1]+=nw.d[0][0]-la.d[0][0];
	}
	mdf(1,num[u]);
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;++i) scanf("%d",&a[i]);
	for(int i=1,x,y;i<n;++i){
		scanf("%d%d",&x,&y);
		add(x,y),add(y,x);
	}
	dfs(1,0);
	dfs1(1,1);
	build(1,1,n);
	while(m--){
		int x,y;
		scanf("%d%d",&x,&y);
		change(x,y);
		a[x]=y;
		mar res=query(1,1,num[ed[1]]);
		printf("%d\n",max(res.d[0][0],res.d[0][1]));
	}
	return 0;
}