【题解】CF1621G Weighted Increasing Subsequences

发布时间 2023-11-29 00:29:22作者: kymru

常规,但不常规。

思路来自 @gyh.

思路

BIT 优化计数。

本来考虑的是对 LIS 进行计数,得到一个对 \([]\) 形式的值套三层求和的方式,然后再瞪眼找优化方法,但是没有发现什么好的处理方法,于是只能考虑转换计数方法。

考虑通过每个位置对答案的贡献计数。假设某个位置 \(x\) 被一个合法的子序列 \(i_1, \cdots i_k\) 包含,考虑此时需要满足的限制。

其实很简单,考虑最后一个满足题目限制的位置,令 \(y\) 为最靠后的满足 \(a_y > a_x\) 的位置。只需限制 \(i_k < y\),就可以通过位置 \(y\) 满足题目的限制。同时容易发现 \(y\) 以及之后的位置都不可能满足限制,所以这个条件是充要的。

对于每个位置,考虑在统计以其为开头的上升子序列时计算它对答案的贡献。换言之,对于 \(1 \leq x \leq n\),统计 \([x, y - 1]\) 中包含 \(x\) 的上升子序列的个数。

通过树状数组优化朴素的 dp 做法,可以达到 \(O(n^2 \log n)\) 的复杂度。

同时可以观察得到合法的 \(y\) 一定是原序列的后缀最大值,可以进一步减小常数。

优化发现没法再转化统计方式,于是考虑通过容斥一类的方式做一些手脚。先假定所有包含 \(x\) 的上升子序列均符合限制,再容斥减去不符合限制的子序列。

分类讨论。当子序列以 \([x, y - 1]\) 中的位置结尾时,序列一定符合条件;当子序列以 \(y\) 结尾时,不符合题目限制;当子序列以 \([y + 1, n]\) 中的位置结尾时,因为 \(y\) 的定义,一定不可能从中找出比 \(a_y\) 大的值,所以这种情况矛盾,不可能统计。

这意味着我们只需要先预处理出以 \(x\) 为开头的上升子序列数量,再计算以 \(x\) 开头且 \(y\) 为结尾的上升子序列数量。

同时题目有进一步的性质:令 \(z\) 为满足 \(z > y\)\(a_z\) 最大的位置。对于某个 \(y\),其对应的所有 \(x\) 都应当满足 \(a_z \leq a_x < a_y\),也就是对于每个 \(y\),其对应的 \(a_x\) 一定在某个区间内,统计的时候只需要考虑这些位置,直接通过定义二分预处理出来。

又因为 \(x\)\(y\) 可以看成是一种单射关系,也就是每个 \(x\) 所在的被 \(y\) 确定的区间是唯一的,所以均摊的复杂度是 \(O(n \log n)\).

两次转化算是比较常规,就是瞪不出来题目的性质,跟做初中几何一个样。

代码

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

const int maxn = 2e5 + 5;
const int mod = 1e9 + 7;

int n;
int top, stk[maxn];
int a[maxn], p[maxn];
int pre[maxn], suf[maxn], f[maxn];
vector<int> seq[maxn];

namespace BIT
{
	int c[maxn];

	void init() { for (int i = 1; i <= n; i++) c[i] = 0; }

	int lowbit(int x) { return x & (-x); }

	void update(int p, int w) { for (int i = p; i <= n; i += lowbit(i)) c[i] = (c[i] + w) % mod; }

	int query(int p)
	{
		int res = 0;
		for (int i = p; i; i -= lowbit(i)) res = (res + c[i]) % mod;
		return res;
	}
}
using namespace BIT;

bool cmp(int x, int y) { return (a[x] == a[y] ? x > y : a[x] < a[y]); }

void solve()
{
	scanf("%d", &n);
	top = 0;
	for (int i = 1; i <= n; i++) seq[i].clear();
	for (int i = 1; i <= n; i++) scanf("%d", &a[i]), p[i] = i;
	sort(p + 1, p + n + 1, cmp);
	for (int i = 1; i <= n; i++) a[p[i]] = i;
	init(); for (int i = 1; i <= n; i++) update(a[i], pre[i] = query(a[i] - 1) + 1);
	init(); for (int i = n; i >= 1; i--) update(n - a[i] + 1, suf[i] = query(n - a[i]) + 1);
	for (int i = n; i >= 1; i--) if (a[i] > a[stk[top]]) stk[++top] = i;
	for (int i = n; i >= 1; i--)
	{
		int l = 1, r = top;
		while (l < r)
		{
			int mid = (l + r) >> 1;
			if (a[i] <= a[stk[mid]]) r = mid;
			else l = mid + 1;
		}
		if (i != stk[l]) seq[stk[l]].push_back(i);
	}
	init();
	for (int i = 1; i <= top; i++)
	{
		update(n - a[stk[i]] + 1, f[stk[i]] = 1);
		for (int p : seq[stk[i]]) update(n - a[p] + 1, f[p] = query(n - a[p]));
		for (int p : seq[stk[i]]) update(n - a[p] + 1, -f[p]);
		update(n - a[stk[i]] + 1, -1);
	}
	int ans = 0;
	for (int i = 1; i <= n; i++) ans = (ans + 1ll * (suf[i] - f[i] + mod) % mod * pre[i] % mod) % mod;
	printf("%d\n", ans);
}

int main()
{
	int t;
	scanf("%d", &t);
	while (t--) solve();
	return 0;
}