trie(字典树)学习笔记

发布时间 2023-11-05 17:07:11作者: rhineofts

trie(字典树)学习笔记

trie 可以在 \(O(nL)\) 的时间, \(O(n\left| \Sigma\right|L)\) 的空间完成插入,查找字符串。其中 \(L\) 为字符串长,\(\Sigma\) 为字符集

int trie[N][26], tot;
int tag[N];
void insert() {
    int n = str.size();
    int u = 0;
    for (int i = 0; i < n; i++) {
        int x = str[i] - 'a';
        if (!trie[u][x]) trie[u][x] = ++tot; 
        u = trie[u][x];
    }
    tag[u] = 1;
}
void query() {
    int n = str.size();
    int u = 0;
    for (int i = 0; i < n; i++) {
        int x = str[i] - 'a';
        if (!trie[u][x]) {
            break;
        }    
        u = trie[u][x];
    }
    if (tag[u] == 0) {
        cout << "WRONG\n";
    }
    else if (tag[u] == 1) {
        cout << "OK\n"; 
        tag[u] = 2;
    }
    else {
        cout << "REPEAT\n";
    }
}

trie 维护异或极值

问题:最长异或路径

给定一棵 \(n\) 个点的无向带权树,结点下标从 \(1\) 开始到 \(n\)。寻找树中找两个结点,求最长的异或路径。边权 \(w\)

\(w < 2^{31}\)

解析:记 \(d(u, v)\)\(u\to v\) 的异或路径值。

我们选择一个点为根,不妨为 \(1\)

那么我们有这个重要结论: \(d(1,u) \oplus d(1, v)=d(u,v)\) 这是由于 \(a\oplus a=0\)\(a\oplus 0=a\)\(u, v\) 的 lca 部分被抵消了。

那么问题就转化成了 枚举每个 \(d(1, u)\) 找到其对应的 \(d(1, v)\) 使得 \(d(1,u) \oplus d(1, v)\) 最大,更新答案。

如何快速做到这一点呢?我们可以注意到一个性质:对每个数补全到 \(31\) 位后,如果 \(u\)\(v\) 的最高位不同,但 \(u\)\(w\) 的最高位相同,那么 \(u\oplus v\) 一定大于 \(u\oplus w\)

基于这一点,我们可以将每个 \(d(1, u)\) 都将其视作一个 \(31\) 位的二进制数插入到 trie 中。枚举 \(d(1, u)\) 时从最高位开始采用贪心,尽可能地向不同于当前这一位的子树走,最后的结果就是最大的。

时空复杂度同上。这里 \(\Sigma\)\(2\)\(L\)\(31\)

代码:

#include<bits/stdc++.h>
using namespace std;
using pii = pair<int, int>;
#define fi first
#define se second
const int N = 1e5 + 10, M = 31;
int trie[N * M][2], tot;
int dis[N];
int n, ans;
vector<pii> g[N];
void insert(int x) {
    for (int i = 30, u = 0; i >= 0; i--) {
        int c = (x >> i) & 1; // i-th bit
        if (!trie[u][c]) trie[u][c] = ++tot;
        u = trie[u][c];   
    }
}
void update(int x) {
    int res = 0;
    for (int i = 30, u = 0; i >= 0; i--) {
        int c = (x >> i) & 1;
        if (trie[u][!c]) {
            u = trie[u][!c];
            res += (1 << i);
        }
        else u = trie[u][c];
    }
    ans = max(ans, res);
}
void dfs(int u, int fa) {
    insert(dis[u]);
    update(dis[u]);
    for (auto x : g[u]) {
        int v = x.fi, w = x.se;
        if (v == fa) continue;
        dis[v] = dis[u] ^ w;
        dfs(v, u);
    }
}
int main() {
    cin.tie(0)->sync_with_stdio(0);
    cin >> n;
    for (int i = 0; i < n - 1; i++) {
        int u, v, w;
        cin >> u >> v >> w;
        auto add = [&] (int u, int v, int w) {
            g[u].push_back({v, w});
        };
        add(u, v, w);
        add(v, u, w);
    }
    dfs(1, 0);
    cout << ans << "\n";
    return 0;
}