AtCoder Grand Contest 021 E Ball Eat Chameleons

发布时间 2023-07-01 19:46:13作者: zltzlt

洛谷传送门

AtCoder 传送门

容易发现一个变色龙是红色当且仅当,设 \(R\) 为红球数量,\(B\) 为蓝球数量,那么 \(R \ge B\)\(R = B\) 且最后一个球是蓝球。

考虑如何判定一个颜色序列是否可行。

考虑贪心。

  • \(R < B\) 显然不行。
  • \(R \ge B + n\),每个变色龙都可以分到比蓝球数多 \(1\) 的红球,答案为 \(\binom{R + B}{R}\)
  • \(R = B\),考虑删除颜色序列中的最后一位(这一位必然是 \(\texttt{B}\)),转化为 \(B < R < B + n\) 的情况。
  • \(B < R < B + n\),考虑分配得尽量平均,每个变色龙分到的蓝球数恰好等于红球数或等于红球数减 \(1\)。那么每个蓝球数等于红球数的变色龙,我们贪心地给它分配 \(\texttt{RB}\) 即可,剩下的全部给蓝球数等于红球数或等于红球数减 \(1\) 的变色龙。

于是现在问题转化为了,求满足可以取出 \(n - (R - B)\)\(\texttt{RB}\)\(\texttt{RB}\) 序列。

考虑转化成,每个前缀的红球数减蓝球数不小于 \(n - (R - B) - B = n - R\),也就是说不能匹配的蓝球数 \(\le R - n\)

考虑画出折线图,\(\texttt{R}\) 就是 \((x, y) \to (x + 1, y + 1)\)\(\texttt{B}\) 就是 \((x, y) \to (x + 1, y - 1)\)。那么就是要求从 \((0, 0)\) 走到 \((R + B, R - B)\) 且与 \(y = n - R - 1\) 没有交点的路径数。

考虑容斥,总路径数 \(\binom{R + B}{R}\) 减去与 \(y = n - R - 1\) 有交点的路径数。把在 \(y = n - R - 1\) 下方的折线翻上去,等价于从 \((0, 2n - 2R - 2)\)\((R + B, R - B)\)。设 \(x\) 为向右上走的步数,\(y\) 为向右下走的步数,那么 \(x + y = R + B, x - y = 3R - B - 2n + 2\),可得 \(x = 2R - n + 1\)。所以,与 \(y = n - R - 1\) 有交点的路径数就是 \(\binom{x + y}{x} = \binom{R + B}{2R - n + 1}\)

时间复杂度线性。

code
// Problem: E - Ball Eat Chameleons
// Contest: AtCoder - AtCoder Grand Contest 021
// URL: https://atcoder.jp/contests/agc021/tasks/agc021_e
// Memory Limit: 256 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 2000100;
const int N = 2000000;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, m, fac[maxn], ifac[maxn];

inline void init() {
	fac[0] = 1;
	for (int i = 1; i <= N; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[N] = qpow(fac[N], mod - 2);
	for (int i = N - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
}

inline ll C(ll n, ll m) {
	if (n < m || n < 0 || m < 0) {
		return 0;
	} else {
		return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
	}
}

void solve() {
	scanf("%lld%lld", &n, &m);
	ll ans = 0;
	for (int i = n; i <= m; ++i) {
		int j = m - i;
		if (i < j) {
			continue;
		}
		if (i >= j + n) {
			ans = (ans + C(i + j, i)) % mod;
		} else {
			if (i == j) {
				--j;
			}
			ans = (ans + C(i + j, i) - C(i + j, i * 2 - n + 1) + mod) % mod;
		}
	}
	printf("%lld\n", ans);
}

int main() {
	init();
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}