CODE FESTIVAL 2017 Final J Tree MST

发布时间 2023-08-03 23:13:45作者: zltzlt

洛谷传送门

AtCoder 传送门

求完全图的最小生成树,立刻想到 Boruvka。

于是剩下的任务是,对于每个点 \(y\),找到当前和它不在同一连通块的点 \(y\)\(F(x, y) = w_y + dis_{x, y}\) 的最小值。

如果没有 \(x, y\) 所在连通块不同的限制,可以很轻易地换根 dp 完成。先自下而上求出 \(y\) 在子树内的 \(F(x, y)\) 最小值,再自上而下求出 \(y\) 在子树外 \(F(x, y)\) 最小值。

加上了这个限制,我们除了求每个 \(x\)\(F(x, y)\) 最小值和它对应的 \(y\),还要求次小值和它对应的 \(y\)。需要注意我们强制规定最小值和次小值对应的 \(y\) 当前所在连通块不同。这样如果 \(x\) 跟最小值的 \(y\) 在同一连通块,就可以让次小值递补。

时间复杂度 \(O(n \log n)\)

code
// Problem: J - Tree MST
// Contest: AtCoder - CODE FESTIVAL 2017 Final
// URL: https://atcoder.jp/contests/cf17-final/tasks/cf17_final_j
// Memory Limit: 256 MB
// Time Limit: 5000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;

ll n, a[maxn], head[maxn], len, fa[maxn];
pii f[maxn][2], g[maxn][2], b[maxn];

struct edge {
	int to, dis, next;
} edges[maxn << 1];

inline void add_edge(int u, int v, int d) {
	edges[++len].to = v;
	edges[len].dis = d;
	edges[len].next = head[u];
	head[u] = len;
}

int find(int x) {
	return fa[x] == x ? x : fa[x] = find(fa[x]);
}

inline bool merge(int x, int y) {
	x = find(x);
	y = find(y);
	if (x != y) {
		fa[x] = y;
		return 1;
	} else {
		return 0;
	}
}

inline void upd(pii a, pii &x, pii &y) {
	if (a < x) {
		if (find(a.scd) != find(x.scd)) {
			y = x;
		}
		x = a;
	} else if (a < y) {
		if (find(a.scd) != find(x.scd)) {
			y = a;
		}
	}
}

void dfs(int u, int fa) {
	f[u][0] = make_pair(a[u], u);
	f[u][1] = make_pair(1e18, -1);
	for (int i = head[u]; i; i = edges[i].next) {
		int v = edges[i].to, d = edges[i].dis;
		if (v == fa) {
			continue;
		}
		dfs(v, u);
		pii p1 = f[v][0], p2 = f[v][1];
		p1.fst += d;
		p2.fst += d;
		upd(p1, f[u][0], f[u][1]);
		upd(p2, f[u][0], f[u][1]);
	}
}

void dfs2(int u, int fa, pii p1, pii p2) {
	g[u][0] = f[u][0];
	g[u][1] = f[u][1];
	upd(p1, g[u][0], g[u][1]);
	upd(p2, g[u][0], g[u][1]);
	vector<int> son, dis;
	for (int i = head[u]; i; i = edges[i].next) {
		int v = edges[i].to, d = edges[i].dis;
		if (v == fa) {
			continue;
		}
		son.pb(v);
		dis.pb(d);
	}
	if (son.empty()) {
		return;
	}
	int len = (int)son.size();
	vector< vector<pii> > pre(len, vector<pii>(2)), suf(len, vector<pii>(2));
	pre[0][0] = f[son[0]][0];
	pre[0][0].fst += dis[0];
	pre[0][1] = f[son[0]][1];
	pre[0][1].fst += dis[0];
	for (int i = 1; i < len; ++i) {
		pre[i][0] = pre[i - 1][0];
		pre[i][1] = pre[i - 1][1];
		int v = son[i], d = dis[i];
		pii t = f[v][0];
		t.fst += d;
		upd(t, pre[i][0], pre[i][1]);
		t = f[v][1];
		t.fst += d;
		upd(t, pre[i][0], pre[i][1]);
	}
	suf[len - 1][0] = f[son[len - 1]][0];
	suf[len - 1][0].fst += dis[len - 1];
	suf[len - 1][1] = f[son[len - 1]][1];
	suf[len - 1][1].fst += dis[len - 1];
	for (int i = len - 2; ~i; --i) {
		suf[i][0] = suf[i + 1][0];
		suf[i][1] = suf[i + 1][1];
		int v = son[i], d = dis[i];
		pii t = f[v][0];
		t.fst += d;
		upd(t, suf[i][0], suf[i][1]);
		t = f[v][1];
		t.fst += d;
		upd(t, suf[i][0], suf[i][1]);
	}
	for (int i = 0; i < len; ++i) {
		int v = son[i], d = dis[i];
		pii q1 = p1, q2 = p2, t = make_pair(d + a[u], u);
		q1.fst += d;
		q2.fst += d;
		upd(t, q1, q2);
		if (i) {
			t = pre[i - 1][0];
			t.fst += d;
			upd(t, q1, q2);
			t = pre[i - 1][1];
			t.fst += d;
			upd(t, q1, q2);
		}
		if (i + 1 < len) {
			t = suf[i + 1][0];
			t.fst += d;
			upd(t, q1, q2);
			t = suf[i + 1][1];
			t.fst += d;
			upd(t, q1, q2);
		}
		dfs2(v, u, q1, q2);
	}
}

void solve() {
	scanf("%lld", &n);
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
		fa[i] = i;
	}
	for (int i = 1, u, v, d; i < n; ++i) {
		scanf("%d%d%d", &u, &v, &d);
		add_edge(u, v, d);
		add_edge(v, u, d);
	}
	ll ans = 0;
	while (1) {
		int cnt = 0;
		for (int i = 1; i <= n; ++i) {
			cnt += (fa[i] == i);
		}
		if (cnt == 1) {
			break;
		}
		dfs(1, -1);
		dfs2(1, -1, make_pair(1e18, -1), make_pair(1e18, -1));
		for (int i = 1; i <= n; ++i) {
			b[i] = make_pair(1e18, -1);
		}
		for (int i = 1; i <= n; ++i) {
			pii x = g[i][0], y = g[i][1];
			x.fst += a[i];
			y.fst += a[i];
			if (find(i) == find(x.scd)) {
				b[find(i)] = min(b[find(i)], y);
			} else {
				b[find(i)] = min(b[find(i)], x);
			}
		}
		for (int i = 1; i <= n; ++i) {
			if (fa[i] == i && merge(i, b[i].scd)) {
				ans += b[i].fst;
			}
		}
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}