CF519E A and B and Lecture Rooms

发布时间 2023-06-23 14:11:01作者: 空白菌

题目链接

题目

见链接。

题解

知识点:倍增,LCA,树型dp。

要找到距离两点 \(u,v\) 相同的点个数,我可以分类讨论:

  1. \(u,v\) 是同一个点,那么全部点都可以。
  2. \(u,v\) 处于相同深度,那么就是全部点减去 \(LCA(u,v)\)\(u,v\) 两点所在子树的全部点。
  3. \(u,v\) 不在相同深度,当 \(u,v\) 距离为奇数时无解,否则解为 \(u,v\) 路径中点为根的子树全部点减去中点的 \(u,v\) 中是中点子孙的点所在子树的全部点。

其中,查询过程用倍增很容易实现,子树大小用树型dp即可。

时间复杂度 \(O((n+m) \log n)\)

空间复杂度 \(O(n \log n)\)

代码

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

struct Graph {
    struct edge {
        int v, nxt;
    };
    int idx;
    vector<int> h;
    vector<edge> e;

    Graph(int n = 0, int m = 0) { init(n, m); }

    void init(int n, int m) {
        idx = 0;
        h.assign(n + 1, 0);
        e.assign(m + 1, {});
    }

    void add(int u, int v) {
        e[++idx] = { v,h[u] };
        h[u] = idx;
    }
};

const int N = 100007;
Graph g;

int dep[N], f[27][N], sz[N];
void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1;
    f[0][u] = fa;
    sz[u] = 1;
    for (int i = 1;i <= 18;i++)
        f[i][u] = f[i - 1][f[i - 1][u]];
    for (int i = g.h[u];i;i = g.e[i].nxt) {
        int v = g.e[i].v;
        if (v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
    }
}

int LCA(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);
    for (int i = 20;i >= 0;i--) {
        if (dep[f[i][u]] >= dep[v])u = f[i][u];
        if (u == v) return u;
    }
    for (int i = 20;i >= 0;i--) {
        if (f[i][u] != f[i][v]) {
            u = f[i][u];
            v = f[i][v];
        }
    }
    return f[0][u];
}

int dist(int u, int v) { return dep[u] + dep[v] - 2 * dep[LCA(u, v)]; }

int get_ans(int u, int v) {
    if (u == v) return sz[1];
    if (dep[u] == dep[v]) {
        for (int i = 20;i >= 0;i--) {
            if (f[i][u] != f[i][v]) {
                u = f[i][u];
                v = f[i][v];
            }
        }
        return sz[1] - sz[u] - sz[v];
    }
    int dis = dist(u, v);
    if (dis & 1) return 0;
    if (dep[u] < dep[v]) swap(u, v);
    int d = dep[u] - dis / 2;
    for (int i = 20;i >= 0;i--)
        if (dep[f[i][u]] > d) u = f[i][u];
    return sz[f[0][u]] - sz[u];
}

int main() {
    std::ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n;
    cin >> n;
    g.init(n, n << 1);
    for (int i = 1;i <= n - 1;i++) {
        int u, v;
        cin >> u >> v;
        g.add(u, v);
        g.add(v, u);
    }

    dfs(1, 0);

    int m;
    cin >> m;
    while (m--) {
        int u, v;
        cin >> u >> v;
        cout << get_ans(u, v) << '\n';
    }
    return 0;
}