「解题报告」CF1621G Weighted Increasing Subsequences

发布时间 2023-06-22 16:12:21作者: APJifengc

比较套路的拆贡献题。

考虑直接枚举那个 \(j\),求有多少包含 \(j\) 的上升子序列满足这个子序列最后一个数的后面有大于 \(a_j\) 的数。

首先对于 \(j\) 前面的选择方案是没有影响的,可以直接拿树状数组 DP 一遍得到。后面的过程我们可以找到从后往前第一个大于 \(a_j\) 的数的位置 \(x\),那么后面的方案就是 \(\lbrack j, x)\) 中包含 \(j\) 的上升子序列数。这个东西直接求不好求,考虑任意一个以 \(j\) 开头的上升子序列,由于 \(x\) 是第一个大于 \(a_j\) 的数,说明这个数后面的数都比 \(a_j\) 小,那么以 \(j\) 开头的上升子序列不可能包含这些数。那么实际上只有两种序列:要不然不包含 \(x\),要不然以 \(x\) 结尾。我们容斥一下,求所有以 \(x\) 结尾的上升子序列数。发现,我们要求的就是以 \(j\) 开头,以 \(x\) 结尾的上升子序列数。看起来还是不可做,但是考虑到 \(x\) 表示的数一定是第一个大于 \(a_j\) 的数,也就是说假如我们把后缀最大值写成一个序列 \(b_1 < b_2 < \cdots < b_m\),那么在这个子序列中的数值域一定在 \((b_{x - 1}, b_x\rbrack\),容易发现这样的值域总数是 \(O(n)\) 的。那么我们直接拿树状数组跑 DP,就也是 \(O(n \log n)\) 的了。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005, P = 1000000007;
int T, n, m, a[MAXN], b[MAXN];
struct BinaryIndexTree {
    int a[MAXN];
#define lowbit(x) (x & (-x))
    void init() { for (int i = 1; i <= n; i++) a[i] = 0; }
    void add(int d, int v) {
        while (d <= n) {
            a[d] = (a[d] + v) % P;
            d += lowbit(d);
        }
    }
    int query(int d) {
        if (!d) return 0;
        int ret = 0;
        while (d) {
            ret = (ret + a[d]) % P;
            d -= lowbit(d);
        }
        return ret;
    }
} bit;
int f[MAXN], g[MAXN], h[MAXN];
pair<int, int> q[MAXN];
int main() {
    scanf("%d", &T);
    while (T--) {
        scanf("%d", &n);
        for (int i = 1; i <= n; i++) scanf("%d", &a[i]), b[i] = a[i];
        sort(b + 1, b + 1 + n);
        int m = unique(b + 1, b + 1 + n) - b - 1;
        for (int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + 1 + n, a[i]) - b;
        bit.init();
        for (int i = 1; i <= n; i++) {
            f[i] = (1 + bit.query(a[i] - 1)) % P;
            bit.add(a[i], f[i]);
        }
        bit.init();
        for (int i = n; i >= 1; i--) {
            g[i] = (1 + bit.query(n) - bit.query(a[i]) + P) % P;
            bit.add(a[i], g[i]);
        }
        m = 0;
        for (int i = n; i >= 1; i--) {
            if (!m || q[m].first < a[i]) q[++m] = { a[i], i }, b[i] = -1;
        }
        bit.init();
        for (int i = n; i >= 1; i--) {
            auto it = lower_bound(q + 1, q + 1 + m, make_pair(a[i], INT_MAX));
            if (it != q + m + 1 && it->second > i) {
                h[i] = (bit.query(a[it->second]) - bit.query(a[i]) + P) % P;
                auto it2 = lower_bound(q + 1, q + 1 + m, make_pair(a[i], 0));
                if (it2->first != a[i]) bit.add(a[i], h[i]);
            } else if (b[i] == -1) {
                h[i] = 1;
                bit.add(a[i], h[i]);
            } else {
                h[i] = 0;
            }
        }
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            auto it = lower_bound(q + 1, q + 1 + m, make_pair(a[i], INT_MAX));
            if (it != q + m + 1 && it->second > i) {
                // printf("%d: %d %d %d\n", i, f[i], g[i], h[i]);
                ans = (ans + 1ll * f[i] * (g[i] - h[i] + P)) % P;
            }
        }
        printf("%d\n", ans);
    }
    return 0;
}