[WC2018] 通道题解

发布时间 2023-12-27 10:38:01作者: hubingshan

先考虑只有两颗树要咋做,柿子先变成 \(dep_x+dep_y-2\times dep_{lca}+dist_2(x,y)\)
我们可以新建节点 \(x'\rightarrow x\),边权为 \(dep_x\),这样上面的式子可以看作枚举 \(lca\) 后,选出一个端点在不同子树中的直径,可以直接 \(DP\) 合并直径

再考虑多了一颗树,我们可以进行边分治,枚举断边之后,柿子变成 \(d_x+d_y+val+dep_{2x}+dep_{2y}-2\times dep_{2lca}+dist_{3}(x,y)\)
同理有了这个式子之后我们可以在虚树上 \(DP\) 合并直径,端点必须为不同子树的且颜色不同的点

code

#include<bits/stdc++.h>
using namespace std;
#define N 400005
#define int long long
int top,n,ans=-1e18;
int q[N],val[N],col[N];
namespace TREE{
	int k,tot;
	int h[N],id[N],lc[N][20],dep[N],lg[N],deep[N];
	vector<int> G[N];
	struct AB{
		int a,b,c,n;
	}d[N*2];
	void cun(int x,int y,int z){
		d[++k]={x,y,z,h[x]},h[x]=k;
	}
	void dfs(int x,int fa){
		lc[++tot][0]=x,id[x]=tot,deep[x]=deep[fa]+1;
		for(int i=h[x];i;i=d[i].n){
			int y=d[i].b;
			if(y==fa) continue;
			dep[y]=dep[x]+d[i].c;
			dfs(y,x);lc[++tot][0]=x;
		}
	}
	int mn(int x,int y){
		return deep[x]<deep[y]?x:y;
	}
	int Lca(int x,int y){
		if(id[x]>id[y]) swap(x,y);
		int p=lg[id[y]-id[x]+1];
		return mn(lc[id[x]][p],lc[id[y]-(1<<p)+1][p]);
	}
	int dis(int x,int y){
		if(!x||!y) return -1e18;
		int lca=Lca(x,y);
		return dep[x]+dep[y]-2*dep[lca];
	}
	void init(){
		for(int i=1,x,y,z;i<n;i++){
			scanf("%lld%lld%lld",&x,&y,&z);
			cun(x,y,z),cun(y,x,z);
		}
		dfs(1,0);
		for(int i=2;i<=tot;i++) lg[i]=lg[i/2]+1;
		for(int j=1;j<=19;j++){
			for(int i=1;i+(1<<j)-1<=tot;i++) lc[i][j]=mn(lc[i][j-1],lc[i+(1<<(j-1))][j-1]);
		}
	}
}
struct ZJ{
	int a,b,len;
};
int merge(ZJ x,ZJ y){
	int ans=-1e18;
	ans=max(ans,val[x.a]+val[y.a]+TREE::dis(x.a,y.a));
	ans=max(ans,val[x.a]+val[y.b]+TREE::dis(x.a,y.b));
	ans=max(ans,val[x.b]+val[y.a]+TREE::dis(x.b,y.a));
	ans=max(ans,val[x.b]+val[y.b]+TREE::dis(x.b,y.b));
	return ans;
}
ZJ mk_zj(int x,int y){
	ZJ z={x,y,val[x]+val[y]+TREE::dis(x,y)};
	return z;
}
ZJ mx(ZJ x,ZJ y){
	return x.len>y.len?x:y;
}
ZJ bing(ZJ x,ZJ y){
	ZJ z=mx(x,y);
	z=mx(z,mk_zj(x.a,y.a));z=mx(z,mk_zj(x.a,y.b));z=mx(z,mk_zj(x.b,y.a));z=mx(z,mk_zj(x.b,y.b));
	return z;
}
namespace XS{
	int k,res,t,tot;
	int h[N],sta[N],id[N],lc[N][20],dep[N],lg[N],vis[N],deep[N];
	ZJ f[N][2];
	vector<int> G[N];
	struct AB{
		int a,b,c,n;
	}d[N*2];
	void cun(int x,int y,int z){
		d[++k]={x,y,z,h[x]},h[x]=k;
	}
	void dfs(int x,int fa){
		lc[++tot][0]=x,id[x]=tot,deep[x]=deep[fa]+1;
		for(int i=h[x];i;i=d[i].n){
			int y=d[i].b;
			if(y==fa) continue;
			dep[y]=dep[x]+d[i].c;
			dfs(y,x);lc[++tot][0]=x;
		}
	}
	int mn(int x,int y){
		return deep[x]<deep[y]?x:y;
	}
	void init(){
		for(int i=1,x,y,z;i<n;i++){
			scanf("%lld%lld%lld",&x,&y,&z);
			cun(x,y,z),cun(y,x,z);
		}
		dfs(1,0);
		for(int i=2;i<=tot;i++) lg[i]=lg[i/2]+1;
		for(int j=1;j<=19;j++){
			for(int i=1;i+(1<<j)-1<=tot;i++) lc[i][j]=mn(lc[i][j-1],lc[i+(1<<(j-1))][j-1]);
		}
	}
	int Lca(int x,int y){
		if(id[x]>id[y]) swap(x,y);
		int p=lg[id[y]-id[x]+1];
		return mn(lc[id[x]][p],lc[id[y]-(1<<p)+1][p]);
	}
	void ins(int x){
		if(!vis[x]) G[x].clear(),vis[x]=1;
		if(!t){
			sta[++t]=x;
			return;
		}
		int lca=Lca(x,sta[t]);
		if(!vis[lca]) G[lca].clear(),col[lca]=-1,vis[lca]=1;
		if(lca!=sta[t]){
			while(t>1&&id[lca]<id[sta[t-1]]) G[sta[t-1]].push_back(sta[t]),t--;
			if(id[lca]!=id[sta[t-1]]) G[lca].push_back(sta[t]),sta[t]=lca;
			else G[lca].push_back(sta[t]),t--;
		}
		sta[++t]=x;
	}
	void DP(int x){
		f[x][0]=f[x][1]={0,0,(int)-1e18};
		if(col[x]!=-1) f[x][col[x]]={x,x,val[x]};
		for(auto y:G[x]){
			DP(y);
			res=max(res,max(merge(f[x][0],f[y][1]),merge(f[x][1],f[y][0]))-2*dep[x]);
			f[x][0]=bing(f[x][0],f[y][0]),f[x][1]=bing(f[x][1],f[y][1]);
		}
		vis[x]=0;
	}
	bool cmp(int x,int y){
		return id[x]<id[y];
	}
	int work(){
		res=-1e18,t=0;
		sort(q+1,q+1+top,cmp);
		if(q[1]!=1) G[1].clear(),col[1]=-1,vis[1]=1,sta[++t]=1;
		for(int i=1;i<=top;i++) ins(q[i]);
		while(t>1) G[sta[t-1]].push_back(sta[t]),t--;
		DP(1);
		return res;
	}
}
namespace BFZ{
	int k=1,ssz,rt,tot;
	int h[N],dep[N],sz[N],vis[N];
	vector<pair<int,int> > G[N];
	struct AB{
		int a,b,c,n;
	}d[N*2];
	void cun(int x,int y,int z){
		d[++k]={x,y,z,h[x]},h[x]=k;
	}
	void rebuild(int x,int fa){
		int tmp=0,lst=0;
		for(auto p:G[x]){
			int y=p.first,z=p.second;
			if(y==fa) continue;
			tmp++;
			if(tmp==1) cun(x,y,z),cun(y,x,z),lst=x;
			else if(tmp==(int)G[x].size()-(x!=1)) cun(lst,y,z),cun(y,lst,z);
			else tot++,cun(lst,tot,0),cun(tot,lst,0),cun(tot,y,z),cun(y,tot,z),lst=tot;
		}
		for(auto p:G[x]){
			int y=p.first;
			if(y==fa) continue;
			rebuild(y,x);
		}
	}
	void init(){
		for(int i=1,x,y,z;i<n;i++){
			scanf("%lld%lld%lld",&x,&y,&z);
			G[x].push_back({y,z}),G[y].push_back({x,z});
		}tot=n;
		rebuild(1,0);
	}
	void gt_rt(int x,int fa,int siz){
		sz[x]=1;
		for(int i=h[x];i;i=d[i].n){
			int y=d[i].b;
			if(y==fa||vis[i>>1]) continue;
			gt_rt(y,x,siz);sz[x]+=sz[y];
			int msz=max(siz-sz[y],sz[y]);
			if(msz<ssz) ssz=msz,rt=i;
		}
	}
	void dfs(int x,int fa,int s,int op){
		if(x<=n) val[x]=s+XS::dep[x],col[x]=op,q[++top]=x;
		for(int i=h[x];i;i=d[i].n){
			int y=d[i].b;
			if(y==fa||vis[i>>1]) continue;
			dfs(y,x,s+d[i].c,op);
		}
	}
	void solve(int x,int siz){
		ssz=1e9,gt_rt(x,0,siz);
		if(ssz==1e9) return ;
		int i=rt;top=0;vis[i>>1]=1;
		dfs(d[i].a,0,0,0),dfs(d[i].b,0,0,1);
		ans=max(ans,XS::work()+d[i].c);int sum=sz[d[i].b];
		solve(d[i].a,siz-sum),solve(d[i].b,sum);
	}
}
signed main(){
	scanf("%lld",&n);
	BFZ::init();
	XS::init();
	TREE::init();
	BFZ::solve(1,BFZ::tot);
	printf("%lld\n",ans); 
	return 0;
}