CF213E Two Permutation 题解

发布时间 2023-12-13 19:50:38作者: 小超手123

CF213E Two Permutations 题解

题意:

给出两个排列$a,b $,长度分别为 \(n,m\),你需要计算有多少个 $ x $,使得 \(a_1 + x,a_2 + x,...a_n + x\)\(b\) 的子序列。
\(n \leq m \leq 2 \times 10^5\)

分析:

一个很自然的思路是直接枚举 \(x\),然后只保留 \(b\) 中值域在 \([x+1,x+n]\) 的数,然后利用哈希判断 \(a\)\(b\) 是否相同。

显然是可行的,看看如何维护 \(a\)\(b\) 的哈希值。

维护 \(a\) 的哈希值是简单的,记 \(S = \sum_{i=1}^{n}a_i \times base^{n-i+1}\),显然

\[\sum_{i=1}^{n}(a_i+x) \times base^{n-i+1}=S+x\sum_{i=1}^{n}base^{n-i+1} \]

\(b\) 的哈希也不难,考虑保留值域从 \([l,r]\)\([l+1,r+1]\) 的过程,明显删掉了 \(l\),加上了 \(r\)

我们可以利用线段树动态维护每个点的 \(b_i \times base^{n-i+1}\),删掉一个数后前面的数都要除以 \(base\),加入一个数后前面的数都要乘上 \(base\)

在线段树里面维护两个值 \(c_1,c_2\) 分别记录这个区间的 \(\sum b_i \times base^{n-i+1}\)里面有值的数的数量

这个线段树需要支持区间乘、单点赋值、区间查询。

时间复杂度为 \(O(n \log n)\)

代码:

#include<bits/stdc++.h>
#define int long long
#define base 1000000007
#define mod 998244353
#define N 200005
using namespace std;
int read() {
	char ch = getchar(); int x = 0, f = 1;
	while(ch < '0' || ch > '9') { if(ch == '-') f = -1; ch = getchar(); }
	while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
	return x * f;
}
void write(int x) {
	if(x < 0) putchar('-'), x = -x;
	if(x > 9) write(x / 10);
	putchar('0' + x % 10);
}
int Pow(int a, int n) {
	if(n == 0) return 1;
	if(n == 1) return a % mod;
	int x = Pow(a, n / 2);
	if(n % 2 == 0) return x * x % mod;
	else return x * x % mod * a % mod;
}
int inv(int x) {
	return Pow(x, mod - 2);
}
int n, m, hx, ans;
int a[N], b[N], h[N], f[N], S;
int c1[N * 4], c2[N * 4], tag[N * 4]; //c1记录hash值,c2记录有值的个数 
void pushup(int u) {
	c1[u] = (c1[u * 2] + c1[u * 2 + 1]) % mod;
	c2[u] = c2[u * 2] + c2[u * 2 + 1];
}
void maketag(int u, int x) {
	if(tag[u] != -1) tag[u] = tag[u] * x % mod; 
	else tag[u] = x % mod;
	if(c1[u] != -1) c1[u] = c1[u] * x % mod;
	else c1[u] = x % mod;
}
void pushdown(int u) {
	if(tag[u] == -1) return;
	maketag(u * 2, tag[u]);
	maketag(u * 2 + 1, tag[u]);
	tag[u] = -1;
}
void update1(int u, int L, int R, int x, int y) { //c1[x]=y 
	if(L == R) {
		if(c1[u] != 0 && y == 0) c2[u] = 0;
		else if(c1[u] == 0 && y != 0) c2[u] = 1;
		c1[u] = y;
		return;
	}
	int mid = (L + R) / 2;
	pushdown(u);
	if(x <= mid) update1(u * 2, L, mid, x, y);
	else update1(u * 2 + 1, mid + 1, R, x, y);
	pushup(u);
}
void update2(int u, int L, int R, int l, int r, int x) { //c1[l...r]乘上x 
	if(r < L || R < l) return;
	if(l <= L && R <= r) {
		maketag(u, x);
		return; 
	}
	pushdown(u);
	int mid = (L + R) / 2;
	update2(u * 2, L, mid, l, r, x);
	update2(u * 2 + 1, mid + 1, R, l, r, x);
	pushup(u);
}
int query(int u, int L, int R, int l, int r) { //查询[l,r]的c2 
	if(r < L || R < l) return 0;
	if(l <= L && R <= r) return c2[u];
	pushdown(u);
	int mid = (L + R) / 2;
	return query(u * 2, L, mid, l, r) + query(u * 2 + 1, mid + 1, R, l, r);
}
void Insert(int x, int y) { //在x处插入y 
    int z = query(1, 1, m, x + 1, m);
	update1(1, 1, m, x, y * h[z] % mod);
	update2(1, 1, m, 1, x - 1, base);
}
void Delete(int x) {
	update1(1, 1, m, x, 0);
	update2(1, 1, m, 1, x - 1, inv(base) % mod);
}
signed main() {
	memset(tag, -1, sizeof(tag));
    n = read(), m = read();
    for(int i = 1; i <= n; i++) a[i] = read();
    for(int i = 1; i <= m; i++) b[i] = read(), f[b[i]] = i;
    h[0] = 1; 
	for(int i = 1; i <= 200000; i++) h[i] = h[i - 1] * base % mod;
	for(int i = 1; i <= n; i++) {
		S += h[n - i]; S %= mod;
		hx += a[i] * h[n - i] % mod; hx %= mod;
	}
	for(int i = 1; i <= n; i++) Insert(f[i], b[f[i]]);
	for(int x = 0; x <= m - n; x++) { //保留b值域为 [1+x, x+n] 的数 
		if(x) {
			hx += S % mod; hx %= mod;
			Delete(f[x]);
			Insert(f[x + n], b[f[x + n]]);
		}
		if(c1[1] == hx) ans++;
	}
	write(ans);
	return 0;
}