KMP字符串匹配算法 整理

发布时间 2023-12-02 22:13:28作者: 2020luke

KMP 整理

题面
视频详解

KMP 的作用

KMP 算法的主要作用是求出一个字符串(模式串)是否为另一个字符串(主串)的子串,并同时求出它出现的位置,也即字符串匹配问题。

算法解析

暴力

先说暴力算法:

  • 从头开始枚举模式串位置的起点,然后遍历从起点往后 \(m\) 个字符,检查它是否与模式串完全相同。

  • 复杂度为 \(O(nm)\)\(O(n)\) 遍历起点,对于每个起点用 \(O(m)\) 检查是否与模式串相同

我们发现时间复杂度奇高,于是便有三位 dalao (Knuth,Morris,Pratt) 共同发明了 KMP 算法,他的核心思想就是利用已知的数据,利用好,不往回头走,从而做到 \(O(n+m)\) 的线性复杂度。

KMP

KMP 算法思路:

我们设当前比对主串(str1)的下标是 \(i\),模式串(str2)的下标是 \(j\)。当 \(j>m\) 时则字符串匹配;若 \(str1[i]==str2[j]\) 则当前字符匹配,\(i\)\(j\) 同时 ++,开始比较下一个字符 ;若 \(str1[i]!=str2[j]\) ,则说明当前字符不匹配,这时我们跳过一部分字符,\(i\) 不变,继续比较。

举个栗子:

第一次匹配
模式串: ABABC
主串:  ABABABCAB
这时我们发现模式串的C与主串的A不匹配,但是由于我们已经知道了主串的前5个字符,发现模式串的[1,2]串和[3,4]串相同,所以我们直接将模式串跳过前面的[1,2]串AB再进行比较。

第二次匹配
模式串:   ABABC
主串:  ABABABCAB
这时我们发现模式串和主串匹配了,于是说明模式串为主串的字串。

那我们如何知道该跳过夺少字符呢?这时候就要用到 \(next\) 数组了。

next 数组

\(next\) 数组的作用是当我们发现模式串中的第 \(j\) 个字符与主串的第 \(i\) 个字符不匹配时,我们直接跳过前 \(next[j - 1]\) 个字符,\(i\) 不变,这样就可以非常有效地降低时间复杂度。

那我们如何知道 \(next\) 数组的值呢?\(next[i]\) 的本质其实是前 \(i\) 个字符中前缀和后缀完全相同的情况下前后缀的长度最长是多少。举个栗子:

模式串:   A B A B C
next数组: 0 0 1 2 0

\(next\) 数组的前两位为 \(0\) 原因是都没有相同的前后缀而且 \(next[i]\) 不能为 \(i\) ,原因是不能跳过整个子串,不然就没有跳过的意义了。

\(next\) 数组的值我们改如何计算呢?我们用递推的方法解决这个问题。首先我们定义一个 \(i\)\(j\) ,分别代表当前要修改哪个 \(next\) 的值和 \(next[i - 1]\) 的值。首先 \(next[1]\) 肯定等于 \(0\) ,原因上面说了是不能跳过整个子串。所以我们从 i = 2, j = 0 开始枚举。由于我们要用递推的算法来计算,于是我们比较 str[i]str[j + 1] ,若两者相等,则说明当前最长前后缀其实是上一个最长前后缀个添加一个相同的字符,此时前后缀依然相等,于是我们执行 j++; next[i] = j; ;而如果两者不同,我们就反复执行 j = next[j]; ,直到 str[i]==str[j + 1] 或者是 j <= 0 ,因为此时我们就已经找到了最长相同前后缀或者当前字串根本不存在最长相同前后缀。

实现流程:

(1)设 \(j\)\(next[i - 1]\) ,比较 str[i]str[j + 1],若相等,那么执行 j++; next[i] = j;

(2)若不相等,\(j\) 需要变得更小,以满足找到最长相同前后缀的最大值,即重复执行j = next[j]; 直到 str[i]==str[j + 1]j <= 0 为止。

Code:

#include<bits/stdc++.h>
using namespace std;
string str1, str2;
const int N = 1e6 + 10;
int nxt[N];
int main() {
	cin >> str1 >> str2;
	for(int i = str1.size(); i >= 1; i--) {
		str1[i] = str1[i - 1];
	}
	for(int i = str2.size(); i >= 1; i--) {
		str2[i] = str2[i - 1];
	}
	for(int i = 2, j = 0; i <= str2.size(); i++) {
		while(j && str2[j + 1] != str2[i]) {
			j = nxt[j];
		}
		if(str2[j + 1] == str2[i]) j++;
		nxt[i] = j;
	}
	for(int i = 1, j = 0; i <= str1.size(); i++) {
		while(j && str2[j + 1] != str1[i]) {
			j = nxt[j];
		}
		if(str2[j + 1] == str1[i]) j++;
		if(j >= str2.size()) {
			cout << i - j + 1 << "\n";
			j = nxt[j]; 
		}
	}
	for(int i = 1; i <= str2.size(); i++) cout << nxt[i] << " ";
	return 0;
}