[CEOI2017] Mousetrap

发布时间 2023-06-01 21:07:12作者: Smallbasic

100黑祭。

首先以终点为根。

先考虑简单一点的情况:如果起点终点相邻,那么方案一定是让老鼠先走到一个叶子节点,然后断掉该节点到根路径上其它的分支。于是我们令 \(f_i\) 表示从 \(i\) 开始走到 \(i\) 子树里的一个叶节点再返回所需的最小代价,每次dp从儿子里的次大值转移即可。

考虑不相邻的情况,老鼠的路径还可能是先向根走一段再向下走。考虑直接计算答案是不好算的,但是可以二分答案。假设当前的答案为 \(k\),从下往上每局老鼠向根走到了哪一个节点。注意到我们只需要检查答案是否在 \(k\) 以内,而老鼠一旦下拐到某个儿子,我们是可以用之前的 \(dp\) 数组以及当前节点到根的链上点的度数和 \(O(1)\) 算出所需步数的。所以我们每到一个点都会花费步数断掉所需步数大于剩余步数的儿子节点。不合法当且仅当在判定的时候用完了步数或者来不及封堵儿子。

#include <bits/stdc++.h>

using namespace std;

inline int read() {
	register int s = 0, f = 1; register char ch = getchar();
	while (!isdigit(ch)) f = (ch == '-' ? -1 : 1), ch = getchar();
	while (isdigit(ch)) s = (s * 10) + (ch & 15), ch = getchar();
	return s * f;
}

const int N = 1e6 + 5;

struct edge {
	int head, to, nxt;
} ed[N << 2];

int en = 0, f[N];

inline void addedge(int from, int to) {
	ed[++en].to = to; ed[en].nxt = ed[from].head; ed[from].head = en;
}

int S, T, n;

const int inf = -1e9;

bool vis[N];
int fa[N], sum[N], d[N], stk[N], top = 0;

inline void getfa(int now, int p) {
	fa[now] = p;
	for (int i = ed[now].head; i; i = ed[i].nxt) {
		int v = ed[i].to;
		if (v == p) continue;
		getfa(v, now);
	}
}

inline void dfs(int now, int fa) {
	int mx = -1e9, mx2 = -1e9, cnt = 0;
	for (int i = ed[now].head; i; i = ed[i].nxt) {
		int v = ed[i].to;
		if (v == fa) continue;
		dfs(v, now);
		if (f[v] >= mx) mx2 = mx, mx = f[v];
		else if (f[v] > mx2) mx2 = f[v];
		++cnt;
	}
	if (!cnt) f[now] = 0;
	else if (cnt == 1) f[now] = 1;
	else f[now] = mx2 + cnt;
}

inline bool check(int mid) {
	int res = 0;
	for (int i = 1; i < top; ++i) {
		int x = stk[i], d = 0;
		for (int j = ed[x].head; j; j = ed[j].nxt) {
			int v = ed[j].to;
			d += !((v == stk[i - 1]) || (v == stk[i + 1]) || (sum[i] + f[v] - (x != S) < mid));
		} res += d; mid -= d;
		if (mid < 0 || res > i) return 0;
	} return 1;
}

int main() {
	n = read(); T = read(); S = read();
	for (int i = 1, u, v; i < n; ++i) {
		u = read(); v = read();
		addedge(u, v); addedge(v, u);
		++d[u]; ++d[v];
	} dfs(T, 0); getfa(T, 0); f[T] = 0;
	for (int i = S; i; i = fa[i]) stk[++top] = i;
	for (int i = top - 1; i; --i) sum[i] = sum[i + 1] + d[stk[i]] - 1 - (stk[i] != T);
	int l = 0, r = 1, mid, res = 1e9;
	while (l <= r) {
		mid = l + r >> 1;
		if (check(mid)) r = (res = mid) - 1;
		else l = mid + 1;
	} printf("%d\n", res);
	return 0;
}