KMP学习笔记(再回首)+ AC自动机学习笔记

发布时间 2023-06-07 11:14:00作者: frostwood

一.KMP

引入

我们经常遇到字符串匹配问题。比如求一个长为 \(m\) 的串 \(a\) 在长度为 \(n\) 的串 \(b\) 中是否出现,或求出现多少次,等等。我们很容易想到 \(n*m\) 的做法,就是以每一位为起点,一直向后匹配,直到失配或匹配成功。显然,这样的复杂度是无法接受的。

KMP

那么,有没有一种可以快速匹配的算法呢?这里我们介绍KMP。首先,我们来明确一些概念。我们把前面的 \(a\) 叫做模式串,\(b\) 叫做主串(AC自动机中也会出现)。

我们发现,在前文的暴力算法中,我们很可能匹配上主串的很长一部分然后失配,这样的话就浪费了一段匹配,因为这个匹配中很可能还存在可以作为模式串匹配起点的子串。我们又发现,这个可以作为起点的串满足它既是 \(a\) 的前缀,也是 \(a\) 的后缀。这样的最长子串,我们称之为 \(border\) 。于是乎,我们可以预处理出来在 \(a\) 的每个前缀中上所对应 \(border\),然后在匹配主串时,如果失配,先跳 \(border\),直到匹配上或者完全失配。

特别地,对于每一个 \(border​\),从第一个元素开始,不包括最后一个元素结束,和从最后一个元素开始,不包括第一元素

至于 \(border\),我们可以用一个 \(next\) 数组来记录。这个数组记录的既是 \(border\) 的长度,也是其在 \(a\) 的前缀上的结束位置。

处理 \(border\) 代码(这也大概是这个算法的核心部分):

void KMP(){
    int j = 0;
    for(int i = 2; i<=n; i++){
        while(j && a[j+1]!=a[i]){
            j = nxt[j];
        }
        if(a[j+1]==a[i]){
            j++;
        }
        nxt[i] = j;
    }
}

至于匹配,和处理 \(border\) 类似。每次失配往回跳即可。当 \(j=n​\) 时,表示匹配成功。

void work(){
   int j = 0;
   for(int i = 1; i<=m; i++){
       while(j && a[j+1]!=b[i]){
           j = nxt[j];
       }
       if(a[j+1] == b[i]){
           j++;
       }
       if(j == n){
           printf("%d\n", i-n+1);
       }
   }
   for(int i = 1; i<=n; i++){
       printf("%d ", nxt[i]);
   }
}

例题:

洛谷P3375(模板题)

我才不说就是把上面两个代码整一块儿就行了

#include<bits/stdc++.h>
using namespace std;
const int N = 1e6+100;

int n, m;
int nxt[N];
char a[N];
char b[N];
void KMP(){
    int j = 0;
    for(int i = 2; i<=n; i++){
        while(j && a[j+1]!=a[i]){
            j = nxt[j];
        }
        if(a[j+1]==a[i]){
            j++;
        }
        nxt[i] = j;
    }
}
void work(){
    int j = 0;
    for(int i = 1; i<=m; i++){
        while(j && a[j+1]!=b[i]){
            j = nxt[j];
        }
        if(a[j+1] == b[i]){
            j++;
        }
        if(j == n){
            printf("%d\n", i-n+1);
        }
    }
    for(int i = 1; i<=n; i++){
        printf("%d ", nxt[i]);
    }
}
int main(){
    scanf("%s%s", b+1, a+1);
    n = strlen(a+1);
    m = strlen(b+1);
    KMP();
    work();
    return 0;
}

洛谷P2375 动物园

这个题要求找出长度不超过当前串一半的 \(border\) 数量。我们可以先跑一遍 KMP,并对每一个点记录一个 \(cnt\),表示这个点被当作 \(border\) 的次数。然后再跑一遍,这次在定位 \(j\) 后,再把 \(j\) 缩短到小于串长一半,统计答案即可。

#include<bits/stdc++.h>
using namespace std;
const int N = 1e6+1000, mod = 1e9+7;

