[题解] CF1051F The Shortest Statement

发布时间 2023-11-14 19:55:10作者: definieren

The Shortest Statement

给一张 \(n\) 个点 \(m\) 条边的无向连通图,保证 \(m - n \le 20\)\(q\) 次询问求两个点间的最短路。
\(n, m, q \le 10^5\)

由于边数只比点数多 20,所以如果我们建出这张图的一棵生成树,那么非树边至多有 21 条。

那么现在两点之间的最短路就转化成了不经过非树边的和经过非树边的最短路取 min。不经过非树边的就是树上两点之间的路径,经过非树边的就是枚举每一条非树边的端点,强制经过它,用 \(dis_{u, k} + dis_{k, v}\) 来更新答案。这部分可以直接 Dijkstra 预处理。

时间复杂度 \(O((n - m) \log m + q \log n)\)

constexpr int MAXN = 1e5 + 9, MAXLG = 17, MAXK = 50;
constexpr ll INF = 1e18;
int n, m, q, fa[MAXLG][MAXN], dep[MAXN], k;
ll dis[MAXN], dist[MAXK][MAXN];
vpii G[MAXN];
bool vis[MAXN], mark[MAXN];

void dfs(int u, int ft) {
	dep[u] = dep[fa[0][u] = ft] + 1, vis[u] = true;
	for (int i = 1; i <= 16; i ++)
		fa[i][u] = fa[i - 1][fa[i - 1][u]];
	for (auto [v, w] : G[u]) {
		if (v == ft) continue;
		if (vis[v]) { mark[u] = mark[v] = true; continue; }
		dis[v] = dis[u] + w, dfs(v, u);
	}
	return;
}
int Get_Lca(int u, int v) {
	if (dep[u] < dep[v]) swap(u, v);
	for (int i = 16; ~i; i --)
		if (dep[fa[i][u]] >= dep[v])
			u = fa[i][u];
	if (u == v) return u;
	for (int i = 16; ~i; i --)
		if (fa[i][u] != fa[i][v])
			u = fa[i][u], v = fa[i][v];
	return fa[0][u];
}
ll Get_Dis(int u, int v) {
	int lca = Get_Lca(u, v);
	return dis[u] + dis[v] - 2 * dis[lca];
}

void Dijkstra(int s) {
	priority_queue<pli, vector<pli>, greater<pli> > q;
	for (int i = 1; i <= n; i ++) dist[k][i] = INF, vis[i] = false;
	q.emplace(dist[k][s] = 0, s);
	while (q.size()) {
		int u = q.top().sec; q.pop();
		if (vis[u]) continue; vis[u] = true;
		for (auto [v, w] : G[u])
			if (dist[k][u] + w < dist[k][v]) {
				dist[k][v] = dist[k][u] + w;
				if (!vis[v]) q.emplace(dist[k][v], v);
			}
	}
	return;
}

void slv() {
	n = Read<int>(), m = Read<int>();
	for (int i = 1; i <= m; i ++) {
		int u = Read<int>(), v = Read<int>(), w = Read<int>();
		G[u].emplace_back(v, w), G[v].emplace_back(u, w);
	}
	dfs(1, 0);
	for (int i = 1; i <= n; i ++)
		if (mark[i]) k ++, Dijkstra(i);
	q = Read<int>();
	while (q --) {
		int u = Read<int>(), v = Read<int>();
		ll ans = Get_Dis(u, v);
		for (int i = 1; i <= k; i ++)
			cmin(ans, dist[i][u] + dist[i][v]);
		Write(ans, '\n');
	}
	return;
}