AtCoder Grand Contest 023 E Inversions

发布时间 2023-09-20 10:16:54作者: zltzlt

洛谷传送门

AtCoder 传送门

首先将 \(a\) 从小到大排序,设 \(p_i\) 为排序后的 \(a_i\) 位于原序列第 \(p_i\) 个位置,\(x_i\) 为要填的排列的第 \(i\) 个数。

\(A = \prod\limits_{i = 1}^n (a_i - i + 1)\),则 \(A\) 为排列的总方案数(考虑按 \(a_i\) 从小到大填即得)。

套路地,统计每对 \((i, j), i < j\) 造成的逆序对贡献。设 \(f(i, j)\)\((p_i, p_j)\) 在排列中构成逆序对的方案。

\(p_i < p_j\),则 \(x_i > x_j\) 有:

\[\begin{aligned}f(i, j) & = \frac{(a_i - i + 1)(a_i - i)}{2} \times \frac{A}{(a_i - i + 1)(a_j - j + 1)} \times \prod\limits_{k = i + 1}^{j - 1} \frac{a_k - k}{a_k - k + 1} \\ & = \frac{(a_i - i)A}{2(a_j - j + 1)} \times \prod\limits_{k = i + 1}^{j - 1} \frac{a_k - k}{a_k - k + 1}\end{aligned} \]

考虑在 \([1, a_i]\) 中选出两个数分配给 \(x_i\)\(x_j\),在总方案数中去除 \(x_i, x_j\) 造成的贡献,对于 \(k \in [i + 1, j - 1]\)\(x_k\) 能选的数少了 \(1\) 个,故减去。然后约分化简得上式。

\(p_i > p_j\),我们计算 \((i, j)\) 构成顺序对的方案数再减去,有:

\[f'(i, j) = A - \frac{(a_i - i)A}{2(a_j - j + 1)} \times \prod\limits_{k = i + 1}^{j - 1} \frac{a_k - k}{a_k - k + 1} \]

看到式子有个 product 很不顺眼,考虑设 \(b_i = \prod\limits_{j = 1}^i \frac{a_j - j}{a_j - j + 1}\)\(c_i = \frac{1}{b_i} = \prod\limits_{j = 1}^i \frac{a_j - j + 1}{a_j - j}\)。那么:

\[f(i, j) = A \times ((a_i - i) \times c_i) \times \frac{b_{j - 1}}{a_j - j + 1} \]

这是一个二维偏序的形式(\(i < j \land p_i < p_j\))。树状数组维护 \((a_i - i) \times c_i\) 的和,在 \(j\) 处乘上 \(\frac{b_{j - 1}}{a_j - j + 1}\) 并加入最终答案即可。

对于 \(f'(i, j)\),我们还需要计算 \(i < j \land p_i > p_j\) 的数量,可以再开一个树状数组。

但是这样有个问题,可能存在 \(a_i - i = 0\),因此可能存在 \(b_i = 0\)。为了不影响前缀积,考虑强制把 \(a_i - i = 0\) 的位置当作 \(1\) 乘进去,然后规定计算 \(f(i, j)\) 时,若 \(\exists k \in [i + 1, j - 1], a_k - k = 0\),就使 \(f(i, j) = 0\)。那我们可以把 \(a_k - k = 0\) 的位置看作一个挡板,把序列分成若干个块,每次只计算块内互相贡献的答案即可。

目前是 AtCoder 最优解。

code
// Problem: E - Inversions
// Contest: AtCoder - AtCoder Grand Contest 023
// URL: https://atcoder.jp/contests/agc023/tasks/agc023_e
// Memory Limit: 256 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
inline int read() {
    char c = getchar();
    int x = 0;
    for (; !isdigit(c); c = getchar()) ;
    for (; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
    return x;
}

const int maxn = 200100;
const ll mod = 1000000007;
const ll inv2 = (mod + 1) / 2;

ll n, inv[maxn], b[maxn], c[maxn], d[maxn], f[maxn], g[maxn];

struct node {
	ll x, i;
} a[maxn];

inline void upd(ll &x, ll y) {
	((x += y) >= mod) && (x -= mod);
}

struct BIT {
	ll c[maxn];
	
	inline void update(int x, ll d) {
		for (int i = x; i <= n; i += (i & (-i))) {
			upd(c[i], d);
		}
	}
	
	inline ll query(int x) {
		ll res = 0;
		for (int i = x; i; i -= (i & (-i))) {
			upd(res, c[i]);
		}
		return res;
	}
	
	inline ll query(int l, int r) {
		return (query(r) - query(l - 1) + mod) % mod;
	}
} t1, t2;

void solve() {
	n = read();
	inv[0] = inv[1] = 1;
	for (int i = 2; i <= n; ++i) {
		inv[i] = (mod - mod / i) * inv[mod % i] % mod;
	}
	for (int i = 1; i <= n; ++i) {
		a[i].x = read();
		a[i].i = i;
	}
	sort(a + 1, a + n + 1, [&](const node &a, const node &b) {
		return a.x < b.x;
	});
	ll A = 1;
	for (int i = 1; i <= n; ++i) {
		A = A * (a[i].x - i + 1) % mod;
	}
	if (!A) {
		puts("0");
		return;
	}
	ll B = A * inv2 % mod;
	b[0] = c[0] = 1;
	for (int i = 1; i <= n; ++i) {
		b[i] = b[i - 1] * max(a[i].x - i, 1LL) % mod * inv[a[i].x - i + 1] % mod;
		c[i] = c[i - 1] * inv[a[i].x - i] % mod * (a[i].x - i + 1) % mod;
		f[i] = b[i - 1] * inv[a[i].x - i + 1] % mod;
		g[i] = (a[i].x - i) * c[i] % mod;
	}
	ll ans = 0;
	for (int i = 1, j = 1; i <= n; ++i) {
		ans = (ans + B * t1.query(a[i].i - 1) % mod * f[i] % mod) % mod;
		ll res = B * t1.query(a[i].i + 1, n) % mod * f[i] % mod;
		ans = (ans + A * t2.query(a[i].i + 1, n) % mod - res + mod) % mod;
		t1.update(a[i].i, g[i]);
		t2.update(a[i].i, 1);
		if (a[i].x == i) {
			while (j < i) {
				t1.update(a[j].i, mod - g[j]);
				++j;
			}
		}
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}