[CF1794E] Labeling the Tree with Distances 题解

发布时间 2023-08-26 13:12:12作者: MoyouSayuki

[CF1794E] Labeling the Tree with Distances 题解

题目描述

给你一个树,边权为 \(1\)。给定 \(n-1\) 个数,你需要将这些数分配到 \(n-1\) 个节点上。

一个点 \(x\) 是好的,当且仅当存在一种分配方案,所有被分配数的点到 \(x\) 的最短路径长度等于其被分配的数。

求所有好点。

思路

从特殊情况出发,如果有 \(n\) 个数的话怎么做。

可以发现,如果我们以 \(u\) 为根,所有点的深度的集合与这 \(n\) 个数的集合相等的话,那么 \(u\) 就是一个好点,如果一个一个点比较的话很耗时,可以使用哈希来描述一个集合,这是我的哈希函数:

\[h_u = c_0\times base^1+c_1\times base^2+\dots+c_{n-1}\times base^{n} \]

其中 \(c_i\) 表示深度为 \(i\) 的点的个数。

如果 \(h_u = tar\)\(tar\) 是目标集合,那么 \(u\) 是好点。

这个时候如果我们每次求一遍 \(h_i\),时间复杂度是 \(O(n^2)\) 的,但是这种与深度相关的 DP 很明显可以换根,所以时间复杂度降为 \(O(n)\)

但是这只是一个特殊情况,如果少了一个数怎么办,其实也很好理解,如果 \(tar\) 中少了一个数,它的哈希值就会减少 \(base^k\),我们只需要判断 \(h_u - tar\) 是否是一个 \(base\) 的次幂即可。

时间复杂度:\(O(n)\)

单哈希会被卡,双哈希有时候也会寄,可以用三哈希。

#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <unordered_map>
#define int long long
using namespace std;

const int N = 2e5 + 10;

int n, a[N], p[N], f[N], dep[N], ans[N], base, mod, f2[N], tar, b[N];
unordered_map<int, bool> h;
vector<int> g[N];
void dfs(int u, int fa) {
    f[u] = base;
    for(auto v : g[u]) {
        if(v == fa) continue;
        dfs(v, u);
        f[u] = (f[u] + f[v] * base % mod) % mod;
    }
}
void dfs2(int u, int fa) {
    if(h.count((f2[u] % mod - tar + mod) % mod)) ans[u] ++;
    for(auto v : g[u]) {
        if(v == fa) continue;
        f2[v] = f[v] + base * ((f2[u] - base * f[v] % mod + mod) % mod) % mod;
        f2[v] %= mod;
        dfs2(v, u);
    }
}

void work() {
    tar = 0, p[0] = 1;
    h.clear();
    h[1] = 1;
    for(int i = 1; i <= n; i ++) p[i] = p[i - 1] * base % mod, h[p[i]] = 1;
    for(int i = 0; i < n; i ++) tar = (tar + b[i] * p[i + 1] % mod) % mod;
    dfs(1, 0), f2[1] = f[1], dfs2(1, 0);
}

signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n;
    for(int i = 1; i < n; i ++) cin >> a[i], b[a[i]] ++;
    for(int i = 1, a, b; i < n; i ++) {
        cin >> a >> b;
        g[a].push_back(b), g[b].push_back(a);
    }
    base = 19260817, mod = 998244353, work();
    base = 19491001, mod = 1011451423, work();
    base = 19421221, mod = 1e9 + 7, work();
    int cnt = 0;
    for(int i = 1; i <= n; i ++) if(ans[i] == 3) cnt ++;
    cout << cnt << '\n';
    for(int i = 1; i <= n; i ++) if(ans[i] == 3) cout << i << ' ';
    return 0;
}