【CF1797F】Li Hua and Path

发布时间 2023-07-02 15:05:26作者: Custlo

于 2023.5.10 更新 : 更正了两处笔误。

考虑如下定义:

\(A\) 表示满足第一种路径的 \((u,v)\) 集合。

\(B\) 表示满足第二种路径的 \((u,v)\) 集合。

\(C\) 表示满足前两种路径的 \((u,v)\) 集合。

然后答案显然就是 \(|A| + |B| - 2|C|\)。先求出这一类的答案,再解决动态挂叶子的问题。

考虑建出重构树大根小根各两颗。这样就可以解决关于两类限制的计数。

$A,B $:直接是子树内多少个点,这是显然的。

\(C\):考虑在大根子树里数有多少个有多少个小根树上的父亲,这样构成一条直上直下的链。

但是直接这样做是大约两只 \(\log\) 的,很劣。不如转换计数的主体,考虑枚举子树内的儿子,然后数根节点到自己路径上有多少个自己小根树上的儿子,这个是简单的,单点插入 dfs 序,区间查询,这个可以用树状数组解决。

最后考虑动态挂叶子。由于 \(u'\) 编号的性质,显然的是,加入一个叶子,他就会成为一个大根树的根节点,成为一个小根树的叶子节点,然后直接考虑计算。

  • \(A'\):显然是没有的。

  • \(B'\):除了 \(u'\) 的所有节点 \(v\) 都可以产生一对 \((u, v)\) 的贡献.

  • \(C'\):数一下自己重构树上有多少个祖先就好了。

那么增加的答案就是 \(|B'| - |C'|\) 就好了。

综上该问题被我们用 \(O(n\log n)\) 的时间解决。

注释暂且不写了。

#include <bits/stdc++.h>
#define pb push_back
#define int long long

using namespace std;

const int N = 4e5 + 5, MOD = 998244353;

int n, m, ans, dfc;
int dep[N], fa[N], sz[N], st[N], ed[N];
int v[N];

vector <int> mx[N], mn[N], e[N];

void add (int u, int k) { for ( ; u <= n; u += u & -u) v[u] += k; }
int query (int u) {
	int res = 0;
	for ( ; u; u -= u & -u) res += v[u];
	return res;
}
void init () { for (int i = 1; i <= n; i ++) fa[i] = i; }
int find (int u) { if (u != fa[u]) return fa[u] = find(fa[u]); else return u; }
void merge (int u, int v) {
	u = find(u), v = find(v);
	if (u == v) return ;
	fa[v] = fa[u]; 
}
void dfs1 (int u) {
	sz[u] = 1, st[u] = ++ dfc;
	for (int v : mx[u]) dfs1(v), sz[u] += sz[v];
	ans += sz[u] - 1, ed[u] = dfc;
}
void dfs2 (int u) {
	sz[u] = 1;
	for (int v : mn[u]) dep[v] = dep[u] + 1, dfs2(v), sz[u] += sz[v];
	ans += sz[u] - 1;
}
void dfs3 (int u) { 
	ans -= 2 * (query(ed[u]) - query(st[u] - 1));
	add(st[u], 1);
	for (int v : mn[u]) dfs3(v);
	add(st[u], -1);
}	// 考虑算多少个祖先在自己的 max子树里, 1 ~ v 
signed main () {
	// freopen(".in", "r", stdin);
	// freopen(".out", "w", stdout);
	ios :: sync_with_stdio(0);
	cin >> n, init();
	for (int i = 1, u, v; i < n; i ++) cin >> u >> v, e[u].pb(v), e[v].pb(u);
	for (int u = 1; u <= n; u ++) 
		for (int v : e[u]) if (v < u && find(u) != find(v)) 
			mx[u].pb(find(v)), merge(u, v); // 此时 v 属于的祖先一定是严格劣于 u 的, fv -> u 
	init(), dfs1(n);
	for (int u = n; u; u --) 
		for (int v : e[u]) if (v > u && find(u) != find(v)) 
			mn[u].pb(find(v)), merge(u, v);	
	dfs2(1), dfs3(1);
	cout << ans << endl, cin >> m;
	for (int i = 1, u; i <= m; i ++) {
		cin >> u, dep[n + i] = dep[u] + 1;
//		cout << u << endl;
		ans += n + i - dep[n + i] - 1;
		cout << ans << endl;
	}
	return 0;
}