Solution to AT_abc310_f Make 10 Again

发布时间 2023-07-21 15:57:37作者: escap1st

Statement

你有 \(N\) 个骰子和一个序列 \(A_i\),第 \(i\) 个骰子能等概率掷出 \(1 \sim A_i\) 的点数。

在同时掷出 \(N\) 个骰子后,求下面所述的条件被满足的概率模 \(998,244,353\) 的值:

能够选出这些骰子的一个子集,使得子集内的骰子掷出的点数和为 \(10\)

保证 \(1 \le N \le 100\)\(1 \le A_i \le {10}^6\)

Solution

发现要求掷出的点数为 \(10\),所以 \(A_i\) 中比 \(10\) 大的数都没有贡献。

当我们投掷一些骰子,选出它们的所有子集,对每个子集中骰子的掷出点数和取并,可以得到这些骰子可能的贡献集合。根据前文,贡献集合中大于 \(10\) 的元素没有意义,因此我们只考虑元素 \(0 \sim 10\) 是否存在于贡献集合中。这启发我们状态压缩表示这个集合。

设计 DP 状态 \(f_{i, T}\) 表示:投掷前 \(i\) 个骰子,使得这些骰子的贡献集合为 \(T\) 的概率。

考虑转移。如果前 \(i-1\) 个骰子的贡献集合中存在 \(x\),掷第 \(i\) 个骰子得到点数 \(j\) 后,意味着前 \(i\) 个骰子的贡献集合必然存在 \(x+j\)\(x\) 对应的子集选中骰子 \(i\)) 和 \(x\)\(x\) 对应的子集不选骰子 \(x\))。即:

\[T = \mathrm{trans}( T', j) = (\{x+j|x\in T'\} \cup T') \cap \{0,1,\cdots,10\} \]

对第 \(i\) 个骰子而言,掷出某个点数的概率都是 \(\frac{1}{A_i}\)

对于掷第 \(i\) 个骰子得到 \(j \ge 10\) 时,发现新的贡献集合等于原来的(前 \(i-1\) 个骰子)贡献集合,不妨认为 \(j \ge 10\) 是同一个事件,其概率为 \(\frac{A_i-10}{A_i}\)

对初始情况,掷了 \(0\) 个骰子的贡献集合显然为 \(\{0\}\),其概率为 \(1\)

据此,我们得到状态转移方程

\[f_{0, T} = \begin{cases} 0, &\textrm{when T } = \{0\}\\ 1, &\textrm{otherwise} \end{cases} \]

\[\begin{cases} f_{i, \mathrm{trans}( T, j )} \gets f_{i-1, T} \times \frac{1}{A_i}, j \in \{1,2,\cdots,10\}, &\textrm{for all } A_i\\ f_{i, T} \gets f_{i-1, T} \times \frac{A_i-10}{A_i}, &\textrm{when }A_i > 10 \end{cases} \]

其中

\[\mathrm{trans}( T, j) = (\{x+j|x\in T\} \cup T) \cap \{0,1,\cdots,10\} \]

答案就是掷前 \(N\) 个骰子,贡献集合包含 \(10\) 的概率,即 \(\sum_{10 \in T} f_{N, T}\)

转移时枚举 \(i, j, T\),时间复杂度 \(\mathcal{O}( N \cdot 2^{w} \cdot w)\),空间复杂度 \(\mathcal{O}( N \cdot 2^{w})\),这里 \(w = 11\)

空间复杂度可以用滚动数组优化至 \(\mathcal{O}( 2^{w})\)。具体见代码实现。

Code

#include <array>
#include <iostream>
#include <vector>

typedef long long LL;

const LL p = 998244353;
const int mask = (1 << 11) - 1;

LL qpow(LL a, LL b) {
	return b ? (b & 1 ? a * qpow(a, b - 1) % p : qpow(a * a % p, b / 2) % p) : 1ll;
}

int main() {
	int n;
	std::cin >> n;

	std::vector<std::array<LL, (1 << 11)>> f(2);

	f[0][1] = 1;
	bool last = 0, now = 1;

	for (int i = 1; i <= n; ++i) {
		int x;
		std::cin >> x;
		LL inv = qpow(x, p - 2);
		f[now].fill(0);

		for (int j = 1; j <= std::min(x, 10); ++j) {
			for (LL k = 0; k <= mask; ++k) {
				LL trans = (((k << j) | k) & mask);
				(f[now][trans] += f[last][k] * inv) %= p;
			}
		}
		for (LL k = 0; k <= mask; ++k) {
			(f[now][k] += f[last][k] * inv % p * std::max(x - 10, 0)) %= p;
		}

		std::swap(now, last);
	}

	LL ans = 0;
	for (int i = 0; i < (1 << 10); ++i) {
		(ans += f[last][i | (1 << 10)]) %= p;
	}

	std::cout << ans << '\n';
	return 0;
}