【线段树合并、虚树】P5327 [ZJOI2019] 语言

发布时间 2023-09-15 21:37:54作者: Lucis0N

终于 1k AC 了家人,感动吧。

贺了很久,很累。

前置题目:P3320 [SDOI2015] 寻宝游戏

虚树的边权和:

\[\sum dep_{a_x} - \sum_{x < n} dep_{a_x, a_{x + 1}} - dep_{a_{1}, a_{n}} \]

考虑转化贡献,求过该点的链的并,最后再除以二即可。

那么我们可以考虑维护以该点的子树的所有关键点以及其对子形成的虚树,那么求虚树的边权和即为所求。

然后我们考虑线段树合并维护这个问题,考虑在 \(dfn\) 上建权值线段树,然后我们每次维护左边右边的点,然后 pushup 时拼起来即可。

考虑树上差分,然后我们就可以做到塞点。注意如果要 \(O(n \log n)\) 可以使用 euler 序求 LCA。用 st 表维护区间的 \(dep\) 最小值即可。

#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; i ++)
#define per(i, r, l) for (int i = r; i >= l; i --) 
using namespace std;

const int _ = 1e5 + 5, __ = _ * 20;

int n, m, dfc, cnt;
int tr[_], dfn[_], eul[_], dep[_];
int lg[_], st[_][20];
long long ans;

vector <int> e[_], ad[_], de[_];

void dfs (int x, int fa) {
	dep[x] = dep[fa] + 1,
	tr[++ dfc] = x, dfn[x] = dfc;
	eul[x] = ++ cnt, st[cnt][0] = x;
	for (int y : e[x]) 
		if (y != fa) dfs(y, x), st[++ cnt][0] = x;
}
int getLCA (int x, int y) {
	int l = eul[x], r = eul[y], k;
	if (l > r) swap(l, r);
	k = lg[r - l + 1];
	x = st[l][k], y = st[r - (1 << k) + 1][k];
	return dep[x] < dep[y] ? x : y;
}
int calc (int x, int y) {
	if (!x || !y) return 0;
	return dep[y] - dep[getLCA(x, y)];
}

int rt[_];
int tot, lc[__], rc[__], s[__], w[__], lp[__], rp[__];
void pushup (int x) {
	s[x] = s[lc[x]] + s[rc[x]] + calc(rp[lc[x]], lp[rc[x]]),
	lp[x] = lp[lc[x]] ? lp[lc[x]] : lp[rc[x]],
	rp[x] = rp[rc[x]] ? rp[rc[x]] : rp[lc[x]];
	return ;
}
void modify (int & x, int l, int r, int v, int k) {
	if (! x) x = ++ tot;
	if (l == r) {
		w[x] += k;
		if (w[x]) lp[x] = rp[x] = tr[v];
		else lp[x] = rp[x] = 0;
		return ;
	}
	int mid = (l + r) >> 1;
	v <= mid ? modify(lc[x], l, mid, v, k) : modify(rc[x], mid + 1, r, v, k);
	pushup(x); 
}
int merge (int x, int y, int l, int r) {
	if (!x || !y) return x | y;
	if (l == r) {
		w[x] += w[y];
		if (w[x]) lp[x] = rp[x] = tr[l];
		else lp[x] = rp[x] = 0;
		return x;
	}
	int mid = (l + r) >> 1;
	lc[x] = merge(lc[x], lc[y], l, mid), 
	rc[x] = merge(rc[x], rc[y], mid + 1, r);
	return pushup(x), x;
}

void dfs2 (int x, int fa) {
	for (int y : e[x]) {
		if (y == fa) continue ;
		dfs2(y, x),
		rt[x] = merge(rt[x], rt[y], 1, n); 
	}
	for (int p : ad[x]) modify(rt[x], 1, n, p, 1);
	ans += s[rt[x]] + dep[lp[rt[x]]] - dep[getLCA(lp[rt[x]], rp[rt[x]])];
//	cout << lp[rt[x]] << " " << rp[rt[x]] << endl;
	for (int p : de[x]) modify(rt[x], 1, n, p, -2);
}

int main () {
	cin >> n >> m;
	rep(i, 1, n - 1) {
		int x, y;
		scanf("%d%d", & x, & y);
		e[x].push_back(y), e[y].push_back(x);
	}
	
	dfs(1, 0);
	rep(i, 2, cnt) lg[i] = lg[i >> 1] + 1;
	for (int k = 1; (1 << k) <= cnt; k ++) 
		rep(i, 1, cnt - (1 << k) + 1) {
			int x = st[i][k - 1], y = st[i + (1 << k - 1)][k - 1];
			st[i][k] = dep[x] < dep[y] ? x : y;
		}
		
	rep(i, 1, m) {
		int x, y, lca;
		scanf("%d%d", & x, & y);
		lca = getLCA(x, y);
		ad[x].push_back(dfn[x]), ad[x].push_back(dfn[y]),
		ad[y].push_back(dfn[x]), ad[y].push_back(dfn[y]);
		de[lca].push_back(dfn[x]), de[lca].push_back(dfn[y]);
	}
	dfs2(1, 0);
	cout << ans / 2ll;
	return 0;
}