【题解】CF193D Two Segments

发布时间 2023-05-20 20:46:32作者: jiangchenyangsong

题意

给定一个\(1\sim N\)的排列,在这个排列中选出两段互不重叠的区间,求使选出的元素排序后构成公差为1的等差数列的方案数。选出的两段区间中元素构成的集合相同时视为同一种方案。\(1\le N\le 3\times 10^5\)

传送门

分析

如果考虑怎么优化枚举的两个区间的话,发现不太好搞(反正我只会暴力)。

于是考虑枚举连续的值域区间,再判断一下连续的值域区间是由原排列中几段连续的区间构成,如果 \(\le 2\),就是可行的方案。

对于这种区间问题,一般套路是确定一个点,然后对其他点算贡献。

\(f[l][r]\) 表示值域 \([l,r]\) 是由几段构成的,\(pos[i]\) 表示 \(i\) 这个值在原序列的位置,我们从 \(1\)\(n\) 依次枚举右端点 \(i\),考虑从 \(i - 1\) 转移到 \(i\),那如何在 \(O(i)\) 的时间内转移呢?可以找到如下规律:

  • 如果原序列中 \(i\) 在的位置左右两个数都 \(\le i\) ,那么肯定在之前加入了,而现在加入 \(i\) 会使 \([l, i], l \in [1, \min(a[pos[i]-1], a[pos[i] + 1])]\) 值域的段数 \(-1\)\([l, i], l \in (\min(a[pos[i]-1], a[pos[i] + 1]), \max(a[pos[i]-1], a[pos[i] + 1])]\)值域的段数不变,\([l, i], l \in (\max(a[pos[i]-1], a[pos[i] + 1]), i - 1]\) 值域的段数 \(+1\)
  • 如果只有一个,设那个数的位置为 \(x\),那么对于 \([l, i] ,l\in [1, x]\)值域的段数不变,\([l, i], l \in (x, i - 1]\) 的段数 \(+1\)
  • 如果没有,那么 \([l, i] ,l\in i - 1\)的值域的段数会 \(+1\)

这样是 \(O(n ^ 2)\),区间加,区间减,很容易想到用线段树优化。

设枚举\(i\),线段树的区间 \([l,r]\),表示 \([x, i], x\in [l,r]\) 的各种信息。

我们需要线段树维护区间的最小段数的值,是这个最小段数的值的区间个数,和是次小段数这个值的区间个数。

最后询问$1∼i−1 $需要分的段数是否小于等于 \(2\) 即可

如何维护详见代码注释。

#include<bits/stdc++.h>
#define N 300005
#define int long long
#define ls u << 1
#define rs u << 1 | 1 
using namespace std;
int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int n, ans; 
int a[N], pos[N];
int minn[N << 2], cnt0[N << 2], cnt1[N << 2], lazy[N << 2];
struct Segment{
	void pushup(int u){
		minn[u] = min(minn[ls], minn[rs]);
		cnt0[u] = (minn[ls] == minn[u]) * cnt0[ls] + (minn[rs] == minn[u]) * cnt0[rs];
		cnt1[u] = (minn[ls] == minn[u]) * cnt1[ls] + (minn[ls] == minn[u] + 1) * cnt0[ls]; 
		cnt1[u] += (minn[rs] == minn[u]) * cnt1[rs] + (minn[rs] == minn[u] + 1) * cnt0[rs];
		//如果左/右区间的最小值等于整个区间的最小值,那么左/右区间次小值就是整个区间的次小值,统计个数
		//如果左/右区间的最小值等于整个区间的最小值 + 1,那么最小值就是整个区间的次小值,因为每次枚举 $i$ 时,
		//值的变化最多加减1,所以次小值就是最小值 + 1。 
	}
	void pushdown(int u){
		minn[ls] += lazy[u], lazy[ls] += lazy[u];
		minn[rs] += lazy[u], lazy[rs] += lazy[u];
		lazy[u] = 0;
	}
 	void build(int u, int l, int r){
		if(l == r) return cnt0[u] = 1, void(); //初始都为 1 
		int mid = (l + r) >> 1;
		build(ls, l, mid), build(rs, mid + 1, r);
		pushup(u);
	}
	void update(int u, int l, int r, int L, int R, int val){
		if(L <= l && r <= R) return minn[u] += val, lazy[u] += val, void();
		pushdown(u);
		int mid = (l + r) >> 1;
		if(L <= mid) update(ls, l, mid, L, R, val);
		if(R > mid) update(rs, mid + 1, r, L, R, val);
		pushup(u);
	}
	int query(int u, int l, int r, int L, int R){
		if(L <= l && r <= R) return cnt0[u] * (minn[u] <= 2) + cnt1[u] * (minn[u] <= 1);
		//如果最小值小于等于 2 ,说明最小值是符合的,统计进去。
		//如果最小值小于等于 1 , 次小值 = 最小值 + 1 , 次小值也符合。 
		pushdown(u);
		int mid = (l + r) >> 1, res = 0;
		if(L <= mid) res += query(ls, l, mid, L, R);
		if(R > mid) res += query(rs, mid + 1, r, L, R);
		return res; 
	}
}tr;
signed main(){
	n = read();
	for(int i = 1; i <= n; ++i) a[i] = read(), pos[a[i]] = i;
	tr.build(1, 1, n);
	for(int i = 1; i <= n; ++i){
		tr.update(1, 1, n, 1, i, 1);
		if(a[pos[i] - 1] < i && a[pos[i] - 1]) tr.update(1, 1, n, 1, a[pos[i] - 1], -1);
		if(a[pos[i] + 1] < i && a[pos[i] + 1]) tr.update(1, 1, n, 1, a[pos[i] + 1], -1); 
		if(i >= 2) ans += tr.query(1, 1, n, 1, i - 1);//[i,i]这段是不能算进去的 
	}
	printf("%lld\n", ans);
	return 0;
}