AtCoder Grand Contest 008 F Black Radius

发布时间 2023-03-23 12:59:36作者: zltzlt

洛谷传送门

AtCoder 传送门

神题!!!!111

考虑如何不重不漏地计数。先考虑全为 1 的情况,令 \(f(u,d)\) 为与 \(u\) 的距离 \(\le d\) 的点集。

首先单独算全集,那么对于不是全集的集合就会有一些比较好的性质。

考虑若有若干个 \(f(u,d)\) 同构,那 只在 \(d\) 最小的时候计数

那么 \(f(u,d)\) 需要满足不能覆盖全集,且不存在与 \(u\) 相邻的点 \(v\),使得 \(f(u,d) = f(v,d-1)\)(由于 \(d\) 最小的约束)。

考虑若存在后者时发生了什么。把 \(v\) 这棵子树抠掉之后,剩下的点与 \(u\) 距离 \(\le d - 2\)

\(f_u\) 为以 \(u\) 为根的子树最大深度,\(g_u\) 为以 \(u\) 为根的子树次大深度(不存在则为 \(0\)),\(d_u\)\(f(u,d)\) 最大能取到的 \(d\),则等价于 \(d_u < \min(f_u,g_u+2)\)

换根求出 \(f_u,g_u\) 即可。于是我们就做完了全为 1 的情况。

现在有一些点是 0。但是我们发现不能完全不考虑它们,因为我们发现有些 1 点的 \(d_u\) 上界过于严苛导致有些情况没有考虑到,那我们将这些情况放到 0 点计算。

发现 0 点的上界仍然可以取到,但是下界并非 \(0\)。设任意 0 点为 \(u\),则未算到的情况满足 1 点所在子树中全被覆盖,并且还可能覆盖了别的子树。设 \(h_u\) 为以 \(u\) 为根的存在 1 点的子树的最大深度,则对于 0 点,\(h_u \le d_u < \min(f_u,g_u+2)\)

此时的 \(h_u\) 仍然可以换根求出,于是我们就以 \(O(n)\) 的时空复杂度做完了。

code
// Problem: F - Black Radius
// Contest: AtCoder - AtCoder Grand Contest 008
// URL: https://atcoder.jp/contests/agc008/tasks/agc008_f
// Memory Limit: 256 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const int inf = 0x3f3f3f3f;

int n, head[maxn], len, a[maxn], sz[maxn], f[maxn], g[maxn], h[maxn];
ll ans = 1;

struct edge {
	int to, next;
} edges[maxn << 1];

void add_edge(int u, int v) {
	edges[++len].to = v;
	edges[len].next = head[u];
	head[u] = len;
}

void dfs(int u, int fa) {
	if (a[u]) {
		sz[u] = 1;
	} else {
		h[u] = inf;
	}
	for (int i = head[u]; i; i = edges[i].next) {
		int v = edges[i].to;
		if (v == fa) {
			continue;
		}
		dfs(v, u);
		sz[u] += sz[v];
		int val = f[v] + 1;
		if (val > f[u]) {
			g[u] = f[u];
			f[u] = val;
		} else if (val > g[u]) {
			g[u] = val;
		}
		if (sz[v]) {
			h[u] = min(h[u], f[v] + 1);
		}
	}
}

void dfs2(int u, int fa) {
	ans += max(0, min(f[u], g[u] + 2) - h[u]);
	for (int i = head[u]; i; i = edges[i].next) {
		int v = edges[i].to;
		if (v == fa) {
			continue;
		}
		int val = (f[u] == f[v] + 1) ? g[u] + 1 : f[u] + 1;
		if (val > f[v]) {
			g[v] = f[v];
			f[v] = val;
		} else if (val > g[v]) {
			g[v] = val;
		}
		if (sz[v] != sz[1]) {
			h[v] = min(h[v], val);
		}
		dfs2(v, u);
	}
}

void solve() {
	scanf("%d", &n);
	for (int i = 1, u, v; i < n; ++i) {
		scanf("%d%d", &u, &v);
		add_edge(u, v);
		add_edge(v, u);
	}
	for (int i = 1; i <= n; ++i) {
		scanf("%1d", &a[i]);
	}
	dfs(1, -1);
	dfs2(1, -1);
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}