AtCoder Regular Contest 133 E Cyclic Medians

发布时间 2023-05-19 19:57:34作者: zltzlt

洛谷传送门

AtCoder 传送门

其实是套路题,但是为什么做不出来啊

第一步就是经典套路。枚举 \(k\),统计中位数 \(> k\) 的方案数,加起来就是中位数的总和。

那么现在 \(x_{1 \sim n}, y_{1 \sim m}\) 就变成了 \(0/1\) 序列,考虑一次操作,如果 \((x,y) = (0,0)\),那么 \(a\) 会变成 \(0\);如果 \((x,y) = (1,1)\),那么 \(a\) 会变成 \(1\);否则 \(a\) 不变。

到这里我就卡住了。想着枚举最后一次 \((0,0)\)\((1,1)\) 的操作,然后发现根本算不了。

其实,发现如果出现了一次 \(x = y\) 的操作,最后 \(a\) 的取值就跟 \(a\) 原来的值无关了。进一步发现,由于对称性\(k = p\) 时最后一次 \(x = y\) 的操作是 \((0,0)\) 的方案数和 \(k = V - p\) 时最后一次 \(x = y\) 的操作是 \((1,1)\) 的方案数相等

现在考虑统计只出现 \(x \ne y\) 的操作的方案数。

考虑一些特殊情况,例如 \(\gcd(n, m) = 1\),那么每个数对 \((p,q), p \in [0,n), q \in [0,m)\)\(\{(i \bmod n, i \bmod m) | i \in [0,nm)\}\) 中出现且仅出现一次。

那么对于一般性的情况,考虑计算 \(g = \gcd(n, m)\),那么 \(i + gt, i \in [0, g), t \in [0, \frac{n}{g})\) 只有可能跟 \(i + gt, i \in [0, g), t \in [0, \frac{m}{g})\) 配对。那么就是要求,\(\forall i \in [0, g)\)

  • \(x_i = x_{i + g} = x_{i + 2g} = \cdots = x_{i + (\frac{n}{g} - 1) g}\)
  • \(y_i = y_{i + g} = y_{i + 2g} = \cdots = y_{i + (\frac{m}{g} - 1) g}\)
  • \(x_i \ne y_i\)

这个的方案数容易统计,\([1, V]\)\(\le k\) 的数有 \(k\) 个,\(> k\) 的数有 \(V - k\) 个,那么分别讨论 \((x, y)\)\((0, 0)\)\((1, 1)\) 的情况,方案数即为:

\[(k^{\frac{n}{g}} (V - k)^{\frac{m}{g}} + (V - k)^{\frac{n}{g}} k^{\frac{m}{g}})^g \]

那么这种情况时 \(a\) 不变,只有原来 \(a\)\(1\) 时才产生贡献。还要计算出现过 \(x = y\) 操作的情况。总方案数 \(V^{n + m}\) 减去上面的式子,得出的就是出现过 \(x = y\) 操作的方案数总和,除以 \(2\) 就是 \((1, 1)\) 的情况。

然后我们就以 \(O(V \log (n + m))\) 的时间复杂度做完了。

code
// Problem: E - Cyclic Medians
// Contest: AtCoder - AtCoder Regular Contest 133
// URL: https://atcoder.jp/contests/arc133/tasks/arc133_e
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const ll mod = 998244353;
const ll inv2 = (mod + 1) / 2;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, m, V, A;

void solve() {
	scanf("%lld%lld%lld%lld", &n, &m, &V, &A);
	ll ans = qpow(V, n + m), g = __gcd(n, m);
	for (int i = 1; i < V; ++i) {
		ll t = qpow((qpow(i, n / g) * qpow(V - i, m / g) % mod + qpow(V - i, n / g) * qpow(i, m / g) % mod) % mod, g);
		if (A > i) {
			ans = (ans + t) % mod;
		}
		ll all = qpow(V, n + m);
		t = (all - t + mod) % mod;
		ans = (ans + t * inv2 % mod) % mod;
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}