【主席树】P8201 [传智杯 #4 决赛] [yLOI2021] 生活在树上(hard version)题解

发布时间 2023-10-07 13:32:50作者: Pengzt

P8201

简单题。

题中求的是 \(dis_{a, t} \oplus dis_{t, b} = k\) 是否存在,显然不好直接维护,考虑转化。

\(dist = dis_{a, t} \oplus dis_{t, b}\)\(val = \bigoplus\limits_{x\in \text{path}(a, b)} w_x\)

如果 \(t\)\(\text{path}(a, b)\) 上,则 \(dist = val \oplus a_t\)
如果 \(t\) 不在上面,其实就是从 \(a\)\(b\) 的简单路径上的某点走了一个来回,显然这里的 \(dist\) 值所构成的集合一定被前一种情况包含。

此时就变为了判断是否存在点 \(t\),满足在 \(a\)\(b\) 的路径上,且 \(val \oplus w_t = k\)。即找路径上是否存在某点。

由于是查询无修改的路径信息,考虑对每个结点建一棵权值线段树,维护的时该点到根节点的路径上每个点的权值出现的次数。显然此时可以直接动态开点。由于空间比较紧,需要离散化。

代码:

const int N = 5e5 + 10, W = 1e7 + 1;
int n, m, tot, num;
int a[N], b[N << 1], d[N], rt[N], val[N], fa[N][20], u[N], v[N], k[N], lca[N];
vector<int> e[N];
struct segt {
	int ls, rs, v;
} tr[N * 80];

void build(int u, int l, int r) {
	if (l == r) return;
	int mid = (l + r) >> 1;
	build(tr[u].ls = ++tot, l, mid), build(tr[u].rs = ++tot, mid + 1, r);
}
void pushup(int u) {
	tr[u].v = tr[tr[u].ls].v + tr[tr[u].rs].v;
}
void ins(int &u, int lstu, int l, int r, int k, int val) {
	if (!u) u = ++tot;
	if (l == r) {
		tr[u].v += val;
		return;
	}
	int mid = (l + r) >> 1;
	if (k <= mid) {
		tr[tr[u].ls = ++tot] = tr[tr[lstu].ls];
		tr[u].rs = tr[lstu].rs;
		ins(tr[u].ls, tr[lstu].ls, l, mid, k, val);
	} else {
		tr[tr[u].rs = ++tot] = tr[tr[lstu].rs];
		tr[u].ls = tr[lstu].ls;
		ins(tr[u].rs, tr[lstu].rs, mid + 1, r, k, val);
	}
	pushup(u);
}
int query(int u1, int u2, int u3, int u4, int l, int r, int k) {
	if (l == r) return tr[u1].v + tr[u2].v - tr[u3].v - tr[u4].v;
	int mid = (l + r) >> 1;
	if (k <= mid) return query(tr[u1].ls, tr[u2].ls, tr[u3].ls, tr[u4].ls, l, mid, k);
	return query(tr[u1].rs, tr[u2].rs, tr[u3].rs, tr[u4].rs, mid + 1, r, k);
}

void dfs1(int u, int f) {
	d[u] = d[f] + 1, fa[u][0] = f, val[u] = val[f] ^ a[u];
	for (int i = 1; i <= 19; i++) fa[u][i] = fa[fa[u][i - 1]][i - 1];
	for (int i = 0, v; i < (int)e[u].size(); i++)
		if ((v = e[u][i]) != f)
			dfs1(v, u);
}
void dfs2(int u, int f) {
	a[u] = lower_bound(b + 1, b + num + 1, a[u]) - b;
	ins(rt[u], rt[f], 1, num, a[u], 1);
	for (int i = 0, v; i < (int)e[u].size(); i++)
		if ((v = e[u][i]) != f)
			dfs2(v, u);
}
int getlca(int u, int v) {
	if (d[u] < d[v]) swap(u, v);
	for (int i = 19; i >= 0; i--)
		if (d[fa[u][i]] >= d[v])
			u = fa[u][i];
	if (u == v) return u;
	for (int i = 19; i >= 0; i--)
		if (fa[u][i] != fa[v][i])
			u = fa[u][i], v = fa[v][i];
	return fa[u][0];
}

int main() {
	ios
	cin >> n >> m;
	for (int i = 1; i <= n; i++) cin >> a[i], b[++num] = a[i];
	for (int i = 1, u, v; i < n; i++) {
		cin >> u >> v;
		e[u].pb(v), e[v].pb(u);
	}
	dfs1(1, 0);
	for (int i = 1; i <= m; i++) {
		cin >> u[i] >> v[i] >> k[i]; lca[i] = getlca(u[i], v[i]);
		b[++num] = (val[u[i]] ^ val[v[i]] ^ a[lca[i]] ^ k[i]);
	}
	sort(b + 1, b + num + 1); num = unique(b + 1, b + num + 1) - b - 1;
	build(rt[0] = ++tot, 1, num);
	dfs2(1, 0);
	for (int i = 1; i <= m; i++) {
		int s = lower_bound(b + 1, b + num + 1, val[u[i]] ^ val[v[i]] ^ b[a[lca[i]]] ^ k[i]) - b;
		if (query(rt[u[i]], rt[v[i]], rt[lca[i]], rt[fa[lca[i]][0]], 1, num, s) > 0)
			cout << "Yes" << "\n";
		else cout << "No" << "\n";
	}
	return 0;
}