题目链接:https://codeforces.com/contest/1910/problem/I
题意
有一个 \(n\) 个字符的字符串 \(S\),你需要不断从中删除一个长度为 \(k\) 的子串,直到串的长度变为 \(n \mathbin{\rm mod} k\),求能够产生的字典序最小的字符串 \(T\)。
但你要解决的不是这个问题,是这个问题的“逆问题”。已知字符串 \(T\),你要找出答案为 \(T\) 的原字符串 \(S\) 的个数。
\(n, k \le 10^6\),\(n \mathbin{\rm mod} k \ne 0\)。给定字符集大小 \(c \le 26\)。
题解(官解)
首先我们解决给出的“正问题”。首先可以看出,因为每一次删除只能删除连续的长为 \(k\) 的子串,因此留下的第一个字符的下标 \(p_0\) 一定是 \(k\) 的整数倍。同理,留下的第 \(i\) 个字符的下标 \(p_i\) 模 \(k\) 为 \(i\)。若记 \(p_i = j_i \cdot k + i\),容易看出任意 \(j_i\) 非递减的 \(p_i\) 序列都可以得到。
经过上面的分析,要得到字典序最小的 \(T\),有如下的贪心算法:
- 第 \(i\) 轮只考虑位置模 \(k\) 等于 \(i\) 的字符。取最左侧的最小字符为第一个留在 \(T\) 中的字符(删去左侧其它所有字符),并继续接下来的选择。
根据这个贪心算法进行 DP。记 \(t = n \mathbin{\rm mod} k\),\(m = \lceil \dfrac n k\rceil\)。记 \(dp_{i, j}\)(\(i \le t\),\(j < m\))为接下来将要进行第 \(i\) 轮选择,可以选择所有 \(p_i \ge j\) 的 \(S\) 的数量,则对 \(\forall j' \ge j\),有贡献:
这是因为 \(j\) 左侧的部分已经被删去,可以任意选择。而 \(j'\) 处已经选择了 \(s_i\),它前面的字符应大于 \(s_i\),后面的字符应大于或等于 \(s_i\)。
这个贡献可以拆成只与 \(j\) 和 \(j'\) 有关的两部分的乘积,因此可以使用前缀和优化计算,使得复杂度降为 \(O(tm) = O(n + k)\)。
注意另外还有 \(i \mathbin{\rm mod} k \ge t\) 的部分没有计算,它们可以任取,即将 \(c^{n- tm}\) 加入答案即可。
代码实现(Kotlin)
fun readInts() = readln().split(' ').map { it.toInt() }.toIntArray()
const val MOD = 998244353L
infix fun Int.modAdd(b: Int): Int {
return if (this + b >= MOD) {
this + b - MOD.toInt()
} else {
this + b
}
}
infix fun Int.modMul(b: Int) = (this.toLong() * b.toLong()).mod(MOD).toInt()
infix fun Int.modPow(b: Int): Int {
var at = this
var res = 1
var bt = b
while (bt > 0) {
if (bt and 1 != 0) {
res = res modMul at
}
at = at modMul at
bt = bt shr 1
}
return res
}
fun Int.modInv(): Int {
return this modPow (MOD.toInt() - 2)
}
fun main() {
val (n, k, c) = readInts()
val t = n.mod(k)
val m = n / k + 1
val s = readln()
val dp = MutableList(t + 1) { MutableList(m) { 0 } }
dp[0][0] = 1
for (i in 0..<t) {
val ci = s[i] - 'a'
for (j in 0..<m) {
if (ci != c - 1)
dp[i + 1][j] = dp[i + 1][j] modAdd (
dp[i][j] modMul ((c modMul (c - ci - 1).modInv()) modPow j)
)
else
dp[i + 1][j] = dp[i + 1][j] modAdd (dp[i][j] modMul (c modPow j))
}
if (ci != c - 1)
for (j in 1..<m) {
dp[i + 1][j] = dp[i + 1][j] modAdd dp[i + 1][j - 1]
}
for (j in 0..<m) {
if (ci != c - 1)
dp[i + 1][j] = dp[i + 1][j] modMul ((c - ci - 1) modPow j) modMul ((c - ci) modPow (m - j - 1))
else
dp[i + 1][j] = dp[i + 1][j] modMul ((c - ci) modPow (m - j - 1))
}
}
var ans = 0
for (j in 0..<m) {
ans = ans modAdd (dp[t][j] modMul (c modPow (n - t * m)))
}
println(ans)
}