最近公共祖先 倍增算法

发布时间 2023-05-01 19:25:40作者: eternal_visionary

求最近公共祖先(Lowest Common Ancestor,LCA)
例题:洛谷P3379 【模板】最近公共祖先(LCA)
https://www.luogu.com.cn/problem/P3379

基本思路就是先用倍增把两点升到同一深度,然后倍增来找最近公共祖先。
其中fa数组是关键

#include<iostream>
#include<vector>
#define forup(i,l,r) for(int i=l;i<=r;i++)
#define fordown(i,l,r) for(int i=r;i>=l;i--)
using namespace std;
const int N =5e5+5;
int dep[N],fa[N][20];//fa[u][i]指从u向上跳2^i个结点到达的祖先结点 
//2^20大约为一百万,具体为1048576 
vector<int> child[N];//结点所连的结点(这里包括了父节点) 
int read()
{
	int n=0;
	char m=getchar();
	while(m<'0'||m>'9') m=getchar();
	while(m>='0'&&m<='9'){
		n=(n<<1)+(n<<3)+(m^'0');
		m=getchar();
	}
	return n;
}
void dfs(int father,int u)//作用是初始化dep和fa数组 
{
	dep[u]=dep[father]+1;
	fa[u][0]=father;
	forup(i,1,19)//跳2^i个相当于是先跳2^(i-1)到fa[u][i-1],然后再往上跳2^(i-1)到fa[u][i] 
	{
		fa[u][i]=fa[fa[u][i-1]][i-1];
	}
	for(int v:child[u])//对子节点进行访问 
	{
		if(v!=father) dfs(u,v);
	}
}
int lca(int u,int v)
{
	if(u==v) return u;//特判 
	if(dep[u]<dep[v]) swap(u,v);
	fordown(i,0,19)
	{
		if(dep[fa[u][i]]>=dep[v])//dep的差=2^x1+2^x2+...+2^xn,故一定可以让u跳到和v同层的地方 
		{
			u=fa[u][i]; 
		}
	}
	if(u==v) return u;//特判二度 
	fordown(i,0,19)
	{
		if(fa[u][i]!=fa[v][i])//让u,v跳到最近公共祖先的下一个,原理同上一个for 
		{
			u=fa[u][i]; v=fa[v][i];
		}
	}
	return fa[u][0];
}
int main()
{
	int n,m,root;
	n=read(),m=read(),root=read();
	forup(i,1,n-1)//建树
	{
		int u,v;
		u=read(),v=read();
		child[v].push_back(u);
		child[u].push_back(v);
	}
	dfs(0,root);//作用是初始化dep和fa数组 
	forup(i,1,m)//查询 
	{
		int u,v;
		u=read(),v=read();
		cout<<lca(u,v)<<endl;
	}
	return 0;
}

另外,当深度对应的结点数不那么明朗的时候比如不清楚n是2的多少次方,或者需要优化的时候,可以预处理一个lg数组来判断需要最多跳多少层,注意lg里面是深度或深度差
递推公式为lg[i]=lg[i-1]+((1<<(lg[i-1]+1))==i)

#include<iostream>
#include<vector>
#define forup(i,l,r) for(int i=l;i<=r;i++)
#define fordown(i,l,r) for(int i=r;i>=l;i--)
using namespace std;
const int N =5e5+5;
int dep[N],fa[N][20];//fa[u][i]指从u向上跳2^i个结点到达的祖先结点 
//2^20大约为一百万,具体为1048576 
int lg[N];//lg为log2(n)
vector<int> child[N];//结点所连的结点(这里包括了父节点) 
int read()
{
	int n=0;
	char m=getchar();
	while(m<'0'||m>'9') m=getchar();
	while(m>='0'&&m<='9'){
		n=(n<<1)+(n<<3)+(m^'0');
		m=getchar();
	}
	return n;
}
void dfs(int father,int u)//作用是初始化dep和fa数组 
{
	dep[u]=dep[father]+1;
	fa[u][0]=father;
	forup(i,1,lg[dep[u]])//跳2^i个相当于是先跳2^(i-1)到fa[u][i-1],然后再往上跳2^(i-1)到fa[u][i] 
	{
		fa[u][i]=fa[fa[u][i-1]][i-1];
	}
	for(int v:child[u])//对子节点进行访问 
	{
		if(v!=father) dfs(u,v);
	}
}
int lca(int u,int v)
{
	if(u==v) return u;//特判 
	if(dep[u]<dep[v]) swap(u,v);
	while(dep[u]>dep[v]) {
		u=fa[u][lg[dep[u]-dep[v]]];//2^lg始终小于等于dep的差 
	}
	if(u==v) return u;//特判二度 
	fordown(i,0,lg[dep[u]])
	{
		if(fa[u][i]!=fa[v][i])//让u,v跳到最近公共祖先的下一个
		{
			u=fa[u][i]; v=fa[v][i];
		}
	}
	return fa[u][0];
}
int main()
{
	int n,m,root;
	n=read(),m=read(),root=read();
	forup(i,2,n)
	{
		lg[i]=lg[i-1]+((1<<(lg[i-1]+1))==i);//在此之前i都是处于[2^lg[i-1],2^lg[i])之间的数,即2^lg[i-1]*2>i 
	} 
	forup(i,1,n-1)
	{
		int u,v;
		u=read(),v=read();
		child[v].push_back(u);
		child[u].push_back(v);
	}
	dfs(0,root);//作用是初始化dep和fa数组 
	forup(i,1,m)//查询 
	{
		int u,v;
		u=read(),v=read();
		cout<<lca(u,v)<<endl;
	}
	return 0;
}

至于为什么不直接用c++自带的log函数,因为log比20快但是比lg数组慢

参考:https://www.cnblogs.com/dx123/p/16320461.html
https://blog.csdn.net/weixin_45697774/article/details/105289810
次方表:https://blog.csdn.net/weixin_44827418/article/details/106355287