[QOJ4815] Flower's Land

发布时间 2023-10-09 11:56:14作者: 灰鲭鲨

简要题意:给出一个 \(n\) 个点的树,对某个点 \(i\) 求包含某一个点的大小为 \(k\) 的权值最大的连通块,一个连通块的权值是其所有点的权值之和。
\(n\le 40000,k\le \min(3000,n)\)

这个树上背包很难直接解决,考虑一种变体的树形背包:点分治。

点分治后,设分治中心为 \(rt\),那么如果要选上一个点 \(x\)\(rt\)\(x\) 这条链上的所有点都要选上。

跑出点分治连通块的一个 dfs 序,设 dfs 序为 \(x\) 的点为 \(dfn_x\),那么一个点不选,他的子树都不能选,直接跳到 \(x+sz_x\),否则就跳到 \(x+1\).

要求包括某个点,可以对 dfs 序从前往后跑一次,从后往前跑一次,现在只要在 \(x\) 这个地方,枚举 \(k\),把前面的和后面的背包合并取个最大值就可以了。

复杂度 \(nklog n\),不好过。

但是当连通块大小小于 \(k\) 时,就可以不继续点分治了,复杂度 \(O(nklog\frac nk)\)

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=40005,K=3005,INF=1e9;
int n,k,hd[N],mn,f[N][K],g[N][K],a[N],sz[N],dfn[N],ans[N],rt,res,vs[N],e_num,idx,ret;
struct edge{
	int v,nxt;
}e[N<<1];
void add_edge(int u,int v)
{
	e[++e_num]=(edge){v,hd[u]};
	hd[u]=e_num;
}
void dfs(int x,int y)
{
	sz[x]=1,dfn[++idx]=x;;
	for(int i=hd[x];i;i=e[i].nxt)
		if(e[i].v^y&&!vs[e[i].v])
			dfs(e[i].v,x),sz[x]+=sz[e[i].v];
}
void getsz(int x,int y,int n)
{
	int ret=0;
	sz[x]=1;
	for(int i=hd[x];i;i=e[i].nxt)
		if(e[i].v^y&&!vs[e[i].v])
			getsz(e[i].v,x,n),ret=max(ret,sz[e[i].v]),sz[x]+=sz[e[i].v];
	ret=max(ret,n-sz[x]);
	if(ret<mn)
		mn=ret,rt=x;
}
int findrt(int x,int n)
{
	mn=INF;
	getsz(x,0,n);
	return rt;
}
void solve(int x)
{
	dfs(x,idx=0);
	res+=sz[x]*k;
	for(int i=1;i<=sz[x]+1;i++)
		memset(f[i],-0x3f,sizeof(f[i])),memset(g[i],-0x3f,sizeof(g[i]));
	for(int i=0;i<=k;i++)
		f[1][i]=g[sz[x]+1][i]=0;
	for(int i=1;i<=sz[x];i++)
	{
		for(int j=0;j<k;j++)
			f[i+1][j+1]=max(f[i+1][j+1],f[i][j]+a[dfn[i]]);
		if(i^1)
			for(int j=0;j<=k;j++)
				f[i+sz[dfn[i]]][j]=max(f[i][j],f[i+sz[dfn[i]]][j]);
	}
	for(int i=sz[x];i;i--)
	{
		for(int j=1;j<=k;j++)
			g[i][j]=g[i+1][j-1]+a[dfn[i]];
		for(int j=0;j<=k;j++)
			ans[dfn[i]]=max(ans[dfn[i]],f[i][j]+g[i][k-j]);
		if(i^1)
		{
			for(int j=0;j<=k;j++)
				g[i][j]=max(g[i][j],g[i+sz[dfn[i]]][j]);
		}
	}
	vs[x]=1;
	for(int i=hd[x];i;i=e[i].nxt)
		if(sz[e[i].v]>=k&&!vs[e[i].v])
			solve(findrt(e[i].v,sz[e[i].v]));
}
int main()
{
	scanf("%d%d",&n,&k);
	for(int i=1;i<=n;i++)
		scanf("%d",a+i);
	for(int u,v,i=1;i<n;i++)
		scanf("%d%d",&u,&v),add_edge(u,v),add_edge(v,u);
	solve(findrt(1,n));
	for(int i=1;i<=n;i++)
		printf("%d ",ans[i]);
}