【CF1621G】Weighted Increasing Subsequences 题解(优化树状数组)

发布时间 2023-07-03 11:30:40作者: 伍叁壹

CF 传送门 | LG 传送门

优化树状数组 + 反向处理。

Solution

  • 发现直接做不好下手。难点主要在求出所有的上升子序列并计算它们分别的贡献。
    所以需要反向考虑每个单点在什么情况下产生贡献。一个单点会产生多少贡献。
  • 一个单点产生贡献的条件很容易得到。一个是在一个上升子序列中;一个是它小于该序列后面的最大值。
    发现第二个条件是可以转化为一个“一般化”的限制的:对于每个点 \(i\) 都能找到 \(r_i\) 表示最后一个大于它的项。
    所以能让 \(i\) 产生贡献的一定是结尾在 \(r_i\) 前、包含 \(i\) 的上升子序列。
  • 然后考虑能使 \(i\) 产生贡献的上升子序列个数。
    根据乘法分配律,这个个数为:以 \(i\) 结尾的上升子序列个数 \(\times\)\(i\) 开头结尾在 \(r_i\) 前的上升序列个数。
  • 前者可以直接使用树状数组 \(O(n\log n)\) 得到。
    后者可以转化一下。使用同样的方法求出以 \(i\) 为开头的上升子序列个数 \(-\)\(i\) 开头且 \(r_i\) 结尾的上升序列个数。
    显然 \(r_i\) 之后的数对上升序列不会产生任何贡献影响。
  • 如果直接暴力求“以 \(i\) 开头且 \(r_i\) 结尾的上升序列个数”,那么复杂度将是 \(O(n^2\log n)\) 的。
    考虑优化。考虑省去一些无用的转移——即转移项对 \(i\) 的上升序列不会产生贡献。
  • 考虑什么样的项会对“以 \(i\) 开头 \(r_i\) 结尾的上升序列”产生贡献。
    不妨设这样的项下标为 \(j\),那么:\(a_i<a_j<a_{r_i}\)。所以显然 \(r_j=r_i\)
    \(r_i=x\)
    所以对于一个 \(x\),所有能对以它结尾的上升序列产生影响的项 \(i\)\(r_i=x\)
    所以只需要在所有 \(r_i=x\) 的项 \(i\) 考虑上升序列的方案数,一起考虑它们“以自身开头以 \(x\) 结尾的上升序列”的个数即可。那么对这些项遍历用树状数组求一遍即可。
  • 因为每个 \(i\) 只会有一个对应的 \(r_i\),所以单次询问复杂度是 \(O(n\log n)\) 的。

Code

#include<bits/stdc++.h>
using namespace std;

#define int long long 
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define per(i, a, b) for(int i = a; i >= b; --i) 
#define init rep(i, 1, n) t[i] = 0
const int maxn = 2e5 + 5, mod = 1e9 + 7;
int n, a[maxn], id[maxn];
int t[maxn], pre[maxn], suf[maxn], f[maxn];
int s[maxn], tp;
vector<int> d[maxn];

inline int lb(int x){ return x & (-x);}
inline void add(int x, int k){
	for(int i = x; i <= n; i += lb(i)) (t[i] += k) %= mod;
}
inline int qry(int x){ int res = 0;
	for(int i = x; i; i -= lb(i)) (res += t[i]) %= mod;
	return res;
}

inline bool cmp(int x, int y){
	return a[x] == a[y] ? x > y : a[x] < a[y];
}
inline void slv(){
	scanf("%lld", &n), tp = 0;
	rep(i, 1, n) d[i].clear();
	rep(i, 1, n) scanf("%lld", &a[i]), id[i] = i;
	sort(id + 1, id + n + 1, cmp); rep(i, 1, n) a[id[i]] = i;//离散化 
	init; rep(i, 1, n) add(a[i], pre[i] = qry(a[i] - 1) + 1);
	init; per(i, n, 1) add(n - a[i] + 1, suf[i] = qry(n - a[i]) + 1);//nlogn 求上升序列个数 
	per(i, n, 1) if(a[i] > a[s[tp]]) s[++tp] = i;
	per(i, n, 1){ int l = 1, r = tp, pos;
		while(l <= r){ 
			int mid = (l + r) >> 1;
			if(a[i] <= a[s[mid]]) r = mid - 1, pos = mid;
			else l = mid + 1;
		}
		if(i != s[pos]) d[s[pos]].push_back(i);
	} init;
	rep(i, 1, tp){
		add(n - a[s[i]] + 1, f[s[i]] = 1);
		for(int v : d[s[i]]) add(n - a[v] + 1, f[v] = qry(n - a[v]));
		for(int v : d[s[i]]) add(n - a[v] + 1, -f[v]);
		add(n - a[s[i]] + 1, -1);
	}
	int ans = 0;
	rep(i, 1, n) (ans += 1ll * (suf[i] - f[i] + mod) % mod * pre[i] % mod) %= mod;
	printf("%lld\n", ans);
}

signed main(){
	int T; scanf("%lld", &T);
	while(T--) slv();
	return 0;
}

Thanks for reading.