CF1910I Inverse Problems

发布时间 2024-01-02 22:08:27作者: cccpchenpi

题目链接:https://codeforces.com/contest/1910/problem/I

题意

有一个 \(n\) 个字符的字符串 \(S\),你需要不断从中删除一个长度为 \(k\) 的子串,直到串的长度变为 \(n \mathbin{\rm mod} k\),求能够产生的字典序最小的字符串 \(T\)

但你要解决的不是这个问题,是这个问题的“逆问题”。已知字符串 \(T\),你要找出答案为 \(T\) 的原字符串 \(S\) 的个数。

\(n, k \le 10^6\)\(n \mathbin{\rm mod} k \ne 0\)。给定字符集大小 \(c \le 26\)

题解(官解)

首先我们解决给出的“正问题”。首先可以看出,因为每一次删除只能删除连续的长为 \(k\) 的子串,因此留下的第一个字符的下标 \(p_0\) 一定是 \(k\) 的整数倍。同理,留下的第 \(i\) 个字符的下标 \(p_i\)\(k\)\(i\)。若记 \(p_i = j_i \cdot k + i\),容易看出任意 \(j_i\) 非递减的 \(p_i\) 序列都可以得到。

经过上面的分析,要得到字典序最小的 \(T\),有如下的贪心算法:

  • \(i\) 轮只考虑位置模 \(k\) 等于 \(i\) 的字符。取最左侧的最小字符为第一个留在 \(T\) 中的字符(删去左侧其它所有字符),并继续接下来的选择。

根据这个贪心算法进行 DP。记 \(t = n \mathbin{\rm mod} k\)\(m = \lceil \dfrac n k\rceil\)。记 \(dp_{i, j}\)\(i \le t\)\(j < m\))为接下来将要进行第 \(i\) 轮选择,可以选择所有 \(p_i \ge j\)\(S\) 的数量,则对 \(\forall j' \ge j\),有贡献:

\[dp_{i , j} \cdot c^j \cdot (c - s_i - 1)^ {j' - j} \cdot (c - s_i)^ {m - j' - 1} \to dp_{i +1 , j'} \]

这是因为 \(j\) 左侧的部分已经被删去,可以任意选择。而 \(j'\) 处已经选择了 \(s_i\),它前面的字符应大于 \(s_i\),后面的字符应大于或等于 \(s_i\)

这个贡献可以拆成只与 \(j\)\(j'\) 有关的两部分的乘积,因此可以使用前缀和优化计算,使得复杂度降为 \(O(tm) = O(n + k)\)

注意另外还有 \(i \mathbin{\rm mod} k \ge t\) 的部分没有计算,它们可以任取,即将 \(c^{n- tm}\) 加入答案即可。

代码实现(Kotlin)

fun readInts() = readln().split(' ').map { it.toInt() }.toIntArray()

const val MOD = 998244353L

infix fun Int.modAdd(b: Int): Int {
    return if (this + b >= MOD) {
        this + b - MOD.toInt()
    } else {
        this + b
    }
}

infix fun Int.modMul(b: Int) = (this.toLong() * b.toLong()).mod(MOD).toInt()

infix fun Int.modPow(b: Int): Int {
    var at = this
    var res = 1
    var bt = b
    while (bt > 0) {
        if (bt and 1 != 0) {
            res = res modMul at
        }
        at = at modMul at
        bt = bt shr 1
    }
    return res
}

fun Int.modInv(): Int {
    return this modPow (MOD.toInt() - 2)
}

fun main() {
    val (n, k, c) = readInts()
    val t = n.mod(k)
    val m = n / k + 1
    val s = readln()
    val dp = MutableList(t + 1) { MutableList(m) { 0 } }
    dp[0][0] = 1
    for (i in 0..<t) {
        val ci = s[i] - 'a'
        for (j in 0..<m) {
            if (ci != c - 1)
                dp[i + 1][j] = dp[i + 1][j] modAdd (
                        dp[i][j] modMul ((c modMul (c - ci - 1).modInv()) modPow j)
                        )
            else
                dp[i + 1][j] = dp[i + 1][j] modAdd (dp[i][j] modMul (c modPow j))
        }
        if (ci != c - 1)
            for (j in 1..<m) {
                dp[i + 1][j] = dp[i + 1][j] modAdd dp[i + 1][j - 1]
            }
        for (j in 0..<m) {
            if (ci != c - 1)
                dp[i + 1][j] = dp[i + 1][j] modMul ((c - ci - 1) modPow j) modMul ((c - ci) modPow (m - j - 1))
            else
                dp[i + 1][j] = dp[i + 1][j] modMul ((c - ci) modPow (m - j - 1))
        }
    }
    var ans = 0
    for (j in 0..<m) {
        ans = ans modAdd (dp[t][j] modMul (c modPow (n - t * m)))
    }
    println(ans)
}