hot tea.
一次删点操作的影响太大了,考虑添加虚点以减小影响(相同的套路在 CF1882E2 Two Permutations (Hard Version) 也出现过)。
具体而言,我们把第 \(i\) 条边 \((u, v)\) 变成 \((u, n + i), (v, n + i)\)。称编号 \(\le n\) 的点为黑点,编号 \(> n\) 的点为白点。
因为最后只剩 \(n\) 一个单点,所以不妨以 \(n\) 为根。
那么删掉一个黑点 \(u\) 就可以看成,把 \(u\) 的所有儿子(白点)合并到 \(fa_u\),也就是把 \(u\) 的所有儿子的所有子树都接到 \(fa_u\) 下面。
然后经过若干次删点操作后两个黑点有边等价于存在一个白点连接这两个黑点。因为操作后原本的这棵树还是一棵树,所以这样的白点最多只有一个。
先考虑一次怎么算答案。对于一个三元组 \((a, b, c)\) 把它看成 \((a, x, b, y, c)\),其中 \(x\) 和 \(y\) 分别为 \(a, b\) 和 \(b, c\) 之间的白点。
分类讨论一下。
- 若 \(x = y\),\(x\) 的相邻点可以任选 \(3\) 个。若设 \(f_u\) 为白点 \(u\) 的儿子数量,那么这个对答案的贡献即为 \((f_x + 1) \times f_x \times (f_x - 1)\)。
- 若 \(x \ne y\) 且 \(x, y\) 是 \(b\) 的两个儿子,我们在 \(b\) 处统计贡献。若设 \(g_u\) 为黑点 \(u\) 的二级儿子数量,贡献即为 \(g_b^2 - \sum\limits_{u \in son_b} f_u^2\)。观察到每个 \(u\) 会恰好被算一遍 \(- f_u^2\),于是可以把这个扔到上一种情况的贡献。
- 若 \(x \ne y\) 且 \(x, y\) 分别是 \(b\) 的一个儿子和父亲,不妨设 \(x\) 为 \(b\) 的父亲。我们在 \(x\) 处统计答案。若设 \(h_u\) 为白点 \(u\) 的三级儿子数量,贡献即为 \(2 \times f_x \times h_x\),其中 \(f_x\) 为选择 \(a\) 的方案。
综上,我们需要维护:
\[\sum\limits_{i = 1}^n g_i^2 - \sum\limits_{i = n + 1}^{2n - 1} f_i (f_i + 1) (f_i - 1) - f_i^2 - 2 f_i h_i
\]
一次对 \(u\) 的删除操作,只会影响 \(u\) 的不超过三级的祖先。暴力更新这些点的 \(f, g, h\) 即可。还要去除 \(u\) 和 \(u\) 的儿子原来的贡献。
使用并查集维护白点 \(u\) 最终被合并到了哪个点。时间复杂度取决于并查集。
code
// Problem: P9194 [USACO23OPEN] Triples of Cows P
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P9194
// Memory Limit: 256 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 400100;
ll n, a[maxn], f[maxn], g[maxn], h[maxn], fa[maxn];
vector<int> G[maxn];
int find(int x) {
return a[x] == x ? x : a[x] = find(a[x]);
}
void dfs(int u, int fa) {
for (int v : G[u]) {
if (v == fa) {
continue;
}
::fa[v] = u;
dfs(v, u);
if (u <= n) {
g[u] += f[v];
} else {
++f[u];
h[u] += g[v];
}
}
}
void solve() {
scanf("%lld", &n);
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(n + i);
G[n + i].pb(u);
G[v].pb(n + i);
G[n + i].pb(v);
}
dfs(n, -1);
ll ans = 0;
for (int i = 1; i <= n; ++i) {
ans += g[i] * g[i];
}
for (int i = n + 1; i < n * 2; ++i) {
ans += (f[i] + 1) * f[i] * (f[i] - 1) - f[i] * f[i] + f[i] * h[i] * 2;
a[i] = i;
}
printf("%lld\n", ans);
for (int u = 1; u < n; ++u) {
ans -= g[u] * g[u];
int x = find(fa[u]);
int y = fa[x];
int z = find(fa[y]);
// printf("%d %d %d %d\n", u, x, y, z);
ans -= (f[x] + 1) * f[x] * (f[x] - 1) - f[x] * f[x] + f[x] * h[x] * 2;
ans -= g[y] * g[y];
if (z) {
ans -= (f[z] + 1) * f[z] * (f[z] - 1) - f[z] * f[z] + f[z] * h[z] * 2;
}
--f[x];
--g[y];
if (z) {
--h[z];
}
for (int v : G[u]) {
if (v == fa[u]) {
continue;
}
a[v] = x;
f[x] += f[v];
h[x] -= f[v];
h[x] += h[v];
g[y] += f[v];
if (z) {
h[z] += f[v];
}
ans -= (f[v] + 1) * f[v] * (f[v] - 1) - f[v] * f[v] + f[v] * h[v] * 2;
}
ans += (f[x] + 1) * f[x] * (f[x] - 1) - f[x] * f[x] + f[x] * h[x] * 2;
ans += g[y] * g[y];
if (z) {
ans += (f[z] + 1) * f[z] * (f[z] - 1) - f[z] * f[z] + f[z] * h[z] * 2;
}
// for (int i = 1; i <= n * 2 - 1; ++i) {
// printf("%lld %lld %lld\n", f[i], g[i], h[i]);
// }
printf("%lld\n", ans);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}