【题解】Tree MST

发布时间 2023-04-15 08:44:15作者: T-water

题面

给定一棵 \(n\) 个节点的树,现有有一张完全图,两点 \(x,y\) 之间的边长为 \(w_x+w_y+dis_{x,y}\),其中 \(dis\) 表示树上两点的距离。

求完全图的最小生成树。

\(n \leq 2 \times 10^5\)

题解

???神秘
借鉴支配对的思想,点分治后将树中点权替换为\(dep_i+w_i\),选择点权最小的一个和其他点连边,总共\(n\log n\)条边,跑最小生成树,总复杂度\(n \log^2 n\)

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int rd(){
	int f=1,j=0;
	char w=getchar();
	while(!isdigit(w)){
		if(w=='-')f=-1;
		w=getchar();
	}
	while(isdigit(w)){
		j=j*10+w-'0';
		w=getchar();
	}
	return f*j;
}
const int N=200010;
int head[N],to[N*2],fro[N*2],len[N*2],tail;
int n,w[N],wk[N],dis[N],siz[N],son[N];
inline void adlin(int x,int y,int z){
	to[++tail]=y,fro[tail]=head[x],head[x]=tail,len[tail]=z;
	return ;
}
int Siz,Rt,Sum;
struct Lin{
	int u,v,len;
}link[N*40];
int cnt;
int lin[N],tot,p[N];
void getrt(int u,int fa){
	siz[u]=1;
	for(int k=head[u];k;k=fro[k]){
		int x=to[k];
		if(x==fa||wk[x])continue;
		getrt(x,u);
		siz[u]+=siz[x];
		son[u]=(siz[x]>siz[son[u]])?x:son[u];
	}
	int ma=max(siz[son[u]],Siz-siz[u]);
	if(ma<Sum)Sum=ma,Rt=u;
	return ;
}
void dfs(int u,int fa,int dep){
	siz[u]=1;
	p[u]=w[u]+dep;
	lin[++tot]=u;
	for(int k=head[u];k;k=fro[k]){
		int x=to[k];
		if(x==fa||wk[x])continue;
		dfs(x,u,dep+len[k]),siz[u]+=siz[x];
	}
	return ;
}
void init(int u,int v){
	Sum=1e18,Siz=v;
	getrt(u,0);
	int rt=Rt;
	tot=0,dfs(rt,0,0);
	sort(lin+1,lin+1+tot,[&](int a,int b){return p[a]<p[b];});
	for(int i=2;i<=tot;i++)link[++cnt]=(Lin){lin[1],lin[i],p[lin[1]]+p[lin[i]]};
	wk[rt]=true;
	for(int k=head[rt];k;k=fro[k]){
		int x=to[k];
		if(wk[x])continue;
		init(x,siz[x]);
	}
	return ;
}
int bel[N];
int getfa(int x){
	return (bel[x]==x)?x:bel[x]=getfa(bel[x]);
}
long long ans;
signed main(){
	n=rd();
	for(int i=1;i<=n;i++)w[i]=rd();
	for(int i=1;i<n;i++){
		int x=rd(),y=rd(),z=rd();
		adlin(x,y,z),adlin(y,x,z);
	}
	init(1,n);
	for(int i=1;i<=n;i++)bel[i]=i;
	sort(link+1,link+1+cnt,[&](Lin a,Lin b){return a.len<b.len;});
	for(int i=1;i<=cnt;i++){
		int u=link[i].u,v=link[i].v,k=link[i].len;
//		cout<<u<<"-"<<v<<":"<<w<<"\n";
		u=getfa(u),v=getfa(v);
		if(u!=v)bel[u]=v,ans+=k;
	}
	printf("%lld\n",ans);
	return 0;
}