P9745 「KDOI-06-S」树上异或 题解

发布时间 2023-10-16 17:24:13作者: MoyouSayuki

P9745 「KDOI-06-S」树上异或 题解

\(x_i = 0\)

这题一看就不是很可做,先考虑部分分。

对于一条链的情况,我们可以枚举上一个断边的位置,然后转移。

一看数据范围,估计和值域有关,所以考虑 \(x_i = 1\) 的部分分,如果全部点权都是 1,那么一种方案只有 0 和 1 两种取值,考虑这个状态设计:\(f_{i, 0/1}\) 表示在以 \(i\) 为根的子树里面,\(i\) 的连通块异或值是 0/1,且除了 \(i\) 所在连通块其它连通块都是合法的 的连通块权值乘积的和。

然后转移就枚举每一个儿子选或不选即可,例如 \(f_{i,0}\) 可以这么转移,有点树上背包的味道:

\[f_{i,0} =\sum_{(i, v) \in E} f_{i,0}\times f_{v, 0} + f_{i, 1}\times f_{v,1} + f_{i,0}\times f_{v, 1} \]

分别表示选 \(v\),并且 \(v\) 的连通块异或和也是 \(0\);选 \(v\) 两个 \(1\) 异或起来是 \(0\);不选 \(v\),需要保证 \(v\) 子树内连通块合法。

思路

考虑把上述思路扩展到 \(x_i \ne 1\) 的情况。

\(f_{i, j, 0/1}\) 表示示在以 \(i\) 为根的子树里面,\(i\) 的连通块的第 \(j\) 位异或值是 0/1,且除了 \(i\) 所在连通块其它连通块都是合法的 的连通块权值乘积的和,再用一个 \(g_{i}\) 表示子树 \(i\) 内的连通块权值乘积的和。

\[f_{u,i,0} = \sum_{(i, v)\in E} f_{u, i, 0}\times g_v+f_{u, i, 0}\times f_{v, i, 0} + f_{v, i, 1}\times f_{u, i, 1} \]

最后用方案数乘上 \(2^{i}\) 就是这一位的贡献了,\(g_i\) 就可以通过 \(f\) 转移:

\[g_{u} = \sum_{i = 0}^{60}f_{u, i, 1}\times 2^i \]

代码

实现的时候要注意用临时变量存一下 \(f_{u, i, 0/1}\),防止环形转移。

另外由于每一条边 \((u,v )\) 都满足 \(u < v\),所以可以直接枚举边不用 DFS。

时间复杂度:\(O(n\log V)\)

// Problem: P9745 「KDOI-06-S」树上异或
// Contest: Luogu
// Author: Moyou
// Copyright (c) 2023 Moyou All rights reserved.
// Date: 2023-10-15 21:59:52

#include <algorithm>
#include <cstring>
#include <iostream>
#include <queue>
// #define int long long
using namespace std;
const int N = 5e5 + 10, mod = 998244353;

int f[N][61][2], g[N], n;
long long a[N];
vector<int> G[N];

signed main() {
    n = read();
    for(int i = 1; i <= n; i ++) a[i] = read();
    for(int i = 2; i <= n; i ++) G[read()].push_back(i);
    for(int u = n; u; u --) {
        for(int i = 0; i < 60; i ++) f[u][i][a[u] >> i & 1] = 1;
        for(auto v : G[u]) {
            for(int i = 0; i < 60; i ++) {
                int a = f[u][i][0], b = f[u][i][1];
                f[u][i][0] = (1ll * f[u][i][0] * (g[v] + f[v][i][0]) + 1ll * f[v][i][1] * b) % mod;
                f[u][i][1] = (1ll * f[u][i][1] * (g[v] + f[v][i][0]) + 1ll * f[v][i][1] * a) % mod;
            }
        }
        for(int i = 0, p = 1; i < 60; i ++, p = p * 2 % mod) {
            g[u] = (1ll * g[u] + 1ll * p * f[u][i][1]) % mod;
        }
    }
    printf("%d\n", g[1]);
    return 0;
}