KMP相关模板

发布时间 2023-03-28 14:01:08作者: Galetx

KMP

洛谷P3375

#include <bits/stdc++.h>
#define int long long
using namespace std;

int read() {
	int s = 0, f = 1;
	char ch = getchar();
	while (ch < '0' || ch > '9')
		f = (ch == '-' ? -1 : 1), ch = getchar();
	while (ch >= '0' && ch <= '9')
		s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar();
	return s * f;
}

int n, m, tot = 0;
char s1[1000005], s2[1000005], s[2000005];
int nxt[2000005];

signed main() {
	scanf("%s\n%s", s1 + 1, s2 + 1);
	n = strlen(s1 + 1), m = strlen(s2 + 1);
	for (int i = 1; i <= m; i++)
		s[++tot] = s2[i];
	s[++tot] = '#';
	for (int i = 1; i <= n; i++)
		s[++tot] = s1[i];
	for (int i = 2, j = 0; i <= tot; i++) {
		while (j && s[i] != s[j + 1])
			j = nxt[j];
		if (s[i] == s[j + 1])
			j++;
		nxt[i] = j;
	}
	for (int i = 2 * m + 1; i <= tot; i++)
		if (nxt[i] == m)
			printf("%lld\n", i - 2 * m);
	for (int i = 1; i <= m; i++)
		printf("%lld ", nxt[i]);
	return 0;
}

失配树

洛谷P5829

#include <bits/stdc++.h>
#define int long long
using namespace std;

int read() {
	int s = 0, f = 1;
	char ch = getchar();
	while (ch < '0' || ch > '9')
		f = (ch == '-' ? -1 : 1), ch = getchar();
	while (ch >= '0' && ch <= '9')
		s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar();
	return s * f;
}

#define N 1000005
#define M 1000005

int n;
char s[N];
int nxts[N];

int to[M], nxt[M], head[N] = {0}, tot = 0;
int dep[N] = {0}, an[N][30] = {{0}};
int pw[30];

void add(int u, int v){
	to[++tot] = v;
	nxt[tot] = head[u];
	head[u] = tot;
}

void update(int p, int fa){
	dep[p] = dep[fa] + 1, an[p][0] = fa;
	for (int i = 1; pw[i] <= dep[p]; i++)
		an[p][i] = an[an[p][i - 1]][i - 1];
	for (int i = head[p]; i; i = nxt[i])
		if (to[i] != fa)
			update(to[i], p);
}

int mov(int p, int d){
	int ans = p;
	for (int j = 20; j >= 0; j--)
		if (pw[j] <= d)
			ans = an[ans][j], d -= pw[j];
	return ans;
}

int lca(int u, int v){
	dep[u] > dep[v] ? u = mov(u, dep[u] - dep[v]) : v = mov(v, dep[v] - dep[u]);
	if (u == v)
		return u;
	for (int i = 20; i >= 0; i--)
		if (an[u][i] != an[v][i])
			u = an[u][i], v = an[v][i];
	return an[u][0];
}

signed main() {
	scanf("%s", s + 1);
	n = strlen(s + 1);
	add(0, 1);
	for (int i = 2, j = 0; i <= n; i++) {
		while (j && s[j + 1] != s[i])
			j = nxts[j];
		nxts[i] = j += (s[j + 1] == s[i]);
		add(j, i);
	}
	pw[0] = 1;
	for (int i = 1; i <= 20; i++)
		pw[i] = pw[i - 1] << 1;
	update(0, 0);
	int m = read();
	while (m--) {
		int p = read(), q = read(), l = lca(p, q);
		printf("%lld\n", (l == p || l == q) ? an[l][0] : l);
	}
	return 0;
}