KMP 算法

发布时间 2023-08-16 22:17:20作者: susenyang
  • KMP 算法

一个人能走的多远不在于他在顺境时能走的多快,而在于他在逆境时多久能找到曾经的自己。 —— KMP

例题

【模板】KMP 字符串匹配

原理

朴素算法的缺陷

设主串与模式串的长度分别为 \(m\)\(n\),那么完成一次匹配的最坏时间复杂度将是 \(O(mn)\)

匹配算法的改进

我们思考,当一次匹配失败时,如何利用失配的信息

假设在模式串的第 \(j\) 位与主串某一位比对时失配,说明模式串的前j-1位是与主串匹配的。我们希望找到最长的公共真前后缀,然后让模式串移动到这个后缀的起点进行下一次比较,从而避免重复比较。类似的,如果成功匹配,我们只需要将模式串移动到整个字符串的最长公共真前后缀的后缀起点。

这么做的原因是因为当模式串的前 \(j-1\) 位与主串匹配时该子串的前后缀显然与对应位置的主串匹配将下一次匹配的起点选在后缀的起点,可以省去后缀(下一次匹配的前缀)与子串匹配的步骤,只需要比较(下一次匹配的)前缀以后的部分与主串是否匹配。此外,取最长的公共真前后缀可以避免在移动中漏解。

所以我们预处理模式串,计算它每个前缀的最长真公共前后缀的长度,从而根据模式串与主串失配的位置查询到模式串下一次比对的位置

void KMP()
{
	//ne[j]表示模式串P[1, j]中公共真前后缀的最长长度
	m = strlen(S) - 1;
	for (int i = 1, j = 0; i <= m; i++)
	{
		while (j && S[i] != P[j + 1])j = ne[j];
		if (S[i] == P[j + 1])j++;
		if (j == n)cout << i - n + 1 << endl;
	}
	//输出模式串的前i位的最长公共真前后缀的长度
	for (int i = 1; i <= n; i++)cout << ne[i] << " ";
	cout << endl;
}

ne[N]数组的预处理

ne[j] 表示模式串 P[1, j] 中公共真前后缀的最长长度)

最朴素的想法当然是暴力做法,对每个 \(j\) 遍历公共真前后缀,找到最长的长度。显然这么做时间开销过大,所以我们也要对其优化。

首先,对于 P[1, j + 1],设它的最长公共真前后缀长度为 \(l + 1\),那么对于 P[1, j]它的前 \(l\) 位与后 \(l\) 位一定是匹配的不能保证是最长的,例如 "abaabab")。这就说明,要找 P[1, j + 1] 的最长公共真前后缀,我们可以找 P[1, j] 的所有公共真前后缀,然后判断各自加上下一位之后,两个字符串是否仍然相等

我们使用 \(i\) 扫描模式串,\(j\) 扫描前缀。在处理 ne[j] 的时候,我们考虑利用 ne[1]ne[2] ... ne[j - 1] 。设 \(k =\) ne[j - 1],这就说明模式串的前 \(k\) 位与第 \(j - k\) 位至第 \(j - 1\) 位是匹配的。由上述讨论,设 \(j\) 指向要判断前缀的前一位,\(i\) 指向 \(P\) 的第 \(i\) 位,我们只需要判断 P[i]P[j + 1] 是否相等,如果相等,就让 \(j\) 自增并赋值给 ne[i],否则让 \(j\) 跳到 ne[j] 上,直到满足P[i]P[j + 1] 相等或者 \(j=0\)即模式串的前 \(j\) 位不存在公共前后缀)。因为每一次 \(j\) 都会跳到当前 \(j\) 位前缀的最长公共真前后缀的前缀末尾,所以不会漏掉 P[1, j] 的前缀,也就不会漏解,又因为是从大到小遍历,所以我们找到的一定是 P[1, j + 1] 的最长公共真前后缀。

void init()
{
	P[0] = S[0] = ' ';
	cin >> S + 1 >> P + 1;
	n = strlen(P) - 1;
	//预处理模式串,ne[i]表示模式串P[1,i]中相等前后缀的最长长度、
	//i扫描模式串,j扫描前缀
	ne[1] = 0;
	for (int i = 2, j = 0; i <= n; i++)
	{
		while (j && P[i] != P[j + 1])j = ne[j];
		if (P[i] == P[j + 1])j++;
		ne[i] = j;
	}
}

代码

#include <iostream>
#include <cstring>
using namespace std;
const int N = 1000005;
char S[N], P[N];//分别表示主串与模式串(在主串中寻找模式串)
int m, n;//分别表示主串与模式串的长度
int ne[N];
void init()
{
	P[0] = S[0] = ' ';
	cin >> S + 1 >> P + 1;
	n = strlen(P) - 1;
	//预处理模式串,ne[i]表示模式串P[1,i]中相等前后缀的最长长度、
	//i扫描模式串,j扫描前缀
	ne[1] = 0;
	for (int i = 2, j = 0; i <= n; i++)
	{
		while (j && P[i] != P[j + 1])j = ne[j];
		if (P[i] == P[j + 1])j++;
		ne[i] = j;
	}
}
void KMP()
{
	//ne[j]表示模式串P[1, j]中公共真前后缀的最长长度
	m = strlen(S) - 1;
	for (int i = 1, j = 0; i <= m; i++)
	{
		while (j && S[i] != P[j + 1])j = ne[j];
		if (S[i] == P[j + 1])j++;
		if (j == n)cout << i - n + 1 << endl;
	}
	//输出模式串的前i位的最长公共真前后缀的长度
	for (int i = 1; i <= n; i++)cout << ne[i] << " ";
	cout << endl;
}
int main()
{
	init();
	KMP();
	return 0;
}