「解题报告」AGC023E Inversions

发布时间 2023-05-19 20:32:10作者: APJifengc

好。

首先考虑怎么计算方案数。我们考虑按照 \(a_i\) 从小往大选,设排序后的下标为 \(b_i\),那么容易得出方案数为:

\[s = \prod_{i=1}^n (a_{b_i} - i + 1) \]

我们设 \(c_i = a_{b_i} - i + 1\),这代表着某个数的选择方案数。

然后考虑经典拆贡献,枚举每一对 \((i, j)\),求 \(p_i > p_j\) 的方案数,这样累加起来就是答案。

首先假设 \(a_i < a_j\),我们直接枚举 \(p_i, p_j\) 的选择方案,容易得出为 \(\dbinom{c_i}{2} = \dfrac{c_i(c_i - 1)}{2}\)。然后考虑这会导致 \(a_i\)\(a_j\) 之间的数的选择方案减少 \(1\),其他数的选择方案不变,那么可以写出式子:

\[\begin{aligned} f(i, j) &= \frac{c_i (c_i - 1)}{2} \times \frac{s}{c_i c_j} \times \prod_{a_i < a_k < a_j}\frac{c_k - 1}{c_k}\\ &=\frac{s (c_i - 1)}{2 c_j} \prod_{a_i < a_k < a_j}\frac{c_k - 1}{c_k} \end{aligned} \]

假如说 \(a_i > a_j\),那么我们可以先对称的求出 \(p_i > p_j\) 的方案数,然后总方案数减去它即可,即:

\[f(i, j) =s - \frac{s (c_i - 1)}{2 c_j} \prod_{a_i < a_k < a_j}\frac{c_k - 1}{c_k} \]

我们考虑维护这个东西。我们按照 \(a_i\) 从小往大的顺序依次计算答案,然后以 \(i\) 为下标建一棵线段树,这样我们就可以将 \(i < j\)\(i > j\) 的分开来计算。对于后者,我们还需要统计有多少个数,可以拿树状数组维护。这个式子容易通过区间乘,区间求和计算。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005, P = 1000000007;
int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
int n, a[MAXN], b[MAXN], c[MAXN];
struct SegmentTree {
    struct Node {
        int val, tag;
        Node() : val(0), tag(1) {}
    } t[MAXN << 2];
#define lc (i << 1)
#define rc (i << 1 | 1)
    void tag(int i, int v) {
        t[i].tag = 1ll * t[i].tag * v % P;
        t[i].val = 1ll * t[i].val * v % P;
    }
    void pushDown(int i) { tag(lc, t[i].tag), tag(rc, t[i].tag), t[i].tag = 1; }
    void mul(int a, int b, int v, int i = 1, int l = 1, int r = n) {
        if (a > b) return;
        if (a <= l && r <= b) return tag(i, v);
        int mid = (l + r) >> 1;
        pushDown(i);
        if (a <= mid) mul(a, b, v, lc, l, mid);
        if (b > mid) mul(a, b, v, rc, mid + 1, r);
        t[i].val = (t[lc].val + t[rc].val) % P;
    }
    void set(int d, int v, int i = 1, int l = 1, int r = n) {
        if (l == r) {
            t[i].val = v;
            return;
        }
        int mid = (l + r) >> 1;
        pushDown(i);
        if (d <= mid) set(d, v, lc, l, mid);
        else set(d, v, rc, mid + 1, r);
        t[i].val = (t[lc].val + t[rc].val) % P;
    }
    int query(int a, int b, int i = 1, int l = 1, int r = n) {
        if (a > b) return 0;
        if (a <= l && r <= b) return t[i].val;
        int mid = (l + r) >> 1;
        pushDown(i);
        if (b <= mid) return query(a, b, lc, l, mid);
        if (a > mid) return query(a, b, rc, mid + 1, r);
        return (query(a, b, lc, l, mid) + query(a, b, rc, mid + 1, r)) % P;
    }
} st;
struct BinaryIndexTree {
    int a[MAXN];
#define lowbit(x) (x & (-x))
    void add(int d, int v) {
        while (d <= n) {
            a[d] += v;
            d += lowbit(d);
        }
    }
    int query(int d) {
        if (!d) return 0;
        int ret = 0;
        while (d) {
            ret += a[d];
            d -= lowbit(d);
        }
        return ret;
    }
} bit;
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        b[i] = i;
    }
    sort(b + 1, b + 1 + n, [&](int x, int y) { return a[x] < a[y]; });
    int s = 1;
    for (int i = 1; i <= n; i++) {
        c[i] = a[b[i]] - i + 1;
        if (c[i] <= 0) {
            printf("0\n");
            return 0;
        }
        s = 1ll * s * c[i] % P;
    }
    int ans = 0;
    for (int i = 1; i <= n; i++) {
        ans = (ans + 1ll * st.query(1, b[i] - 1) * qpow(2 * c[i], P - 2)) % P;
        ans = (ans - 1ll * st.query(b[i] + 1, n) * qpow(2 * c[i], P - 2) % P + P) % P;
        ans = (ans + 1ll * (i - bit.query(b[i]) - 1) * s) % P;
        st.mul(1, n, 1ll * (c[i] - 1) * qpow(c[i], P - 2) % P);
        st.set(b[i], 1ll * s * (c[i] - 1) % P);
        bit.add(b[i], 1);
    }
    printf("%d\n", ans);
    return 0;
}