[CTSC2018]暴力写挂题解

发布时间 2023-12-26 22:13:20作者: hubingshan

我们先将柿子变成 \(\frac{1}{2}(dis_{x,y}+dep_{x}+dep_{y})-dep'_{lca'}\)

考虑边分治,枚举断边,我们将一个点在第二棵树上的点权看成是 \(v_x=d_x+dep_x\),答案就为 \(v_x+v_y+dep'_{lca'}\)

对于每次边分治将分治联通块内所有点在第二棵树上的建出虚树,同时将分治联通块以分治中心边为界限分成两部分,将一部分的点标为黑点,将另一部分的点标为白点。

那么对于虚树中的一个点,以它为的最大答案就是在它的两个不同子树中分别选出一个黑点和一个白点使这两个点的点权和最大。

我们在虚树上进行树形DP,每个点维护这个点在虚树上的子树中黑点的最大点权及白点的最大点权就行了。

code

#include<bits/stdc++.h>
using namespace std;
#define N 1000005
#define int long long
int n,top,ans=-1e18;
int q[N],val[N],col[N];
namespace XS{
	int k,res,t,tot;
	int h[N],sta[N],id[N],lc[N][20],dep[N],lg[N],vis[N],f[N][2],deep[N];
	vector<int> G[N];
	struct AB{
		int a,b,c,n;
	}d[N*2];
	void cun(int x,int y,int z){
		d[++k]={x,y,z,h[x]},h[x]=k;
	}
	void dfs(int x,int fa){
		lc[++tot][0]=x,id[x]=tot,deep[x]=deep[fa]+1;
		for(int i=h[x];i;i=d[i].n){
			int y=d[i].b;
			if(y==fa) continue;
			dep[y]=dep[x]+d[i].c;
			dfs(y,x),lc[++tot][0]=x;
		}
	}
	int mn(int x,int y){
		return deep[x]<deep[y]?x:y;
	}
	void init(){
		for(int i=1,x,y,z;i<n;i++){
			scanf("%lld%lld%lld",&x,&y,&z);
			cun(x,y,z),cun(y,x,z);
		}
		dfs(1,0);
		for(int i=2;i<=tot;i++) lg[i]=lg[i/2]+1;
		for(int j=1;j<=19;j++){
			for(int i=1;i+(1<<j)-1<=tot;i++) lc[i][j]=mn(lc[i][j-1],lc[i+(1<<(j-1))][j-1]);
		}
	}
	int Lca(int x,int y){
		if(id[x]>id[y]) swap(x,y);
		int p=lg[id[y]-id[x]+1];
		return mn(lc[id[x]][p],lc[id[y]-(1<<p)+1][p]);
	}
	void ins(int x){
		vis[x]=1,G[x].clear();
		if(!t){
			sta[++t]=x;
			return ;
		}
		int lca=Lca(x,sta[t]);
		if(!vis[lca]) col[lca]=-1,G[lca].clear(),vis[lca]=1;
		if(lca!=sta[t]){
			while(t>1&&id[lca]<id[sta[t-1]]) G[sta[t-1]].push_back(sta[t]),t--;
			if(id[lca]!=id[sta[t-1]]) G[lca].push_back(sta[t]),sta[t]=lca;
			else G[lca].push_back(sta[t]),t--;
		}
		sta[++t]=x;
	}
	void DP(int x){
		f[x][0]=f[x][1]=-1e18;
		for(auto y:G[x]){
			DP(y);
			res=max(res,max(f[x][0]+f[y][1],f[x][1]+f[y][0])-2*dep[x]);
			f[x][0]=max(f[x][0],f[y][0]),f[x][1]=max(f[x][1],f[y][1]);
		}
		if(col[x]!=-1) res=max(res,val[x]+f[x][col[x]^1]-2*dep[x]);
		f[x][0]=f[x][1]=-1e18;
		if(col[x]!=-1) f[x][col[x]]=val[x];
		for(int i=G[x].size()-1;~i;i--){
			int y=G[x][i];
			res=max(res,max(f[x][0]+f[y][1],f[x][1]+f[y][0])-2*dep[x]);
			f[x][0]=max(f[x][0],f[y][0]),f[x][1]=max(f[x][1],f[y][1]);
		}
		vis[x]=0;
	}
	bool cmp(int x,int y){
		return id[x]<id[y];
	}
	int work(){
		res=-1e18,t=0;
		sort(q+1,q+1+top,cmp);
		if(q[1]!=1) col[1]=-1,sta[++t]=1,G[1].clear();
		for(int i=1;i<=top;i++) ins(q[i]);
		while(t>1) G[sta[t-1]].push_back(sta[t]),t--;
		DP(1);
		return res;
	}
}
namespace BFZ{
	int k=1,ssz,rt,tot;
	int h[N],dep[N],sz[N],vis[N];
	vector<pair<int,int> > G[N];
	struct AB{
		int a,b,c,n;
	}d[N*2];
	void cun(int x,int y,int z){
		d[++k]={x,y,z,h[x]},h[x]=k;
	}
	void rebuild(int x,int fa){
		int tmp=0,lst=0;
		for(auto p:G[x]){
			int y=p.first,z=p.second;
			if(y==fa) continue;
			tmp++;
			if(tmp==1) cun(x,y,z),cun(y,x,z),lst=x;
			else if(tmp==(int)G[x].size()-(x!=1)) cun(lst,y,z),cun(y,lst,z);
			else tot++,cun(lst,tot,0),cun(tot,lst,0),lst=tot,cun(tot,y,z),cun(y,tot,z);
		}
		for(auto p:G[x]){
			if(p.first==fa) continue;
			rebuild(p.first,x);
		}
	}
	void dfs(int x,int fa){
		for(int i=h[x];i;i=d[i].n){
			int y=d[i].b,z=d[i].c;
			if(y==fa) continue;
			dep[y]=dep[x]+z,dfs(y,x);
		}
	}
	void init(){
		tot=n;
		for(int i=1,x,y,z;i<n;i++){
			scanf("%lld%lld%lld",&x,&y,&z);
			G[x].push_back({y,z}),G[y].push_back({x,z});
		}
		rebuild(1,0);
		dfs(1,0);
	}
	void gt_rt(int x,int fa,int siz){
		sz[x]=1;
		for(int i=h[x];i;i=d[i].n){
			int y=d[i].b;
			if(y==fa||vis[i>>1]) continue;
			gt_rt(y,x,siz),sz[x]+=sz[y];
			int mx_sz=max(siz-sz[y],sz[y]);
			if(mx_sz<ssz) ssz=mx_sz,rt=i;
		}
	}
	void dfs2(int x,int fa,int s,int op){
		if(x<=n) val[x]=s+dep[x],col[x]=op,q[++top]=x;
		for(int i=h[x];i;i=d[i].n){
			int y=d[i].b;
			if(y==fa||vis[i>>1]) continue;
			dfs2(y,x,s+d[i].c,op);
		}
	}
	void solve(int x,int siz){
		ssz=1e9,gt_rt(x,0,siz);
		if(ssz==1e9) return ;
		int i=rt;
		vis[i>>1]=1,top=0;
		dfs2(d[i].a,0,0,0);
		dfs2(d[i].b,0,0,1);
		ans=max(ans,d[i].c+XS::work());
		int sum=sz[d[i].b];
		solve(d[i].a,siz-sum),solve(d[i].b,sum);
	}
}
signed main(){
	scanf("%lld",&n);
	BFZ::init();
	XS::init();
	BFZ::solve(1,BFZ::tot);
	assert(ans%2==0);
	ans/=2;
	for(int i=1;i<=n;i++) ans=max(ans,BFZ::dep[i]-XS::dep[i]);
	printf("%lld\n",ans);
	return 0;
}