浅谈关于LCA

发布时间 2023-10-06 09:06:37作者: carp_oier

prologue

本身只会 tarjan 和 倍增法求LCA 的,但在发现有一种神奇的\(O(1)\) 查询 lca 的方法,时间优化很明显。

main body

倍增法

先讨论倍增法,倍增法求 lca 是一种很常见普遍的方法,这里直接放代码了,其本身的内核就是让较低点每次都跳 $ 2 ^ k $ 步,如果跳的比另一个高了,就不跳那么高,跳 \(2 ^ {k-1}\) 步,这就用对数级的复杂度求出来了 LCA。

code

比较基础,建议直接背过。

inline void bfs()
{
    memset(dep, 0x3f, sizeof dep);

    ll hh = 0, tt = -1;

    q[++ tt] = 1;
    
    dep[0] = 0, dep[1] = 1;

    while(hh <= tt)
    {
        ll u = q[hh ++ ];

        for(rl i = h[u]; ~i; i = ne[i])
        {
            ll v = e[i];
            if(dep[v] > dep[u] + 1)
            {
                dep[v] = dep[u] + 1;
                fa[v][0] = u;
                q[++ tt] = v;
                for(rl k=1; k <= 20; ++ k)
                    fa[v][k] = fa[fa[v][k - 1]][k - 1];
            }
        }
    }
}

inline ll lca(ll a, ll b)
{
    if(dep[a] < dep[b]) swap(a, b);

    for(rl k=20; k >= 0; -- k)
        if(dep[fa[a][k]] >= dep[b])
            a = fa[a][k];

    if(a == b) return a;

    for(rl k=20; k >= 0; -- k)
        if(fa[a][k] != fa[b][k])
            a = fa[a][k], b = fa[b][k];

    return fa[a][0];
}

tarjan求LCA

这是一种离线做法(将所有的询问存下来,然后再一一输出)。

这种做法我自我感觉不太好用,但是因为这种做法的时间复杂度是 \(O(n + m)\) 的,所以说有的人用起来很香(个人不喜欢,主要是没怎么敲过/

#include <bits/stdc++.h>
using namespace std;
#define ll int 
#define rl register ll

const ll N = 20010, M = 2 * N;

ll n, m;

ll tot, ne[M], e[M], h[N], w[M];

ll p[N], dis[N], st[N];

ll res[N];

vector<pair<int, int>> query[N];

inline void add(ll a, ll b, ll c)
{
    ne[++tot] = h[a], h[a] = tot, e[tot] = b, w[tot] = c;
}

inline void dfs(ll u, ll fa)
{
    for(rl i=h[u]; ~i; i = ne[i])
    {
        ll v = e[i];
        if(v == fa) continue;
        dis[v] = dis[u] + w[i];
        dfs(v, u);
    }
}

inline ll find(ll x)
{
    if(p[x] == x) return x;
    else return p[x] = find(p[x]);
}

inline void tarjan(ll u)
{
    st[u] = 1;
    for(rl i=h[u]; ~i; i = ne[i])
    {
        ll v = e[i];
        if(!st[v])
        {
            tarjan(v);
            p[v] = u;
        }
    }
    
    for(auto item : query[u])
    {
        ll y = item.first, id = item.second;
        if(st[y] == 2)
        {
            ll anc = find(y);
            res[id] = dis[y] + dis[u] - 2 * dis[anc];
        }
    }
    
    st[u] = 2;
}

int main()
{
    cin >> n >> m;
    
    memset(h, -1, sizeof h);
    
    for(rl i=1; i <= n - 1; ++ i)
    {
        ll a, b, c;
        cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    
    for(rl i=1; i <= m; ++ i)
    {
        ll a, b;
        cin >> a >> b;
        if(a != b)
        {
            query[a].push_back({b, i});
            query[b].push_back({a, i});
        }
    }
    
    for(rl i=1; i <= n; ++ i) p[i] = i; 
    
    dfs(1, -1);
    tarjan(1);
    
    for(rl i=1; i <= m; ++ i)
        cout << res[i] << endl;
    return 0;
}

O(1) 复杂度求LCA

我们首先求出来树上每个点的 \(dfs\) 序列。

我们考虑 \(u\)\(v\) 之间有什么,令 \(lca(u, v) = x, dfn_u < dfn_v\)。那么 \(x\)\(v\) 路径上的一点,一定是 \(u \to v\) 路径上一点,并且这个点是 \(u \to v\) 深度最小的,这个点的父亲节点就是 \(x\)

再考虑一种特殊情况,当\(u\) 就是 \(v\) 的祖先的时候,深度最小得点就变成了\(u\), 为了规避这种情况,我们就选择在 \([dfn_u + 1 \to dfn_v]\) 上来查询。

区间深度最小可以使用 ST表 来维护。

下面是代码,也建议背过。

inline ll get(ll a, ll b) { return dep[a] < dep[b] ? a : b; }

inline void dfs(ll u, ll fa)
{
    dfn[u] = ++ idx, st[0][idx] = fa, dep[u] = dep[fa] + 1;
    for(rl i=h[u]; ~i; i = ne[i])
    {
        ll v = e[i];
        if(v == fa) continue;
        dfs(v, u);
    }
}

inline void init()
{
    dfs(1, -1);
    for(rl i=1; (1 << i) <= n; ++ i)
        for(rl j=1; j <= n - (1 << i) + 1; ++ j)
            st[i][j] = get(st[i-1][j], st[i-1][j + (1 << i - 1)]);
}

inline ll lca(ll a, ll b)
{
    if(a == b) return a;
    a = dfn[a], b = dfn[b];
    if(a > b) swap(a, b);
    ll l = __lg(b - a); // 本人亲测用__lg(b - a),和log2(b - a) 都可以,这两个得区别可以上网搜。
    return get(st[l][a + 1], st[l][b - (1 << l) + 1]);
}