Luogu P6292 区间本质不同子串个数

发布时间 2023-12-02 01:50:18作者: zzxLLL

给定字符串 \(S\)\(m\) 次询问 \(S_{l_i}S_{l_i + 1} \cdots S_{r_i}\) 中本质不同的字符串个数。

\(|S| \le 10 ^ 5, m \le 2 \times 10 ^ 5\)

考虑将询问离线,右端点扫描线,维护 \(f_l\)\(S_l S_{l + 1} \cdots S_r\) 有多少个 \(x \in [l, r]\) 满足 \(S[l, x]\)\(S[l, r]\) 中仅出现一次。询问 \(l_i, r_i\) 时答案是 \(\sum\limits_{x = l_i} ^ {r_i} f_x\)

考虑暴力怎么维护 \(f\) 数组。每次向后添加一个字符 \(S_{r}\),然后枚举 \(1 \le len \le r\),求出 \(S[r - len + 1, r]\) 上一次出现位置 \(p\),然后 \(f_{p - len + 1} \gets f_{p - len + 1} - 1, f_{r - len + 1} \gets f_{r - len + 1} + 1\),因为有 \(S[p - len + 1, p] = S[r - len + 1, r]\)

建出 \(S\) 的 SAM,设 \(S_r\) 对应的节点是 \(u\),发现上述枚举 \(len\) 的过程可以看作 parent tree 上从 \(u\) 向上跳到根的过程。维护每个节点 \(x\) 对应的 \(p_x\),也就是当前它的 \(\operatorname{endpos}_x\) 集合内的最大值。对于 \(u\) 的一个祖先 \(v\),对于 \(i \in [p_v - len_v + 1, p_v - len_{fa_v}]\),都 \(f_i \gets f_i - 1\),其中 \(len\) 和 SAM 中的 \(len\) 定义相同。然后将 \(u\) 到根路径上的 \(p\) 全部置为 \(r\) 即可。

可以用 LCT 将 \(p\) 相同的一段节点串成一条实链,一条实链上的 \([p_v - len_v + 1, p_v - len_{fa_v}]\) 的并是一段连续的区间,直接用线段树维护区间减即可。将路径上的 \(p\) 置为 \(r\),直接 LCT 上 access 并打上懒标记。最后还要 \([1, r]\) 区间加 \(1\),因为 \(S[x, r]\)\(S[x, r]\) 仅出现了一次。

Code
#include <bits/stdc++.h>
const int M = 4e5 + 10;

 int read() {
	char ch = getchar();
	int x = 0;
	while (ch < '0' || ch > '9') ch = getchar();
	while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
	return x;
}

int n, m;
char str[M];
struct Q { int l, r, id; } q[M];
long long ans[M];

namespace Tree {
	// f[i]: 以 i 开头的,[i, r] 中仅出现一次的字符串个数
	struct Node { int l, r; long long sum, lzy; } tr[M << 2];
	void pushup(int x) { tr[x].sum = tr[x << 1].sum + tr[x << 1 | 1].sum; }
	void tag(int x, long long v) { tr[x].lzy += v, tr[x].sum += 1ll * (tr[x].r - tr[x].l + 1) * v; }
	void pushdown(int x) {
		if (tr[x].lzy != 0)
			tag(x << 1, tr[x].lzy), tag(x << 1 | 1, tr[x].lzy), tr[x].lzy = 0;
	}
	void build(int x, int l, int r) {
		tr[x].l = l, tr[x].r = r;
		if (l == r) return;
		int mid = (l + r) >> 1;
		build(x << 1, l, mid), build(x << 1 | 1, mid + 1, r);
	}
	void add(int x, int l, int r, int v) {
		if (l > r) return;
		if (l <= tr[x].l && tr[x].r <= r) return tag(x, v);
		pushdown(x);
		int mid = (tr[x].l + tr[x].r) >> 1;
		if (l <= mid) add(x << 1, l, r, v);
		if (r > mid)  add(x << 1 | 1, l, r, v);
		pushup(x);
	}
	long long qry(int x, int l, int r) {
		if (l <= tr[x].l && tr[x].r <= r) return tr[x].sum;
		pushdown(x);
		int mid = (tr[x].l + tr[x].r) >> 1;
		long long ret = 0;
		if (l <= mid) ret += qry(x << 1, l, r);
		if (r > mid)  ret += qry(x << 1 | 1, l, r);
		return ret;
	}
}