int q, n;
char s[N];
int nxt[N];
long long cnt[N];
long long ans = 1;
long long KMP()
{
	ans = 1;
	int pos = 0;
	nxt[1] = 0;
	cnt[1] = 1;
	for(int i = 2; i<=n; i++)
	{
		while(pos&&s[i]!=s[pos+1]) pos = nxt[pos];
		if(s[i]==s[pos+1]) pos++;
		nxt[i] = pos;
		cnt[i] = cnt[pos]+1;//因为每一个border也可能有小的border,故数量是可以累加的。
	}//第一遍
	pos = 0;
	for(int i = 2; i<=n; i++)
	{
		while(pos&&s[i]!=s[pos+1]) pos = nxt[pos];
		if(s[i]==s[pos+1]) pos++;
		while(pos&&pos*2>i) pos = nxt[pos];
		ans=ans*(cnt[pos]+1)%mod;//cnt表示在pos点上的border总数。
	}
	return ans;
}
int main()
{
    scanf("%d", &q);
    while(q--)
	{
		scanf("%s", s+1);
		n = strlen(s+1);
		long long res = KMP();
		printf("%lld\n", res);
	}	
	return 0;
}

CF149E Martian Strings

题意:给定一个主串 \(s\) 和一些模式串 \(p_i\),问主串中是否存在两个不相交的非空子串,拼起来和模式串相同。

考虑如何拼接 \(p_i\)。我们可以从前向后匹配一遍主串,找到 \(p_i\) 的所有长度的前缀在主串中最先出现的位置,并记录下来;然后再从后向前跑匹配,每次匹配上一个后缀,就判断该后缀在主串中的起始位置是否大于这一后缀对应前缀的结束位置,如果是则答案加一。注意判断边界条件。匹配这里用的是KMP算法。

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+100;

char b[N], s[N];
int n, m, lth;
int fnxt[N], bnxt[N], posl[1005], posr[1005];
bool KMP(){
    int j = 0;
    for(int i = 2; i<=lth; i++){
        while(j && b[j+1]!=b[i]){
            j = fnxt[j];
        }
        if(b[j+1]==b[i]){
            j++;
        }
        fnxt[i] = j;
    }
    j = lth+1;bnxt[lth] = lth+1;
    for(int i = lth-1; i>=1; i--){
        while(j<=lth&&b[j-1]!=b[i]){
            j = bnxt[j];
        }
        if(b[j-1]==b[i]){//只记录第一次出现的结束位置
            j--;
        }
        bnxt[i] = j;
    }
    memset(posl, 0, sizeof(posl));
    j = 0;
    for(int i = 1; i<=n; i++){
        while(j && s[i]!=b[j+1]){
            j = fnxt[j];
        }
        if(s[i]==b[j+1]){
            j++;
        }
        if(!posl[j]&&j){
            posl[j] = i;
        }
    }
    j = lth+1;
    
    for(int i = n; i>=1; i--){
        while(j<=lth&&b[j-1]!=s[i]){
            j = bnxt[j];
        }
        if(s[i]==b[j-1]){
            j--;
        }
        if((j-1)&&posl[j-1]&&posl[j-1]<i&&(j<=lth)){//注意前后缀都要非空
            return true;
        }
    }
    return false;
}
int ans;
int main(){
    scanf("%s", s+1);
    n = strlen(s+1);
    scanf("%d", &m);
    while(m--){
        scanf("%s", b+1);
        lth = strlen(b+1);
        if(KMP()){
            ans++;
        }
    }
    printf("%d\n", ans); 
    return 0;
}

二.AC自动机

引入/前置知识

我们通过KMP可以做到主串匹配单模式串。那么,当模式串多起来后,有什么解决方案呢?

答案是KMP上树。没错,AC自动机就是Trie树和KMP的结合。所以在学习之前,需要先学习KMP和Trie树。

AC自动机

