【数据结构】P4338 [ZJOI2018] 历史 题解

发布时间 2023-12-27 10:02:37作者: Pengzt

P4338

先考虑怎么安排崛起的先后顺序最优。

但是发现好像没有一个很好的顺序去进行崛起,并且由于 \(a_i\) 的值域会很大,所以即使知道顺序应该也会难以进行维护。

转换一下方向,正难则反。考虑每个点的贡献,但是颜色不同时只会算一次,所以要钦定是哪一个点造成的贡献。令当前考虑的点为 \(u\),发现可以在不影响 \(u\) 的祖先的的贡献的情况下对 \(u\) 子树内的点的相对操作顺序进行改变。所以 \(u\) 点所产生的贡献是很容易计算的:若 \(u\) 以及 \(u\) 的子树内的所有点的 \(a\) 值都没有超过 \(sza_u\) 的一半,则贡献为 \(sza_u-1\)。其中 \(sza_u\)\(u\)\(u\) 的子树的 \(a\) 值的和,减一是因为第一次不会产生贡献。否则就是 \(2(sza_u-mxa_u)\),因为最大的 \(a\) 无法都产生贡献,注意这里的 \(mxa_u\) 需要和 \(a_u\)\(\max\)

这时候就可以有 \(30\) 分了。

放一份暴力代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector<int>
#define eb emplace_back
#define pii pair<int, ll>
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
bool Mbe;
mt19937_64 rng(35);
constexpr int N = 4e5 + 10;
int n, m;
int sz[N];
ll ans, a[N], sza[N];
vi e[N];
void dfs(int u, int ff) {
	sz[u] = 1, sza[u] = a[u];
	ll mx = a[u];
	for(int v : e[u]) {
		if(v == ff) continue;
		dfs(v, u);
		sz[u] += sz[v];
		sza[u] += sza[v];
		if(sza[v] > mx) mx = sza[v];
	}
	if(2 * mx > sza[u]) ans += 2 * (sza[u] - mx);
	else ans += sza[u] - 1;
}
void work() {
	ans = 0;
	dfs(1, 0);
	cout << ans << "\n";
}
bool Med;
int main() {
	fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
//	freopen("history4.in", "r", stdin);
//	freopen("history.out", "w", stdout);
	ios :: sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	cin >> n >> m;
	for(int i = 1; i <= n; ++i) cin >> a[i];
	for(int i = 1; i < n; ++i) {
		int u, v;
		cin >> u >> v;
		e[u].eb(v);
		e[v].eb(u);
	}
	work();
	for(int i = 1; i <= m; ++i) {
		int x, w;
		cin >> x >> w;
		a[x] += w;
		work();
	}
	cerr << TIME << "ms\n";
	return 0;
}

然后考虑对这个暴力进行优化。

对于 \(u\to son\) 的边,若 \(sza_{son}>\dfrac{sza_u}{2}\),则称这条边为实边,否则为虚边。

考虑对 \(v\in\{\text{path}(1,u)\}\) 进行区间加后虚实边会有什么变化。发现一个很美妙的地方,就是整条路径上至多有 \(\log\sum a_i\) 条虚边。然后因为对于一条 \(fa\to x\) 的实边,\(sza_x\)\(sza_{fa}\) 同时增加,\(fa\) 的带权的重儿子显然还是 \(x\)。所以可能发生变化的只有虚边。

于是每次操作至多会修改 \(\mathcal{O}(\log\sum a_i)\) 条边,即需要支持单点修改、查询区间中为 \(1\) 的数,线段树即可。

然后可能会略有卡常,可以将维护 \(sza\) 的数组换为 BIT。

还有一个实现的小技巧,就是因为带权重儿子可能是 \(u\) 本身,所以连一个 \(i\to i+n\) 的边即可,就能把 \(i\) 的点权转到 \(i+n\) 上了,这里借鉴了 _Jxsts 的代码。

