「解题报告」CF983E NN country

发布时间 2023-04-14 16:14:11作者: APJifengc

水点简单数据结构题!

考虑从两个点开始往上跳,每次肯定尽可能跳到最浅的点。两个点跳到再跳一步就能到达 lca 的位置的时候,此时再看看有没有路径连接这两个点,如果有那么一步就可以跳到,否则就要跳到 lca 再跳一步,两步跳到。跳的过程显然可以用倍增处理。

然后我们考虑处理出每个点能跳到的最浅的点。假如现在处理 \(u\) 点,有一条路径 \(x-y\) 满足 \(x\)\(u\) 子树内,那么 \(u\) 能通过这条路径跳到的最浅的点为 \(\mathrm{lca}(u, y)\)。那么我们相当于要求深度最小的 lca,我们只需要找出 dfn 序最小的与 dfn 序最大的两个点即可。然后做完了。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005;
int n, m, q;
vector<int> e[MAXN];
int dfn[MAXN], ed[MAXN], idf[MAXN], fa[MAXN][22], dep[MAXN], dcnt;
void dfs(int u, int pre) {
    dfn[u] = ++dcnt, idf[dcnt] = u, fa[u][0] = pre, dep[u] = dep[pre] + 1;
    for (int i = 1; i <= 20; i++)
        fa[u][i] = fa[fa[u][i - 1]][i - 1];
    for (int v : e[u]) {
        dfs(v, u);
    }
    ed[u] = dcnt;
}
int lca(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);
    for (int i = 20; i >= 0; i--) if (dep[fa[u][i]] >= dep[v]) u = fa[u][i];
    if (u == v) return u;
    for (int i = 20; i >= 0; i--) if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
    return fa[u][0];
}
struct SegmentTree {
    struct Node {
        int lc, rc;
        int sum;
    } t[MAXN * 48];
    int tot;
    void insert(int d, int &p, int l = 1, int r = n) {
        if (!p) p = ++tot, t[p].sum++;
        else t[++tot] = t[p], p = tot, t[p].sum++;
        if (l == r) return;
        int mid = (l + r) >> 1;
        if (d <= mid) insert(d, t[p].lc, l, mid);
        else insert(d, t[p].rc, mid + 1, r);
    }
    int query(int a, int b, int p, int l = 1, int r = n) {
        if (!p) return 0;
        if (a <= l && r <= b) return t[p].sum;
        int mid = (l + r) >> 1;
        if (b <= mid) return query(a, b, t[p].lc, l, mid);
        if (a > mid) return query(a, b, t[p].rc, mid + 1, r);
        return query(a, b, t[p].lc, l, mid) + query(a, b, t[p].rc, mid + 1, r);
    }
} st;
int root[MAXN];
int mnd[MAXN], mxd[MAXN];
int x[MAXN], y[MAXN];
vector<int> t[MAXN];
void dfs2(int u, int pre) {
    for (int v : e[u]) {
        dfs2(v, u);
        mxd[u] = max(mxd[u], mxd[v]);
        mnd[u] = min(mnd[u], mnd[v]);
    }
}
int f[MAXN][22];
int main() {
    scanf("%d", &n);
    for (int i = 2; i <= n; i++) {
        int p; scanf("%d", &p);
        e[p].push_back(i);
    }
    for (int i = 1; i <= n; i++) {
        mnd[i] = n + 1, mxd[i] = 0;
    }
    dfs(1, 0);
    scanf("%d", &m);
    for (int i = 1; i <= m; i++) {
        int a, b; scanf("%d%d", &a, &b);
        x[i] = a, y[i] = b;
        t[dfn[a]].push_back(dfn[b]), t[dfn[b]].push_back(dfn[a]);
        mxd[a] = max(mxd[a], dfn[b]), mnd[a] = min(mnd[a], dfn[b]);
        mxd[b] = max(mxd[b], dfn[a]), mnd[b] = min(mnd[b], dfn[a]);
    }
    dfs2(1, 0);
    for (int i = 1; i <= n; i++) {
        root[i] = root[i - 1];
        for (int j : t[i]) {
            st.insert(j, root[i]);
        }
    }
    for (int u = 1; u <= n; u++) {
        f[u][0] = u;
        if (mxd[u] != 0) {
            int v = idf[mxd[u]];
            int l = lca(u, v);
            if (dep[l] < dep[f[u][0]]) f[u][0] = l;
        }
        if (mnd[u] != n + 1) {
            int v = idf[mnd[u]];
            int l = lca(u, v);
            if (dep[l] < dep[f[u][0]]) f[u][0] = l;
        }
    }
    for (int j = 1; j <= 20; j++) {
        for (int i = 1; i <= n; i++) {
            f[i][j] = f[f[i][j - 1]][j - 1];
        }
    }
    scanf("%d", &q);
    for (int i = 1; i <= q; i++) {
        int u, v; scanf("%d%d", &u, &v);
        int l = lca(u, v);
        int ans = 0;
        if (l == u || l == v) {
            if (l == u) swap(u, v);
            for (int i = 20; i >= 0; i--) if (dep[f[u][i]] > dep[v]) u = f[u][i], ans += 1 << i;
            if (dep[f[u][0]] <= dep[v]) printf("%d\n", ans + 1);
            else printf("-1\n");
        } else {
            for (int i = 20; i >= 0; i--) if (dep[f[u][i]] > dep[l]) u = f[u][i], ans += 1 << i;
            for (int i = 20; i >= 0; i--) if (dep[f[v][i]] > dep[l]) v = f[v][i], ans += 1 << i;
            if (dep[f[u][0]] > dep[l]) printf("-1\n");
            else if (dep[f[v][0]] > dep[l]) printf("-1\n");
            else {
                int c = st.query(dfn[u], ed[u], root[ed[v]]) - st.query(dfn[u], ed[u], root[dfn[v] - 1]);
                if (c) ans += 1;
                else ans += 2;
                printf("%d\n", ans);
            }
        }
    }
    return 0;
}