\(d = 1\) 时答案显然为 \(k^n\)。
下面只讨论 \(d = 3\) 的情况,\(d = 2\) 类似。
设每个人的指数型生成函数(EGF)为 \(G(x) = \sum\limits_{i = 0}^{+\infty} [3 \mid i] \frac{x^i}{i!}\)
欲求:
\[n! G(x)^k [x^n]
\]
先化 \(G(x)\):
\[G(x) = \frac{1}{3} \sum\limits_{i = 0}^{+\infty} \sum\limits_{j = 0}^2 \frac{(\omega_3^jx)^i}{i!}
\]
\[= \frac{1}{3} \sum\limits_{j = 0}^2 \sum\limits_{i = 0}^{+\infty} \frac{(\omega_3^jx)^i}{i!}
\]
\[= \frac{1}{3} \sum\limits_{j = 0}^2 e^{w_3^j x}
\]
运用的公式是 \(\sum\limits_{i = 0}^{\infty} \frac{(kx)^i}{i!} = e^{kx}\)。
我们枚举 \(k\) 次方中有 \(a_0\) 个选了 \(j = 0\),\(a_1\) 个选了 \(j = 1\),\(a_2\) 选了 \(j = 2\),那么:
\[n! G(x)^k = \frac{n!}{3^k} \sum\limits_{a_0 + a_1 + a_2 = k} e^{(a_0 \omega_3^0 + a_1 \omega_3^1 + a_2 \omega_3^2)x} \binom{k}{a_0, a_1, a_2}
\]
把上面的公式逆过来,可得 \(e^{kx} = \sum\limits_{i = 0}^{\infty} \frac{(kx)^i}{i!}\),因此:
\[n! [x^n] G(x)^k = \frac{1}{3^k} \sum\limits_{a_0 + a_1 + a_2 = k} (a_0 \omega_3^0 + a_1 \omega_3^1 + a_2 \omega_3^2)^n \binom{k}{a_0, a_1, a_2}
\]
暴力枚举 \(a_0, a_1, a_2\) 即可,时间复杂度 \(O(k^{d - 1} \log n)\)。
code
// Problem: #450. 【集训队作业2018】复读机
// Contest: UOJ
// URL: https://uoj.ac/problem/450
// Memory Limit: 512 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 500100;
const ll mod = 19491001, G = 7;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, K, d, fac[maxn], ifac[maxn];
void solve() {
scanf("%lld%lld%lld", &n, &K, &d);
fac[0] = 1;
for (int i = 1; i <= K; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[K] = qpow(fac[K], mod - 2);
for (int i = K - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
if (d == 1) {
printf("%lld\n", qpow(K, n));
} else if (d == 2) {
ll ans = 0, w = qpow(G, (mod - 1) / 2);
for (int i = 0; i <= K; ++i) {
int j = K - i;
ans = (ans + fac[K] * ifac[i] % mod * ifac[K - i] % mod * qpow((i + j * w) % mod, n) % mod) % mod;
}
ans = ans * qpow(qpow(2, mod - 2), K) % mod;
printf("%lld\n", ans);
} else if (d == 3) {
ll ans = 0, w = qpow(G, (mod - 1) / 3);
for (int i = 0; i <= K; ++i) {
for (int j = 0; i + j <= K; ++j) {
int k = K - i - j;
ans = (ans + fac[K] * ifac[i] % mod * ifac[j] % mod * ifac[k] % mod * qpow((i + j * w + k * w % mod * w % mod) % mod, n) % mod) % mod;
}
}
ans = ans * qpow(qpow(3, mod - 2), K) % mod;
printf("%lld\n", ans);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}