[题解] P6773 [NOI2020] 命运

发布时间 2023-11-10 09:41:57作者: definieren

P6773 [NOI2020] 命运

给你一棵 \(n\) 个节点的树,要给每条边染成 \(0\)\(1\)
\(m\) 个限制 \((u, v)\) 满足 \(u\)\(v\) 祖先,表示 \(u\)\(v\) 的路径中至少有一条边被染成了 1。
求方案数。
\(n, m \le 5 \times 10^5\)

首先会想到容斥,但是很没前途,所以考虑直接 dp。

一个自然的想法是设 \(f_{u, 0/1}\) 表示 限制的上端点在 \(u\) 的子树内限制全都满足,下端点在 \(u\) 的子树内但上端点不在的限制满足 / 不满足 的方案数。

但这样无法转移,死在对于 下端点在 \(u\) 的子树内但上端点不在的限制 的状态设计过于简略。

考虑到在下端点相同的情况下,如果上端点深度深的限制满足了,那么上端点浅的限制一定会被满足,同样如果上端点深的限制没被满足,上端点浅的也不会被满足。所以对于 下端点在 \(u\) 的子树内但上端点不在的限制 的满足情况,一定存在一个深度 \(dep\),满足深度小于 \(dep\) 的点全不满足,大于等于的全满足。所以可以用 最深的没被满足的限制的上端点 来刻画 下端点在 \(u\) 的子树内但上端点不在的限制 的满足情况。

所以设计状态为 \(f_{u, i}\) 表示 限制的上端点在 \(u\) 的子树内限制全都满足,下端点在 \(u\) 的子树内但上端点不在的限制 最深的没满足的上端点的深度为 \(i\) 的方案数。