首先我们引入一个概念:失配指针 \(fail\)。这个指针的作用相当于KMP中的 \(next\) 数组。某种意义上,它就是。当然,这里的 \(fail\) 指针指向的是 所有模式串的前缀中匹配当前状态的最长后缀。举个栗子:对于Trie树上的一个字符串路径abcdef,如果你的Trie树上还有一条路径是def,那你的 \(fail\) 就应指向这个路径的终点。这样构建,是考虑到许多模式串会有相同后缀,但是前缀可能不同的情况。

那么怎么去构建 \(fail\) 数组呢?首先我们来讨论最简单的方式。

和KMP类似,就是在Trie树上不断跳 \(fail​\)。但这样的话,我们会发现,有的 \(fail​\) 是叶子节点,这样会导致跳的时候要不断跳很多次。于是,我们有了优化——构建Trie图。

我们可以采用类似并查集路径压缩的方式,将所有节点的空儿子,如 \(tr[u][v]​\) 直接指向 \(tr[fail[u]][v]​\),这样可以让跳多次变为只跳一次。

Trie树的构建:

struct Trie{
    int cnt;
    int son[26];
}tr[N];
int idx;

void insert(char s[]){
    int lth = strlen(s);
    int u = 0, v;
    for(int i = 0; i<lth; i++){
        v = s[i]-'a';
        if(!tr[u].son[v]){
            tr[u].son[v] = ++idx;
        }
        u = tr[u].son[v];
    }
    tr[u].cnt++;
}

\(fail\) 指针的构建+Trie图的构建(bfs):

int fail[N];
queue<int> q;
void build(){
    for(int i = 0; i<26; i++){
        if(tr[0].son[i]) q.push(tr[0].son[i]);
    }
    while(q.size()){
        int u = q.front();
        q.pop();
        for(int i = 0; i<26; i++){
            if(tr[u].son[i]){
                fail[tr[u].son[i]] = tr[fail[u]].son[i], q.push(tr[u].son[i]);
            }
            else{
                tr[u].son[i] = tr[fail[u]].son[i];
            }
        }
    }
}

至于查询,每新增一个字符,都要把整个Trie跳一遍。这里要统计出现的模式串数量,故每次需要清空。

查询:

int query(char t[]){
    int u = 0, ret = 0,v, lth = strlen(t);
    for(int i = 0; i<lth; i++){
        int v = t[i]-'a';
        u = tr[u].son[v];
        for(int j = u; j && tr[j].cnt!=-1; j = fail[j]){//遍历过的模式串没必要再跳
            ret+=tr[j].cnt;
            tr[j].cnt = -1;
        }
    }
    return ret;
}

例题

洛谷P3808 模板1

将以上模板套用即可。

#include<bits/stdc++.h>
using namespace std;
const int N = 8e6+100;

struct Trie{
    int cnt;
    int son[26];
}tr[N];
int idx;

void insert(char s[]){
    int lth = strlen(s);
    int u = 0, v;
    for(int i = 0; i<lth; i++){
        v = s[i]-'a';
        if(!tr[u].son[v]){
            tr[u].son[v] = ++idx;
        }
        u = tr[u].son[v];
    }
    tr[u].cnt++;
}
int fail[N];
queue<int> q;
void build(){
    for(int i = 0; i<26; i++){
        if(tr[0].son[i]) q.push(tr[0].son[i]);
    }
    while(q.size()){
        int u = q.front();
        q.pop();
        for(int i = 0; i<26; i++){
            if(tr[u].son[i]){
                fail[tr[u].son[i]] = tr[fail[u]].son[i], q.push(tr[u].son[i]);
            }
            else{
                tr[u].son[i] = tr[fail[u]].son[i];
            }
        }
    }
}
int query(char t[]){
    int u = 0, ret = 0, v, lth = strlen(t);
    for(int i = 0; i<lth; i++){
        v = t[i]-'a';
        u = tr[u].son[v];
        for(int j = u; j && tr[j].cnt!=-1; j = fail[j]){
            ret+=tr[j].cnt;
            tr[j].cnt = -1;
        }
    }
    return ret;
}
int n, lth;
char s[1000010];
int main(){
    scanf("%d", &n);
    for(int i = 1; i<=n; i++){
        scanf("%s", s);
        insert(s);
    }
    build();
    scanf("%s", s);
    printf("%d\n", query(s));
    return 0;
}