struct SAM_Node { int ch[26], fa, len; } v[M];
int last = 1, tot = 1, ep[M];
void Insert(int x) {
	int cur = ++tot, p = last;
	v[cur].len = v[p].len + 1;
	while (p && !v[p].ch[x]) v[p].ch[x] = cur, p = v[p].fa;
	if (!p) v[cur].fa = 1;
	else {
		int q = v[p].ch[x];
		if (v[q].len == v[p].len + 1) v[cur].fa = q;
		else {
			int nq = ++tot;
			v[nq] = v[q], v[nq].len = v[p].len + 1;
			while (p && v[p].ch[x] == q) v[p].ch[x] = nq, p = v[p].fa;
			v[q].fa = v[cur].fa = nq;
		}
	}
	last = cur;
}

struct LCT_Node { int ch[2], fa, lzy, lst; } tr[M];
#define lc(x) tr[x].ch[0]
#define rc(x) tr[x].ch[1]
bool isRoot(int x) { return lc(tr[x].fa) != x && rc(tr[x].fa) != x; }
bool get(int x) { return rc(tr[x].fa) == x; }
void tag(int x, int v) { tr[x].lzy = tr[x].lst = v; }
void pushdown(int x) {
	if (tr[x].lzy != 0) {
		if (lc(x)) tag(lc(x), tr[x].lzy);
		if (rc(x)) tag(rc(x), tr[x].lzy);
		tr[x].lzy = 0;
	}
}
void push(int x) {
	if (!isRoot(x)) push(tr[x].fa);
	pushdown(x);
}
void rotate(int x) {
	int y = tr[x].fa, z = tr[y].fa;
	int k = get(x), w = tr[x].ch[k ^ 1];
	if (!isRoot(y)) tr[z].ch[get(y)] = x;
	tr[x].fa = z, tr[x].ch[k ^ 1] = y, tr[y].fa = x, tr[y].ch[k] = w;
	if (w) tr[w].fa = y;
}
void splay(int x) {
	for (push(x); !isRoot(x); rotate(x))
		if (!isRoot(tr[x].fa)) rotate(get(x) ^ get(tr[x].fa) ? x : tr[x].fa);
}
void access(int x, int i) {
	int y;
	for (y = 0; x; y = x, x = tr[x].fa) {
		splay(x), rc(x) = y;
		if (tr[x].lst) {
			int l = tr[x].lst - v[x].len + 1, r = tr[x].lst - v[tr[x].fa].len;
			Tree::add(1, l, r, -1);
		}
	}
	tag(y, i), Tree::add(1, 1, i, 1);
}

int main() {
	scanf(" %s", str + 1), n = strlen(str + 1);
	
	for (int i = 1; i <= n; i++) Insert(str[i] - 'a'), ep[i] = last;
	for (int i = 1; i <= tot; i++) tr[i].fa = v[i].fa;
	Tree::build(1, 1, n);
	
	m = read();
	for (int i = 1; i <= m; i++) q[i] = { read(), read(), i };
	std::sort(q + 1, q + 1 + m, [&](Q A, Q B) { return A.r < B.r; } );
	for (int i = 1; i <= m; i++) {
		for (int j = q[i - 1].r + 1; j <= q[i].r; j++) access(ep[j], j);
		ans[q[i].id] = Tree::qry(1, q[i].l, q[i].r);
	}
	for (int i = 1; i <= m; i++) printf("%lld\n", ans[i]);
	return 0;
}