AT_cf17_final_j Tree MST 题解

发布时间 2023-12-10 15:57:41作者: Creeper_l

题意:给定一颗 \(n\) 个点的树,点 \(i\) 有权值 \(a_{i}\),边有边权。现在有另外一个完全图,两点之间的边权为树上两点之间的距离加上树上两点的点权,求这张完全图的最小生成树。

首先有一个很显然的暴力,把完全图中每两点之间的边权算出来,然后跑一边最小生成树,时间复杂度 \(O(n^{2} \log (n^{2}))\)

考虑如何优化。发现有很多路径是不必要的,因为它们一定劣于其它路径,这些路径我们就不用加到完全图中去了。那么可以用点分治来筛选路径。

假设当前重心为 \(u\),我们可以把路径分为两种:

  • 一个端点是 \(u\) 的路径。

  • 经过 \(u\) 但是端点不在 \(u\) 的路径。

对于第一种路径,我们可以直接将它加入边集,因为总边数不超过 \(O(n \log n)\) 条。对于第二种,考虑如何选出最优的。假设两个点为 \(x\)\(y\),那么可以把边权分为两个部分:\(x \to u\)\(u \to y\),即 \((a_{x}+dis(u,x))+(a_{y}+dis(u,y))\)。发现这个式子的前一半和后一半的形式是一样的,所以要让边权最小,只需要选一个 \((a_{x}+dis(u,x))\) 最小的 \(x\) 点,再连向其它所有的 \(y\) 点即可。

总时间复杂度 \(O(n \log^{2} n)\)

本题的 Trick:求最小生成树遇到边数很多时,可以先把边权小的边拿出来,删除一些没用的边,然后再做最小生成树。

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define inf 0x3f
#define inf_db 127
#define ls id << 1
#define rs id << 1 | 1
#define re register
#define endl '\n'
typedef pair <int,int> pii;
const int MAXN = 8e6 + 10;
int n,a[MAXN],head[MAXN],x,y,z,fa[MAXN],mn,id;
int pt,tot = 0,root,cnt,ans = 0;
bool vis[MAXN];
struct Node{int u,v,w,nxt;}e[MAXN << 1];
struct Edge{int u,v,w;}E[MAXN << 1];
struct F{int u,dis,idx;}p[MAXN << 1];
inline void Add(int u,int v,int w){e[++cnt] = {u,v,w,head[u]};head[u] = cnt;} 
inline int Get_size(int u,int father)
{
	if(vis[u] == true) return 0;
	int sum = 1;
	for(int i = head[u]; ~ i;i = e[i].nxt)
		if(e[i].v != father) sum += Get_size(e[i].v,u);
	return sum;
}
inline int Get_wc(int u,int father,int tot,int &wc)
{
	if(vis[u] == true) return 0;
	int sum = 1,mx = 0;
	for(int i = head[u]; ~ i;i = e[i].nxt)
	{
		int now = e[i].v;
		if(now == father) continue;
		int tmp = Get_wc(now,u,tot,wc);
		sum += tmp,mx = max(mx,tmp); 
	}
	if(max(mx,tot - sum) <= tot / 2) wc = u;
	return sum;
} 
inline void dfs(int u,int father,int dist,int r)
{
	if(vis[u] == true) return;
	int val = (a[u] + dist);
	p[++pt] = F{u,dist,r};
	if(val < mn) mn = val,id = u,root = r;
	for(int i = head[u]; ~ i;i = e[i].nxt)
	{
		int now = e[i].v;
		if(now == father) continue;
		dfs(now,u,dist + e[i].w,r);
	}
}
inline void solve(int u)
{
	if(vis[u] == true) return;
	Get_wc(u,0,Get_size(u,0),u),vis[u] = true;
	pt = 0,mn = 1e18,root = 0;
	for(int i = head[u]; ~ i;i = e[i].nxt)
	{
		int now = e[i].v;
		dfs(now,u,e[i].w,now);
	}
	for(int i = 1;i <= pt;i++)
		if(p[i].idx != root) E[++tot] = {id,p[i].u,mn + p[i].dis + a[p[i].u]};
	for(int i = 1;i <= pt;i++) E[++tot] = {u,p[i].u,p[i].dis + a[u] + a[p[i].u]};
	for(int i = head[u]; ~ i;i = e[i].nxt) solve(e[i].v); 
}
inline bool cmp(Edge x,Edge y){return x.w < y.w;}
inline int Find(int x)
{
	if(x == fa[x]) return x;
	return fa[x] = Find(fa[x]);
}
signed main()
{
	memset(head,-1,sizeof head);
	cin >> n;
	for(int i = 1;i <= n;i++) scanf("%lld",&a[i]);
	for(int i = 1;i < n;i++) scanf("%lld%lld%lld",&x,&y,&z),Add(x,y,z),Add(y,x,z);
	solve(1);
	sort(E + 1,E + tot + 1,cmp);
	for(int i = 1;i <= n;i++) fa[i] = i;
	for(int i = 1;i <= tot;i++)
		if(Find(E[i].u) != Find(E[i].v))
		{
			fa[Find(E[i].u)] = Find(E[i].v);
			ans += E[i].w;
		}
	cout << ans;
    return 0;
}