AC自动机

发布时间 2023-09-27 14:36:51作者: _ZRJ

写在前面

本篇代码来源于yyb大佬的博客(指路
加上了一些自己的理解,重写了代码注释,可能算转载plus罢。


代码思路

说到AC自动机,总会提起这个老生常谈的前置知识:Trie+KMP
事实上,它的代码也几乎就是这两者的组合形式。

主体部分:

建Trie树,求失配指针,查询
(即build,get_fail,query三个函数)

  1. 建Trie树:单纯的模板。

  2. 失配指针:祖先失配节点的同字符子节点。
    (简单来说是↑,不过严格来讲,失配指针应当指向“沿着其父节点 的 失配指针,一直向上,直到找到拥有当前这个字母的子节点 的节点 的那个子节点”)

  3. 查询:扫一遍文本串,维护Trie树上的指针,进行统计(见注释)

简化部分:

自定义数据类型
(struct Tree)

注意事项:

输入string类型用cin,方便又快捷
字典树数组开26,则下标从0开始


原理(极简)

Trie+KMP=AC自动机


代码注释

洛谷P3808为例。

#include <bits/stdc++.h>
using namespace std;
//建Trie树,求失配指针,查询
const int N = 1e6+5; // 数据范围
int n, cnt = 0; // cnt:Trie树大小计数
struct Tree{
    int fail; // 失配指针
    int vis[26]; // 字典树每个节点(至多)有26个儿子
    int end; // 是多少个字符串的末字符
    bool fid; // 用于在查询时判重
}t[N];

void build(string s){ // 建Trie树,这里的s是模式串
    int l = s.length(); // 模式串长度
    int now = 0; // Trie树特有的跳来跳去下标
    for(int i = 0; i < l; i++){ // 从零开始的异世界字符串
        if(!t[now].vis[s[i]-'a']) // 如果还没有对应字符的子节点
            t[now].vis[s[i]-'a'] = ++cnt; // 新建一个节点
        now = t[now].vis[s[i]-'a']; // 下移一层
    }
    t[now].end++; // 累计字符串末位字符
}

int que[N]; // 队列,用于BFS
void get_fail(){ // 求失配指针
    int hd = 1, tl = 0; // 队首 队尾
    for(int i = 0; i < 26; ++i){ // 枚举树的第一层
        if(t[0].vis[i]){ // 如果存在字符'a'+i
            t[t[0].vis[i]].fail = 0; // 令其失配指针指向根节点
            que[++tl] = t[0].vis[i]; // 放入BFS队列
        }
    }
    while(hd<=tl){ // BFS
        int u = que[hd++]; // 从队首取一枚节点
        for(int i = 0; i < 26; ++i){ // 枚举该节点的儿子
            if(t[u].vis[i]){ // 如果存在一个'a'+i的子节点
                t[t[u].vis[i]].fail = t[t[u].fail].vis[i]; // 求子节点的失配指针
                que[++tl] = t[u].vis[i]; // 放入BFS队列
            }else{ // 不存在儿子则
                    t[u].vis[i] = t[t[u].fail].vis[i];
                                // 为了后续求失配指针,将儿子指向失配指针对应的儿子。
            }
        }
    }
}

int query(string s){ // 查询(这里的s是文本串)
    int l = s.length(); // 文本串长度
    int now = 0, ans = 0; // Trie树指针 统计答案
    // 
    for(int i = 0; i < l; i++){
        now = t[now].vis[s[i]-'a'];
        for(int j = now; j&&!t[j].fid; j = t[j].fail){
            ans += t[j].end;
            t[j].fid = 1;
        }
    }
    return ans;
}

int main(){
    scanf("%d", &n);
    string s;
    for(int i = 1; i <= n; i++){
        cin >> s;
        build(s);
    }
    t[0].fail = 0; get_fail();
    cin >> s;
    printf("%d\n", query(s));
}