转移考虑合并 \(u\) 和它的一个儿子 \(v\),分类讨论一下:

  • \((u, v)\) 染成了 \(1\):这时下端点在 \(v\) 子树内的限制必然全部满足,转移是 \(f'_{u, i} \leftarrow \sum_{j = 0}^{dep_u} f_{u, i} \times f_{v, j}\)
  • \((u, v)\) 染成了 \(0\):这意味着 \(u\) 的 最深的没满足的限制的上端点的深度 和 \(v\) 的 最深的没满足的限制的上端点的深度 的最大值必须为 \(i\),转移是 \(f'_{u, i} \leftarrow \sum_{j = 0}^{i} f_{u, i} \times f_{v, j} + \sum_{j = 0}^{i - 1} f_{u, j} \times f_{v, i}\)。后面的式子的上表是 \(i - 1\) 是因为 \(j = i\) 时会算两遍。

合起来转移就是:

\[f'_{u, i} \leftarrow \sum_{j = 0}^{dep_u} f_{u, i} \times f_{v, j} + \sum_{j = 0}^i f_{u, i} \times f_{v, j} + \sum_{j = 0}^{i - 1} f_{u, j} \times f_{v, i} \]

这个看着就很前缀和,所以记 \(g_{u, i} = \sum_{j = 0}^i f_{u, i}\)

然后转移式子就变成了:

\[f'_{u, i} \leftarrow f_{u, i} \times (g_{v, dep_u} + g_{v, i}) + g_{u, i - 1} \times f_{v, i} \]

这个直接转移是 \(O(n^2)\) 的,还要继续优化。

可以考虑拉到线段树上进行线段树合并。\(g_{v, dep_u}\) 可以先一次查询查出来,然后就是再线段树合并的过程中维护 \(sum_1 = g_{v, dep} + g_{v, i}\)\(sum_2 = g_{u, i - 1}\),在叶子结点对应地乘起来就行。

时间复杂度 \(O(n \log n)\)

实现还是比较简单的。

constexpr int MAXN = 5e5 + 9;
int n, m, dep[MAXN], rt[MAXN];
vector<int> G[MAXN], stk, up[MAXN];

struct Node {
	int ls, rs;
	int sum, mul;
	
	Node(): ls(0), rs(0), sum(0), mul(1) { return; }
	Node(int _ls, int _rs, int _su, int _mu):
		ls(_ls), rs(_rs), sum(_su), mul(_mu) { return; }
} sgt[MAXN * 20];

int New_Node() {
	static int tot = 0; int id = 0;
	if (stk.empty()) id = ++ tot;
	else id = stk.back(), stk.pop_back();
	sgt[id] = Node(); return id;
}
void Push_Tag(int p, int k) {
	cmul(sgt[p].sum, k), cmul(sgt[p].mul, k);
	return;
}
void Push_Down(int p) {
	if (sgt[p].mul ^ 1) {
		Push_Tag(sgt[p].ls, sgt[p].mul);
		Push_Tag(sgt[p].rs, sgt[p].mul);
		sgt[p].mul = 1;
	}
	return;
}
void Push_Up(int p) {
	sgt[p].sum = add(sgt[sgt[p].ls].sum, sgt[sgt[p].rs].sum);
	return;
}
void Update(int &p, int pos, int k, int L = 0, int R = n) {
	if (!p) p = New_Node();
	if (L == R) { sgt[p].sum += k; return; }
	Push_Down(p); int Mid = L + R >> 1;
	if (pos <= Mid) Update(sgt[p].ls, pos, k, L, Mid);
	else Update(sgt[p].rs, pos, k, Mid + 1, R);
	Push_Up(p); return;
}
int Query(int p, int l, int r, int L = 0, int R = n) {
	if (l <= L && R <= r) return sgt[p].sum;
	Push_Down(p); int Mid = L + R >> 1;
	if (r <= Mid) return Query(sgt[p].ls, l, r, L, Mid);
	if (Mid < l) return Query(sgt[p].rs, l, r, Mid + 1, R);
	return add(Query(sgt[p].ls, l, r, L, Mid), Query(sgt[p].rs, l, r, Mid + 1, R));
}
int Merge(int p, int q, int &sv, int &su, int L = 0, int R = n) {
	if (!p && !q) return 0;
	if (!p || !q) {
		if (!p) {
			cadd(sv, sgt[q].sum), Push_Tag(q, su);
			return q;
		} else {
			cadd(su, sgt[p].sum), Push_Tag(p, sv);
			return p;
		}
	}
	if (L == R) {
		int sumu = sgt[p].sum, sumv = sgt[q].sum;
		cadd(sv, sumv), cmul(sgt[p].sum, sv);
		cmul(sgt[q].sum, su), cadd(su, sumu);
		cadd(sgt[p].sum, sgt[q].sum); return p;
	}
	Push_Down(p), Push_Down(q);
	int Mid = L + R >> 1;
	sgt[p].ls = Merge(sgt[p].ls, sgt[q].ls, sv, su, L, Mid);
	sgt[p].rs = Merge(sgt[p].rs, sgt[q].rs, sv, su, Mid + 1, R);
	Push_Up(p); return p;
}

void dfs(int u, int fa) {
	dep[u] = dep[fa] + 1; int mxd = 0;
	for (auto v : up[u]) cmax(mxd, dep[v]);
	Update(rt[u], mxd, 1);
	for (auto v : G[u]) {
		if (v == fa) continue; dfs(v, u);
		int sv = Query(rt[v], 0, dep[u]), su = 0;
		rt[u] = Merge(rt[u], rt[v], sv, su);
	}
	return;
}

void slv() {
	n = Read<int>();
	for (int i = 1; i <= n - 1; i ++) {
		int u = Read<int>(), v = Read<int>();
		G[u].emplace_back(v), G[v].emplace_back(u);
	}
	m = Read<int>();
	for (int i = 1; i <= m; i ++) {
		int u = Read<int>(), v = Read<int>();
		up[v].emplace_back(u);
	}
	dfs(1, 0);
	Write(Query(rt[1], 0, 0), '\n');
	return;
}