CF1073G Yet Another LCP Problem

发布时间 2023-10-27 18:12:48作者: Ender_32k

一道 *2600 调了一年,代码细节是有点粪了,但自己菜也是挺菜的。/oh/oh

考虑容斥,令 \(f(A)=\sum\limits_{i,j\in A}\operatorname{lcp}(i,j)\),那么答案就是 \(f(A\cup B)-f(A)-f(B)\)(这里的并表示可重集合并)。

\(A=\{a_1,a_2,\cdots ,a_m\}\),并且 \(a_1\le a_2\le\cdots\le a_m\),那么 \(f(A)=\sum\limits_{1\le i\le j\le m}\operatorname{lcp}(i,j)\)

由于我们先前进行了容斥,所以可以忽略 \(i=j\) 的贡献,只需要考虑 \(\sum\limits_{1\le i<j\le m}\operatorname{lcp}(i,j)\) 即可。但是注意到 \(a_i\) 仍然有可能等于 \(a_j\)

建出后缀数组。为了方便,令 \(a_i\gets \operatorname{rk}_{a_i}\),那么对于 \(a_i< a_j\)\(\operatorname{lcp}(\operatorname{sa}_i,\operatorname{sa}_j)=\min\limits_{k\in [{a_i}+1,a_j]}\operatorname{height}_k\),变成了一个子区间最小值之和的问题,可以分治解决;对于 \(a_i=a_j\),单独计算每种 \(a_i\) 的贡献,只需要求出 \(a_j=k\)\(j\) 的个数 \(c_{k}\),那么 \(k\) 的贡献即为 \(\dbinom{c_k}{2}(n-\operatorname{sa}_{k}+1)\)

然后就做完了,复杂度 \(O(n\log n)\),注意分治时不能统计到 \(a_i=a_j\) 的贡献。

// Problem: G. Yet Another LCP Problem
// Contest: Codeforces - Educational Codeforces Round 53 (Rated for Div. 2)
// URL: https://codeforces.com/problemset/problem/1073/G
// Memory Limit: 256 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2")
#define pb emplace_back
#define mt make_tuple
#define mp make_pair
#define fi first
#define se second

using namespace std;
typedef long long ll;
typedef pair<int, int> pi;
typedef tuple<int, int, int> tu;
bool Mbe;

const int N = 4e5 + 400;
const int M = 20;

ll sum[N];
int n, m, q, len, a[N], b[N], pr[N], sf[N], f[N][M];
int id[N], ct[N], rk[N], sa[N], ht[N];
char s[N];

void rst() {
	memset(ct, 0, sizeof(int) * (m + 5));
	for (int i = 1; i <= n; i++) ct[rk[i]]++;
	for (int i = 1; i <= m; i++) ct[i] += ct[i - 1];
	for (int i = n; i; i--) sa[ct[rk[id[i]]]--] = id[i];
}

void SA() {
	m = 30;
	for (int i = 1; i <= n; i++) rk[i] = s[i] - 'a' + 1, id[i] = i;
	rst();
	for (int w = 1, p = 0; w <= n && p != n; w <<= 1, m = p) {
		p = 0;
		for (int i = n - w + 1; i <= n; i++) id[++p] = i;
		for (int i = 1; i <= n; i++) if (sa[i] > w) id[++p] = sa[i] - w;
		rst(), swap(rk, id), p = rk[sa[1]] = 1;
		for (int i = 2; i <= n; i++) rk[sa[i]] = (id[sa[i]] == id[sa[i - 1]] && id[sa[i] + w] == id[sa[i - 1] + w]) ? p : ++p;
	}
	for (int i = 1, j = 0; i <= n; i++) {
		if (j) j--;
		while (s[i + j] == s[sa[rk[i] - 1] + j]) j++;
		ht[rk[i]] = j;
	}
}

int qry(int l, int r) {
	if (l > r) return 0; 
	int len = __lg(r - l + 1);
	return min(f[l][len], f[r - (1 << len) + 1][len]);
}

ll conq(int l, int r) {
	if (l == r) return 0;
	int mid = (l + r) >> 1;
	ll res = conq(l, mid) + conq(mid + 1, r); sum[mid] = 0;
	for (int i = mid; i >= l; i--) sf[i] = qry(a[i] + 1, a[mid + 1]);
	for (int i = mid + 1; i <= r; i++) pr[i] = qry(a[mid] + 1, a[i]), sum[i] = sum[i - 1] + pr[i];
	int p = mid + 1, q = mid;
	while (p <= r && a[p] == a[mid]) p++;
	while (q >= l && a[q] == a[mid + 1]) q--;
	if (a[mid] == a[mid + 1]) res += 1ll * (sum[r] - sum[p - 1]) * (mid - q);
	for (int i = mid + 1; i <= r; i++) pr[i] = qry(a[q] + 1, a[i]), sum[i] = sum[i - 1] + pr[i];
	for (int i = q, j = mid + 1; i >= l; i--) {
		while (j <= r && sf[i] <= pr[j]) j++;
		res += 1ll * (j - mid - 1) * sf[i] + sum[r] - sum[j - 1];
	}
	return res;
}

ll calc(int l, int r) {
	for (int i = l; i <= r; i++) a[i] = b[i];
	sort(a + l, a + r + 1);
	ll res = 0;
	for (int i = l; i <= r; i++) {
		int j = i;
		while (j < r && a[j + 1] == a[i]) j++;
		res += 1ll * (j - i + 1) * (j - i) / 2 * (n - a[i] + 1), i = j;
	}
	for (int i = l; i <= r; i++) a[i] = rk[a[i]];
	sort(a + l, a + r + 1);
	return res + conq(l, r);
}

void solve() {
	cin >> n >> q >> (s + 1), SA();
	for (int i = 1; i <= n; i++) f[i][0] = ht[i];
	for (int j = 1; (1 << j) <= n; j++)
		for (int i = 1; i + (1 << j) - 1 <= n; i++)
			f[i][j] = min(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);
	while (q--) {
		int l1, l2; cin >> l1 >> l2;
		for (int i = 1; i <= l1; i++) cin >> b[i];
		for (int i = 1; i <= l2; i++) cin >> b[l1 + i];
		ll s1 = calc(1, l1 + l2), s2 = calc(1, l1), s3 = calc(l1 + 1, l1 + l2);
		// cout << s1 << ' ' << s2 << ' ' << s3 << '\n';
		cout << s1 - s2 - s3 << '\n';
	}
}

bool Med;
int main() {
	ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
	cerr << (&Mbe - &Med) / 1048576.0 << " MB\n";
	#ifdef FILE
		freopen("1.in", "r", stdin);
		freopen("1.out", "w", stdout);
	#endif
	int T = 1;
	// cin >> T;
	while (T--) solve();
	cerr << (int)(1e3 * clock() / CLOCKS_PER_SEC) << " ms\n";
	return 0;
}