AtCoder Regular Contest 139 D Priority Queue 2

发布时间 2023-05-24 18:35:41作者: zltzlt

洛谷传送门

AtCoder 传送门

看成方案数想了 114514 年……

这个东西看起来就很不好做,但是如果 \(a\)\(01\) 序列,就好做很多,事实上如果 \(a\)\(01\) 序列,只需要维护 \(1\) 的个数 \(c\),每次操作后当 \(c > n + 1 - X\)\(c\) 就减 \(1\)

考虑运用一个经典套路

\[\sum\limits_{i=1}^m i \times c(i) = \sum\limits_{i=1}^m c(\ge i) \]

也就是枚举 \(x\),把 \(\ge x\) 的数设为 \(1\)\(< x\) 的数设为 \(0\),最后求的是 \(1\) 的个数。

这个是 trivial 的。枚举 \(K\) 次操作中有多少次往序列里加了个 \(1\),最后的 \(c\) 是能确定的。

时间复杂度 \(O(m(n + K) \log K)\)

code
// Problem: D - Priority Queue 2
// Contest: AtCoder - AtCoder Regular Contest 139
// URL: https://atcoder.jp/contests/arc139/tasks/arc139_d
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 2020;
const int N = 2000;
const ll mod = 998244353;

ll n, m, K, fac[maxn], ifac[maxn], a[maxn], P;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

void init() {
	fac[0] = 1;
	for (int i = 1; i <= N; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[N] = qpow(fac[N], mod - 2);
	for (int i = N - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
}

inline ll C(ll n, ll m) {
	if (n < m || n < 0 || m < 0) {
		return 0;
	} else {
		return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
	}
}

void solve() {
	scanf("%lld%lld%lld%lld", &n, &m, &K, &P);
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
	}
	ll ans = 0;
	for (int x = 1; x <= m; ++x) {
		ll k = 0;
		for (int i = 1; i <= n; ++i) {
			k += (a[i] >= x);
		}
		for (int i = 0; i <= K; ++i) {
			ans = (ans + qpow(m - x + 1, i) * qpow(x - 1, K - i) % mod * C(K, i) % mod * max(k + i - K, min(n + 1 - P, k + i)) % mod) % mod;
		}
	}
	printf("%lld\n", ans);
}

int main() {
	init();
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}