洛谷P3796 模板2(加强版)

这次是让你统计出现次数了。发现模式串很少,又发现没有相同的模式串(我一开始还傻傻地开了vector),直接开个桶记录一下次数,最后暴力扫即可。

#include<bits/stdc++.h>
using namespace std;
const int N = 12000;

int n;
struct Trie{
    int son[26];
    int have;
}tr[N];
int idx;
int fail[N], cnt[160];
void init(){
    idx = 0;
    memset(fail, 0, sizeof(fail));
    memset(tr, 0, sizeof(tr));
    memset(cnt, 0, sizeof(cnt));
}
void insert(char s[], int id){
    int lth = strlen(s);
    int u = 0, v;
    for(int i = 0; i<lth; i++){
        v = s[i]-'a';
        if(!tr[u].son[v]){
            tr[u].son[v] = ++idx;
        }
        u = tr[u].son[v];
    }
    tr[u].have = id;
}
queue<int> q;
void build(){
    for(int i = 0; i<26; i++){
        if(tr[0].son[i]){
            q.push(tr[0].son[i]);
        }
    }
    int u;
    while(!q.empty()){
        u = q.front();
        q.pop();
        for(int i = 0; i<26; i++){
            if(tr[u].son[i]){
                fail[tr[u].son[i]] = tr[fail[u]].son[i], q.push(tr[u].son[i]);
            }
            else{
                tr[u].son[i] = tr[fail[u]].son[i];
            }
        }
    }
}
void query(char t[]){
    int lth = strlen(t), u = 0, v;
    for(int i = 0; i<lth; i++){
        v = t[i]-'a';
        u = tr[u].son[v];
        for(int j = u; j; j = fail[j]){
            if(tr[j].have){
                cnt[tr[j].have]++;
            }
        }
    }
}
char tmp[160][80], t[1000050];
int mx;
int main(){
    scanf("%d", &n);
    while(n){
        init();
        for(int i = 1; i<=n; i++){
            scanf("%s", tmp[i]);
            insert(tmp[i], i);
        }
        build();
        scanf("%s", t);
        query(t);
        mx = 0;
        for(int i = 1; i<=n; i++){
            mx = max(cnt[i], mx);
        }
        printf("%d\n", mx);
        for(int i = 1; i<=n; i++){
            if(cnt[i] == mx){
                printf("%s\n", tmp[i]);
            }
        }
        scanf("%d", &n);
    }
    return 0;
}

洛谷P5357 模板3(二次加强)

乍眼一看,这题和上一道题不一样吗?虽然有重复串,但完全可以直接记录一个串,再让其他相同串指向这个串即可。于是乎——

76pts

代码还是放一下毕竟写半天不容易

#include<bits/stdc++.h>
using namespace std;
const int N = 2e6+100;

int tr[N][26], e[N], idx, he[N];
void insert(char s[], int x){
    int lth = strlen(s), u = 0, v;
    for(int i = 0; i<lth; i++){
        int v = s[i]-'a';
        if(!tr[u][v]){
            tr[u][v] = ++idx;
        }
        u = tr[u][v];
    }
    if(!e[u]){
        e[u] = x;
    }
    else{
        he[x] = e[u];
    }
}
int fail[N];
queue<int> q;
void build(){
    for(int i = 0; i<26; i++){
        if(tr[0][i]){
            q.push(tr[0][i]);
        }
    }
    int u;
    while(!q.empty()){
        u = q.front();
        q.pop();
        for(int i = 0; i<26; i++){
            if(tr[u][i]){
                fail[tr[u][i]] = tr[fail[u]][i];
                q.push(tr[u][i]);
            }
            else{
                tr[u][i] = tr[fail[u]][i];
            }
        }
    }
}
int cnt[N];
void query(char t[]){
    int lth = strlen(t);
    int u = 0, v;
    for(int i = 0; i<lth; i++){
        v = t[i]-'a';
        u = tr[u][v];
        for(int j = u; j; j = fail[j]){
            if(e[j]){
                cnt[e[j]]++;
            }
        }
    }
}
int n;
char s[N];
int main(){
    scanf("%d", &n);
    for(int i = 1; i<=n; i++){
        scanf("%s", s);
        insert(s, i);
    }
    scanf("%s", s);
    build();
    query(s);
    for(int i = 1; i<=n; i++){
        if(he[i]){
            printf("%d\n", cnt[he[i]]);
        }
        else{
            printf("%d\n", cnt[i]);
        }
    }
    return 0;
}//只有76pts qwq