代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector<int>
#define eb emplace_back
#define pii pair<int, ll>
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
bool Mbe;
mt19937_64 rng(35);
constexpr int N = 8e5 + 10;
int n, m, q;
int dfc, sz[N], hv[N], dfn[N], top[N], fa[N], dep[N], rdfn[N], leaf[N], hva[N];
ll ans, a[N], sza[N], val[N];
vi e[N];
void dfs1(int u, int ff) {
	sz[u] = 1, sza[u] = a[u], dep[u] = dep[ff] + 1, fa[u] = ff;
	for(int v : e[u]) {
		if(v == ff) continue;
		dfs1(v, u);
		sz[u] += sz[v];
		sza[u] += sza[v];
		if(sz[v] > sz[hv[u]]) hv[u] = v;
		if(sza[v] > sza[hva[u]]) hva[u] = v;
		++leaf[u];
	}
	if(leaf[u] <= 1) leaf[u] = 1;
	else leaf[u] = 0;
	if(!leaf[u]) {
		if(2 * sza[hva[u]] > sza[u]) val[u] = 2 * (sza[u] - sza[hva[u]]);
		else val[u] = sza[u] - 1;
		ans += val[u];
		// ans += val[u] = min(2 * (sza[u] - sza[hva[u]]), sza[u] - 1);
	}
}
void dfs2(int u, int f) {
	rdfn[dfn[u] = ++dfc] = u, top[u] = f;
	if(!hv[u]) return;
	dfs2(hv[u], f);
	for(int v : e[u]) {
		if(v == hv[u] || v == fa[u]) continue;
		dfs2(v, v);
	}
}
ll v[N];
void add(int x, ll y) {
	for(; x <= m; x += x & -x) v[x] += y;
}
ll ask(int x) {
	ll res = 0;
	for(; x; x -= x & -x) res += v[x];
	return res;
}
ll ask(int l, int r) {
	return ask(r) - ask(l - 1);
}
int sum[N << 2];
void build(int x, int L, int R) {
	if(L == R) {
		sum[x] = 2 * sza[rdfn[L]] <= sza[fa[rdfn[L]]];
		return;
	}
	int m = (L + R) >> 1;
	build(x << 1, L, m);
	build(x << 1 | 1, m + 1, R);
	sum[x] = sum[x << 1] + sum[x << 1 | 1];
}
void modify(int x, int L, int R, int k, int v) {
	if(L == R) {
		sum[x] = v;
		return;
	}
	int m = (L + R) >> 1;
	if(k <= m) modify(x << 1, L, m, k, v);
	else modify(x << 1 | 1, m + 1, R, k, v);
	sum[x] = sum[x << 1] + sum[x << 1 | 1];
}
int query(int x, int L, int R, int l, int r) {
	if(!sum[x]) return -1;
	if(L == R) return rdfn[L];
	int m = (L + R) >> 1;
	if(r > m) {
		int v = query(x << 1 | 1, m + 1, R, l, r);
		if(~v || (m > r)) return v;
	}
	return l <= m ? query(x << 1, L, m, l, r) : -1;
}
void upd(int x, ll v) {
	int tmp = x;
	while(x) {
		add(dfn[top[x]], v), add(dfn[x] + 1, -v);
		x = fa[top[x]];
	}
	x = tmp;
	int r = dfn[x];
	while(x) {
		if(r < dfn[top[x]]) {
			x = fa[top[x]];
			r = dfn[x];
			continue;
		}
		int u = query(1, 1, m, dfn[top[x]], r);
		if(u == -1) {
			x = fa[top[x]];
			r = dfn[x];
			continue;
		}
		ans -= val[fa[u]];
		ll saf = ask(dfn[fa[u]]), sau = ask(dfn[u]), sahv = ask(dfn[hva[fa[u]]]);
		if(u == hva[fa[u]]) {
			if(2 * sahv > saf && 2 * (sahv - v) <= saf - v) modify(1, 1, m, dfn[u], 0);
			if(2 * sahv > saf) val[fa[u]] = 2 * (saf - sahv);
			else val[fa[u]] = saf - 1;
			ans += val[fa[u]];
		} else {
			if(2 * sahv > saf - v && 2 * sahv <= saf) modify(1, 1, m, dfn[hva[fa[u]]], 1);
			if(2 * sau > saf) modify(1, 1, m, dfn[u], 0);
			if(sau > sahv) {
				hva[fa[u]] = u;
				if(sau * 2 > saf) val[fa[u]] = 2 * (saf - sau);
				else val[fa[u]] = saf - 1;
			} else {
				if(sahv * 2 > saf) val[fa[u]] = 2 * (saf - sahv);
				else val[fa[u]] = saf - 1;
			}
			ans += val[fa[u]];
		}
		r = dfn[fa[u]];
	}
}
bool Med;
int main() {
	fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
//	freopen("history.in", "r", stdin);
//	freopen("history.out", "w", stdout);
	ios :: sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	cin >> n >> q;
	m = 2 * n;
	for(int i = 1; i <= n; ++i) {
		cin >> a[i + n];
		e[i].eb(i + n);
		e[i + n].eb(i);
	}
	for(int i = 1; i < n; ++i) {
		int u, v;
		cin >> u >> v;
		e[u].eb(v);
		e[v].eb(u);
	}
	dfs1(1, 0);
	cout << ans << "\n";
	dfs2(1, 1);
	build(1, 1, m);
	for(int i = 1; i <= m; ++i) add(i, sza[rdfn[i]]), add(i + 1, -sza[rdfn[i]]);
//	q = 1; // ...65
	for(int i = 1; i <= q; ++i) {
		int x, w;
		cin >> x >> w;
		upd(x + n, w);
		cout << ans << "\n";
	}
	cerr << TIME << "ms\n";
	return 0;
}