UOJ450 【集训队作业 2018】复读机

发布时间 2023-07-20 19:15:40作者: zltzlt

UOJ 传送门

\(d = 1\) 时答案显然为 \(k^n\)

下面只讨论 \(d = 3\) 的情况,\(d = 2\) 类似。

设每个人的指数型生成函数(EGF)为 \(G(x) = \sum\limits_{i = 0}^{+\infty} [3 \mid i] \frac{x^i}{i!}\)

欲求:

\[n! G(x)^k [x^n] \]

先化 \(G(x)\)

\[G(x) = \frac{1}{3} \sum\limits_{i = 0}^{+\infty} \sum\limits_{j = 0}^2 \frac{(\omega_3^jx)^i}{i!} \]

\[= \frac{1}{3} \sum\limits_{j = 0}^2 \sum\limits_{i = 0}^{+\infty} \frac{(\omega_3^jx)^i}{i!} \]

\[= \frac{1}{3} \sum\limits_{j = 0}^2 e^{w_3^j x} \]

运用的公式是 \(\sum\limits_{i = 0}^{\infty} \frac{(kx)^i}{i!} = e^{kx}\)

我们枚举 \(k\) 次方中有 \(a_0\) 个选了 \(j = 0\)\(a_1\) 个选了 \(j = 1\)\(a_2\) 选了 \(j = 2\),那么:

\[n! G(x)^k = \frac{n!}{3^k} \sum\limits_{a_0 + a_1 + a_2 = k} e^{(a_0 \omega_3^0 + a_1 \omega_3^1 + a_2 \omega_3^2)x} \binom{k}{a_0, a_1, a_2} \]

把上面的公式逆过来,可得 \(e^{kx} = \sum\limits_{i = 0}^{\infty} \frac{(kx)^i}{i!}\),因此:

\[n! [x^n] G(x)^k = \frac{1}{3^k} \sum\limits_{a_0 + a_1 + a_2 = k} (a_0 \omega_3^0 + a_1 \omega_3^1 + a_2 \omega_3^2)^n \binom{k}{a_0, a_1, a_2} \]

暴力枚举 \(a_0, a_1, a_2\) 即可,时间复杂度 \(O(k^{d - 1} \log n)\)

code
// Problem: #450. 【集训队作业2018】复读机
// Contest: UOJ
// URL: https://uoj.ac/problem/450
// Memory Limit: 512 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 500100;
const ll mod = 19491001, G = 7;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, K, d, fac[maxn], ifac[maxn];

void solve() {
	scanf("%lld%lld%lld", &n, &K, &d);
	fac[0] = 1;
	for (int i = 1; i <= K; ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[K] = qpow(fac[K], mod - 2);
	for (int i = K - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
	if (d == 1) {
		printf("%lld\n", qpow(K, n));
	} else if (d == 2) {
		ll ans = 0, w = qpow(G, (mod - 1) / 2);
		for (int i = 0; i <= K; ++i) {
			int j = K - i;
			ans = (ans + fac[K] * ifac[i] % mod * ifac[K - i] % mod * qpow((i + j * w) % mod, n) % mod) % mod;
		}
		ans = ans * qpow(qpow(2, mod - 2), K) % mod;
		printf("%lld\n", ans);
	} else if (d == 3) {
		ll ans = 0, w = qpow(G, (mod - 1) / 3);
		for (int i = 0; i <= K; ++i) {
			for (int j = 0; i + j <= K; ++j) {
				int k = K - i - j;
				ans = (ans + fac[K] * ifac[i] % mod * ifac[j] % mod * ifac[k] % mod * qpow((i + j * w + k * w % mod * w % mod) % mod, n) % mod) % mod;
			}
		}
		ans = ans * qpow(qpow(3, mod - 2), K) % mod;
		printf("%lld\n", ans);
	}
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}