让我们来分析一下为什么:因为对于每个点我们都要完整地跳一遍 \(fail\),那么最坏情况就是每次深度只减小 \(1\)。这样算下来,我们的最坏复杂度就是 O(模式串长 \(\cdot\) 主串长)。那为什么模板 1 的复杂度是对的呢?因为模板 1 中,每个Trie上的点只会经过一次(因为只需要统计出现的串数,统计过就不用再统计了,即将 \(cnt\) 赋为 \(-1\));而在这个题的程序中,每个点会经过不止一次,所以时间复杂度就爆炸了(

那可不可以让每个点只经过一次呢?答案是可以的(为什么我一开始想到了给否了qwq)。

答案是拓扑排序。

其实一开始我就在想,既然有的模板串是包含在另一些模板串中,那我们是不是只需要标记一个模板串,然后向上回溯,做一个树形dp就行。然鹅我觉得不好实现,因为 \(fail\) 数组乱跳的话没有边界,而且叶子节点也不好确定。于是,我们就有了拓扑排序。

很明显,\(fail\) 数组只会往上跳,换句话说,就是深度只会减小,那这样的话,其实仅仅 \(fail\) 指针就可以看作一个 DAG(连图都不用建了)。

这样的话,我们每次只需要在一个节点上修改权值。因为这个节点的权值会贡献给它所有的子串,而它所对应的字符串的所有的子串一定是这个串某一部分的后缀,所以一定是能通过跳 \(fail\) 跳到的。

代码:

#include<bits/stdc++.h>
using namespace std;
const int N = 2e6+100;

int tr[N][26], e[N], idx, he[N];
int inde[N];
bool vis[N];
void insert(char s[], int x){
    int lth = strlen(s), u = 0, v;
    for(int i = 0; i<lth; i++){
        int v = s[i]-'a';
        if(!tr[u][v]){
            tr[u][v] = ++idx;
        }
        u = tr[u][v];
    }
    e[x] = u;
}
int fail[N];
queue<int> q;
void build(){
    for(int i = 0; i<26; i++){
        if(tr[0][i]){
            q.push(tr[0][i]);
        }
    }
    int u;
    while(!q.empty()){
        u = q.front();
        q.pop();
        for(int i = 0; i<26; i++){
            if(tr[u][i]){
                fail[tr[u][i]] = tr[fail[u]][i];
                inde[fail[tr[u][i]]]++;
                q.push(tr[u][i]);
            }
            else{
                tr[u][i] = tr[fail[u]][i];
            }
        }
    }
}
int cnt[N];
void query(char t[]){
    int lth = strlen(t);
    int u = 0, v;
    for(int i = 0; i<lth; i++){
        v = t[i]-'a';
        u = tr[u][v];
        cnt[u]++;
    }
    for(int i = 1; i<=idx; i++){
        if(!inde[i]){
            q.push(i);
        }
    }
    while(!q.empty()){
        int u = q.front();
        q.pop();
        int v = fail[u];
        inde[v]--;
        cnt[v]+=cnt[u];
        if(!inde[v]){
            q.push(fail[u]);
        }
    }
}
int n;
char s[N];
int main(){
    scanf("%d", &n);
    for(int i = 1; i<=n; i++){
        scanf("%s", s);
        insert(s, i);
    }
    scanf("%s", s);
    build();
    query(s);
    for(int i = 1; i<=n; i++){
        printf("%d\n", cnt[e[i]]);
    }
    return 0;
}