P9576 「TAOI-2」Ciallo~(∠・ω< )⌒★

发布时间 2024-01-11 09:01:24作者: cxqghzj

题意

求字符串 \(s\) 删去每个区间后字符串 \(t\) 出现的次数之和。

Sol

不难注意到答案分为两类:

  • 删去区间后,一个前缀和一个后缀刚好拼成 \(t\)
  • 存在于前缀之中,本身就与 \(t\) 匹配,以及存在于后缀之中,与 \(t\) 匹配的串。

第二类明显是 \(trivial\) 的。

预处理前缀后缀直接算就好。

仔细思考。不难发现可以预处理 \(f_i, g_i\) 分别表示以 \(i\) 开头的最长前缀,以及以 \(i\) 开头的最长后缀。

答案即为:

\[\sum_{i = 1} ^ {n} \sum_{j = i + |T|} ^ {n} [f_i + g_i > |T|] (f_i + g_i - |T| + 1) \]

考虑扫描线快速计算。

从右向左枚举 \(i\),用两个树状数组维护 \(|T| - g_i\) 即可。

Code

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <cassert>
#define int long long
using namespace std;
#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif
int read() {
	int p = 0, flg = 1;
	char c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-') flg = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9') {
		p = p * 10 + c - '0';
		c = getchar();
	}
	return p * flg;
}
string read_() {
	string ans;
	char c = getchar();
	while (c < 'a' || c > 'z')
		c = getchar();
	while (c >= 'a' && c <= 'z')
		ans += c, c = getchar();
	return ans;
}
void write(int x) {
	if (x < 0) {
		x = -x;
		putchar('-');
	}
	if (x > 9) {
		write(x / 10);
	}
	putchar(x % 10 + '0');
}
const int N = 8e5 + 5, mod = 147744151;

namespace Hash {

int gethash(string s) {
	int ans = 0;
	for (auto x : s)
		ans = ans * 131ll % mod + x;
	return ans;
}

array <int, N> idx;

void init() {
	idx[0] = 1;
	for (int i = 1; i <= 8e5; i++)
		idx[i] = idx[i - 1] * 131ll % mod;
}

void getarray(array <int, N> &isl, string &s) {
	for (int i = 1; i < (int)s.size(); i++)
		isl[i] = isl[i - 1] * 131ll % mod + s[i];
}

int query(array <int, N> &hs, int l, int r) {
	if (!l) return 0;
	return (hs[r] - hs[l - 1] * idx[r - l + 1] % mod + mod) % mod;
}

}

namespace Bit1 {

array <int, N> edge;

int lowbit(int x) {
	return x & -x;
}

void modify(int x, int y, int n) {
	assert(x);
	while (x <= n) {
		edge[x] += y;
		x += lowbit(x);
	}
	return;
}

int query(int x) {
	int ans = 0;
	while (x) {
		ans += edge[x];
		x -= lowbit(x);
	}
	return ans;
}

}

namespace Bit2 {

array <int, N> edge;

int lowbit(int x) {
	return x & -x;
}

void modify(int x, int y, int n) {
	while (x <= n) {
		edge[x] += y;
		x += lowbit(x);
	}
	return;
}

int query(int x) {
	int ans = 0;
	while (x) {
		ans += edge[x];
		x -= lowbit(x);
	}
	return ans;
}

}

array <int, N> f, g;
array <int, N> hsS, hsT;

signed main() {
	string s = " " + read_(), t = " " + read_();
	int n = s.size() - 1, m = t.size() - 1;
	Hash::init();
	Hash::getarray(hsS, s), Hash::getarray(hsT, t);
	for (int i = 1; i <= n; i++) {
		int l = i, r = n;
		int ans = i;
		while (l <= r) {
			int mid = (l + r) >> 1;
			if (Hash::query(hsS, i, mid) ==
				Hash::query(hsT, 1, mid - i + 1)) ans = mid + 1, l = mid + 1;
			else r = mid - 1;
		}
		f[i] = ans - i;
		l = i - m - 1, r = i;
		ans = i;
		while (l <= r) {
			int mid = (l + r) >> 1;
			/* write() */
			if (Hash::query(hsS, mid, i) ==
				Hash::query(hsT, m - i + mid, m)) ans = mid - 1, r = mid - 1;
			else l = mid + 1;
		}
		g[i] = i - ans;
		f[i] = min(f[i], m - 1);
		g[i] = min(g[i], m - 1);
	}
	int ans = 0;
	for (int i = 1; i <= n - m + 1; i++)
		if (Hash::query(hsS, i, i + m - 1) == Hash::query(hsT, 1, m))
			ans += (i - 1) * i / 2 + (n - i - m + 2) * (n - i - m + 1) / 2;
	/* write(ans), puts("@"); */
	/* for (int i = 1; i <= n; i++) */
		/* for (int j = i + m; j <= n; j++) */
			/* ans += (f[i] + g[j] >= m) * (f[i] + g[j] - m + 1); */
	/* write(ans), puts(""); */
	for (int i = n - m; i; i--) {
		Bit1::modify(g[i + m] + 1, 1, m + 1);
		Bit2::modify(g[i + m] + 1, g[i + m], m + 1);
		int tp = m - f[i], len = Bit1::query(m + 1) - Bit1::query(tp);
		ans += len * f[i] - len * m + len + Bit2::query(m + 1) - Bit2::query(tp);
		/* write(Bit2::query(m + 1) - Bit2::query(tp)), puts(""); */
	}
	write(ans), puts("");
	return 0;
}