最近公共祖先

发布时间 2023-04-12 21:42:42作者: kkidd
#include<iostream>
#include<vector>
#include<string.h>
using namespace std;

const int N=5e5+10,M=N*2;

typedef pair<int,int> PII;

int n,m,root;
int p[N];
int ans[N];
int h[N],e[M],ne[M],idx;
vector<PII> query[N];
bool st[N];

void add(int a,int b)
{
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}

int find(int x)
{
    if(x!=p[x]) p[x]=find(p[x]);
    return p[x];
}

void dfs(int u,int fa)
{
    st[u]=true;
    
    for(int i=h[u];~i;i=ne[i])
    {
        int j=e[i];
        if(j==fa) continue;
        dfs(j,u);
        p[j]=u;
    }
    
    for(auto x:query[u])
    {
        int v=x.first,i=x.second;
        if(st[v]) ans[i]=find(v);
    }
}

int main()
{
    scanf("%d%d%d",&n,&m,&root);
    
    memset(h,-1,sizeof h);
    
    for(int i=0;i<n-1;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        add(a,b),add(b,a);
    }
    
    for(int i=1;i<=n;i++) p[i]=i;
    
    for(int i=1;i<=m;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        query[a].push_back({b,i});
        query[b].push_back({a,i});
    }
    
    dfs(root,-1);
    
    for(int i=1;i<=m;i++)
    printf("%d\n",ans[i]);
    
    return 0;
}