CodeForces 1856E2 PermuTree (hard version)

发布时间 2023-08-06 08:33:11作者: zltzlt

洛谷传送门

CF 传送门

考虑局部贪心,假设我们现在在 \(u\),我们希望 \(u\) 不同子树中的 \((v, w), a_v < a_u < a_w\) 的对数尽量多。

我们实际上只关心子树内 \(a_u\) 的相对大小关系,不关心它们具体是什么。如果 \(u\) 只有两个儿子 \(v, w\),我们可以让 \(v\) 子树内的 \(a\) 全部小于 \(w\) 子树内的 \(a\),这样 \(u\) 作为 \(\text{LCA}\) 的贡献是 \(sz_v \times sz_w\),是最大的。

那么对于 \(u\) 有多个儿子的情况,推广可知相当于把 \(u\) 的儿子分成 \(S, T\) 两个集合,最大化 \(\sum\limits_{v \in S} sz_v \times \sum\limits_{v \in T} sz_v\)。考虑做一个 \(sz_v\) 的 01 背包,若能把 \(sz_v\) 分成大小为 \(x\) 的集合,\(u\) 对答案的贡献是 \(x \times (sz_u - 1 - x)\)。取这个的最大值即可。

01 背包暴力做即可,根据树形背包的那套理论,每个点对只会在 \(\text{LCA}\) 处被统计,所以时间复杂度 \(O(n^2)\),可以通过 E1。

对于 E2,我们肯定不能再暴力 01 背包了。发现我我们背包的复杂度跟 \(sz_v\) 有关。联想到 dsu on tree,轻子树的大小之和为 \(O(n \log n)\)。于是我们考虑将 \(u\)\(sz\) 最大的两个儿子拎出来,剩下的儿子做一个背包,然后再枚举那两个儿子选不选。

至于如何做背包,我们把 \(sz_v\) 相同的物品看做一种有多个的物品,做单调队列优化多重背包即可。因为去掉两个最大子树后,\(sz_v\) 之和为 \(n \log n\),所以不同的 \(sz_v\)\(O(\sqrt{n \log n})\) 种。

所以这么算下来复杂度其实是 \(O(n \sqrt{n \log n} \log n)\),但是它过了???

code
// Problem: E2. PermuTree (hard version)
// Contest: Codeforces - Codeforces Round 890 (Div. 2) supported by Constructor Institute
// URL: https://codeforces.com/contest/1856/problem/E2
// Memory Limit: 512 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int read() {
    char c = getchar();
    int x = 0;
    for (; !isdigit(c); c = getchar()) ;
    for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
    return x;
}

const int maxn = 1000100;

int n, sz[maxn];
bool f[maxn], g[maxn];
ll ans;
vector<int> G[maxn];

void dfs(int u) {
	sz[u] = 1;
	vector<int> vc;
	for (int v : G[u]) {
		dfs(v);
		sz[u] += sz[v];
		vc.pb(sz[v]);
	}
	int m = (int)vc.size();
	if (m <= 2) {
		ll mx = 0;
		for (int S = 0; S < (1 << m); ++S) {
			int s = 0;
			for (int i = 0; i < m; ++i) {
				if (S & (1 << i)) {
					s += vc[i];
				}
			}
			mx = max(mx, 1LL * s * (sz[u] - 1 - s));
		}
		ans += mx;
		return;
	}
	sort(vc.begin(), vc.end(), greater<int>());
	int s = 0;
	for (int i = 2; i < m; ++i) {
		s += vc[i];
	}
	s /= 2;
	for (int i = 0; i <= s; ++i) {
		f[i] = 0;
	}
	f[0] = 1;
	for (int l = 2, r = 2; l < m; l = (++r)) {
		while (r + 1 < m && vc[r + 1] == vc[l]) {
			++r;
		}
		for (int i = 0; i <= s; ++i) {
			g[i] = f[i];
			f[i] = 0;
		}
		int c = r - l + 1, v = vc[l];
		for (int i = 0; i < v; ++i) {
			int cnt = 0, t = i;
			for (int j = i, k = 0; j <= s; j += v, ++k) {
				cnt += g[j];
				if (k > c) {
					cnt -= g[t];
					t += v;
				}
				f[j] = (cnt ? 1 : 0);
			}
		}
	}
	ll mx = 0;
	for (int i = 0; i <= s; ++i) {
		if (!f[i]) {
			continue;
		}
		for (int S = 0; S < 4; ++S) {
			int k = ((S & 1) ? vc[0] : 0) + ((S & 2) ? vc[1] : 0) + i;
			mx = max(mx, 1LL * k * (sz[u] - 1 - k));
		}
	}
	ans += mx;
}

void solve() {
	n = read();
	for (int i = 2, p; i <= n; ++i) {
		p = read();
		G[p].pb(i);
	}
	dfs(1);
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}