C - Not So Consecutive

发布时间 2023-12-14 16:53:14作者: onlyblues

C - Not So Consecutive

Problem Statement

You are given an integer $N$. An integer sequence $x=(x_1,x_2,\cdots,x_N)$ of length $N$ is called a good sequence if and only if the following conditions are satisfied:

  • Each element of $x$ is an integer between $1$ and $N$, inclusive.
  • For each integer $i$ ($1 \leq i \leq N$), there is no position in $x$ where $i$ appears $i+1$ or more times in a row.

You are given an integer sequence $A=(A_1,A_2,\cdots,A_N)$ of length $N$. Each element of $A$ is $-1$ or an integer between $1$ and $N$. Find the number, modulo $998244353$, of good sequences that can be obtained by replacing each $-1$ in $A$ with an integer between $1$ and $N$.

Constraints

  • $1 \leq N \leq 5000$
  • $A_i=-1$ or $1 \leq A_i \leq N$.
  • All input values are integers.

Input

The input is given from Standard Input in the following format:

$N$
$A_1$ $A_2$ $\cdots$ $A_N$

Output

Print the answer.


Sample Input 1

2
-1 -1

Sample Output 1

3

You can obtain four sequences by replacing each $-1$ with an integer between $1$ and $2$.

$A=(1,1)$ is not a good sequence because $1$ appears twice in a row.

The other sequences $A=(1,2),(2,1),(2,2)$ are good.

Thus, the answer is $3$.


Sample Input 2

3
2 -1 2

Sample Output 2

2

Sample Input 3

4
-1 1 1 -1

Sample Output 3

0

Sample Input 4

20
9 -1 -1 -1 -1 -1 -1 -1 -1 -1 7 -1 -1 -1 19 4 -1 -1 -1 -1

Sample Output 4

128282166

 

解题思路

  纯动态规划优化题,硬是从 $O(n^4)$ 优化到 $O(n^2)$ 甚至是 $O(n^2 \log{n})$,超有意思的说

  状态还是很容易想到的,定义 $f(i,j)$ 表示由前 $i$ 个数构成且第 $i$ 个数是 $j$ 的所有合法方案的数量。根据序列最后一段有多少个连续 $j$(假设有 $k$ 个),以及第 $i-k$ 个数是哪个数(假设是 $u$,需满足 $u \ne j$)进行状态划分,状态转移方程就是$$f(i,j) = \sum\limits_{k=1}^{\min \{ i, j \}}{\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}}$$

  实际上这个状态转移方程是有问题的,因为默认了 $a_{i-1}, a_{i-2}, \ldots, a_{i- \min \{ i, j \}}$ 都是 $-1$ 的情况。考虑 $a_1 \sim a_{i-1}$,如果这些数中存在某些 $a_v \ne -1$ 且 $a_v \ne j$,不妨假设 $v$ 是这些数中的最大下标。如果不存在这样的 $v$,即该范围内的数均是 $-1$ 或 $j$,则令 $v = 0$,同时规定 $a_0 = 0$。分情况讨论,如果 $i-v \leq j$,那么很明显最后一段最多只能有 $i - v$ 个连续的 $j$,且第 $v$ 个数 $a_v$ 是固定的。否则连续一段 $j$ 的最大长度就是 $j$。另外如果存在 $a_{i-k} = j$ 的情况,跳过即可。因此正确的状态转移方程应该是

$$f(i,j) = \begin{cases}
\left( \sum\limits_{\begin{array}{c} k=1 \\ a_{i-k} \ne j \end{array}}^{i-v-1}{\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}} \right) + f(v, a_v), &i-v \leq j \\\\
\sum\limits_{\begin{array}{c} k=1 \\ a_{i-k} \ne j \end{array}}^{j}{\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}}, &\text{others}
\end{cases}$$

  同时规定 $f(0,0) = 0$,这样对于序列前 $i$ 个数都是 $j$ 的状态可以从 $f(0,0)$ 转移得到。

  容易知道整个 dp 的时间复杂度是 $O(n^4)$,不过一个很明显可以优化的地方是 $\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}$ 这部分。本质是累加所有第一维是 $i-k$ 的状态 $f(i-k, *)$,然后减去 $f(i-k, j)$,而 $f(i-k, *)$ 在之前就已经全部求出来了。所以定义 $s_i = \sum\limits_{k=1}^{n}{f(i, k)}$,那么 $\sum\limits_{\begin{array}{c} u=1 \\ u \ne j \end{array}}^{n}{f(i-k, u)}$ 就可以等价成 $s_{i-k} - f(i-k, j)$,而 $s_i$ 只需在计算完 $f(i, *)$ 时进行累加即可,这样时间复杂度就降到了 $O(n^3)$。

  对应的状态转移方程如下:

$$f(i,j) = \begin{cases}
\left( \sum\limits_{k=1}^{i-v-1}{s_{i-k} - f(i-k,j)} \right) + f(v, a_v), &i-v \leq j \\\\
\sum\limits_{k=1}^{j}{s_{i-k} - f(i-k,j)}, &\text{others}
\end{cases}$$

  对于 $a_{i-k} = j$ 的情况原本是要跳过的,但对于这种情况必然有 $s_{i-k} = f(i-k, j)$,这是因为 $a_{i-k}$ 是定值,$f(i-k, u) = 0, \, u \ne j$,因此 $s_{i-k} - f(i-k, j) = 0$,并没有影响。

  先放出 TLE 代码,时间复杂度为 $O(n^3)$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 5010, mod = 998244353;

int a[N];
int f[N][N];
int s[N];

int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
    }
    f[0][0] = 1;
    for (int i = 1; i <= n; i++) {
        if (a[i] == -1) {
            for (int j = 1; j <= n; j++) {
                for (int k = 1; k <= j && k <= i; k++) {
                    if (a[i - k] != -1 && a[i - k] != j) {
                        f[i][j] = (f[i][j] + f[i - k][a[i - k]]) % mod;
                        break;
                    }
                    f[i][j] = ((LL)f[i][j] + s[i - k] - f[i - k][j] + mod) % mod;
                }
            }
        }
        else {
            int j = a[i];
            for (int k = 1; k <= j && k <= i; k++) {
                if (a[i - k] != -1 && a[i - k] != j) {
                    f[i][j] = (f[i][j] + f[i - k][a[i - k]]) % mod;
                    break;
                }
                f[i][j] = ((LL)f[i][j] + s[i - k] - f[i - k][j] + mod) % mod;
            }
        }
        for (int j = 1; j <= n; j++) {
            s[i] = (s[i] + f[i][j]) % mod;
        }
    }
    int ret = 0;
    for (int i = 1; i <= n; i++) {
        ret = (ret + f[n][i]) % mod;
    }
    printf("\n%d", ret);

    return 0;
}

  上述的代码是在枚举 $k$ 的过程中找到 $v$ 的。很明显如果还要优化的话那么就应该继续把求和符号去掉,求和的部分本质也是对第一维某个已求得区间的 $s_{i} - f(i,j)$ 进行累加,因此可以用前缀和进行优化。

  定义 $S_i = \sum\limits_{k=1}^{i}{s_i}$,$g(i,j) = \sum\limits_{k=1}^{i}{f(k,j)}$。

  那么 $\sum\limits_{k=1}^{i-v-1}{s_{i-k} - f(i-k,j)}$ 就等价于 $S_{i-1} - S_{v} - (g(i-1,j) - g(v,j))$。

  同理 $\sum\limits_{k=1}^{j}{s_{i-k} - f(i-k,j)}$ 就等价于 $S_{i-1} - S_{i-j-1} - (g(i-1,j) - g(i-j-1,j))$。

  状态转移方程变成了

$$f(i,j) = \begin{cases}
S_{i-1} - S_{v} - (g(i-1,j) - g(v,j)) + f(v, a_v), &i-v \leq j \\\\
S_{i-1} - S_{i-j-1} - (g(i-1,j) - g(i-j-1,j)), &\text{others}
\end{cases}$$

  现在关键的问题对于每个状态 $f(i,j)$ 如何快速确定对应的 $v$。本质是在 $a_0 \sim a_{i-1}$ 中找到同时满足 $a_v \ne -1$ 且 $a_v \ne j$ 的最大下标 $v$,所以可以用 std::set<std::pair<int, int>> 来动态维护 $0 \sim n$ 每个值出现的最大下标,其中第一个关键字是下标,第二个关键字是值,按第一个关键字降序排序。另外开一个数组 $p$ 表示每个值对应的最大下标。

  当枚举到 $j$ 时,查看 st.begin()->second,如果不等于 $j$,则对应的 $v$ 就是 st.begin()->first,否则就是 next(st.begin())->first

  当枚举到的 $a_i$ 是一个定值,那么只需从 std::set 中删除原本的数对 $(p_{a_i}, a_i)$,并重新插入 $(i, a_i)$,同时更新 $p_{a_i} \gets i$。

  AC 代码如下,时间复杂度为 $O(n^2 \log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef pair<int, int> PII;

const int N = 5010, mod = 998244353;

int a[N], p[N];
int f[N][N], g[N][N];
int s[N];

int main() {
    int n;
    scanf("%d", &n);
    set<PII> st({{0, 0}});
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
        st.insert({0, i});
    }
    f[0][0] = 1;
    for (int i = 1; i <= n; i++) {
        if (a[i] == -1) {
            for (int j = 1; j <= n; j++) {
                int x = -st.begin()->first, y = st.begin()->second;
                if (st.begin()->second == j) x = -next(st.begin())->first, y = next(st.begin())->second;
                if (x < i - j) f[i][j] = ((LL)s[i - 1] - s[max(0, i - j - 1)] - g[i - 1][j] + g[max(0, i - j - 1)][j]) % mod;
                else f[i][j] = ((LL)s[i - 1] - s[x] - g[i - 1][j] + g[x][j] + f[x][y]) % mod;
            }
        }
        else {
            int j = a[i];
            int x = -st.begin()->first, y = st.begin()->second;
            if (st.begin()->second == j) x = -next(st.begin())->first, y = next(st.begin())->second;
            if (x < i - j) f[i][j] = ((LL)s[i - 1] - s[max(0, i - j - 1)] - g[i - 1][j] + g[max(0, i - j - 1)][j]) % mod;
            else f[i][j] = ((LL)s[i - 1] - s[x] - g[i - 1][j] + g[x][j] + f[x][y]) % mod;
            st.erase({-p[j], j});
            st.insert({-i, j});
            p[j] = i;
        }
        s[i] = s[i - 1];
        for (int j = 1; j <= n; j++) {
            s[i] = (s[i] + f[i][j]) % mod;
            g[i][j] = (g[i - 1][j] + f[i][j]) % mod;
        }
    }
    int ret = 0;
    for (int i = 1; i <= n; i++) {
        ret = (ret + f[n][i]) % mod;
    }
    ret = (ret + mod) % mod;
    printf("%d", ret);
    
    return 0;
}

  其实 $O(n^2 \log{n})$ 的复杂度已经可以过了,实际上还可以优化到 $O(n^2)$,如果有兴趣可以继续往下看。

  上面的 $p_i$ 表示值 $i$ 的最大下标,可以反过来考虑,变成对于值不为 $i$ 的最大下标。那么对于 $f(i,j)$,对应的 $v$ 就直接等于 $p_j$。另外可以发现只有 $a_i \ne -1$ 的情况才需要更新 $p$ 数组,只需暴力枚举 $k$,令 $p_k = i, \, k \ne j$ 即可。

  AC 代码如下,时间复杂度为 $O(n^2)$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef pair<int, int> PII;

const int N = 5010, mod = 998244353;

int a[N];
int f[N][N], g[N][N];
int s[N];
int p[N];

int main() {
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
    }
    f[0][0] = 1;
    for (int i = 1; i <= n; i++) {
        if (a[i] == -1) {
            for (int j = 1; j <= n; j++) {
                if (p[j] < i - j) f[i][j] = ((LL)s[i - 1] - s[max(0, i - j - 1)] - g[i - 1][j] + g[max(0, i - j - 1)][j]) % mod;
                else f[i][j] = ((LL)s[i - 1] - s[p[j]] - g[i - 1][j] + g[p[j]][j] + f[p[j]][a[p[j]]]) % mod;
            }
        }
        else {
            int j = a[i];
            if (p[j] < i - j) f[i][j] = ((LL)s[i - 1] - s[max(0, i - j - 1)] - g[i - 1][j] + g[max(0, i - j - 1)][j]) % mod;
            else f[i][j] = ((LL)s[i - 1] - s[p[j]] - g[i - 1][j] + g[p[j]][j] + f[p[j]][a[p[j]]]) % mod;
            for (int k = 1; k <= n; k++) {
                if (k != j) p[k] = i;
            }
        }
        s[i] = s[i - 1];
        for (int j = 1; j <= n; j++) {
            s[i] = (s[i] + f[i][j]) % mod;
            g[i][j] = (g[i - 1][j] + f[i][j]) % mod;
        }
    }
    int ret = 0;
    for (int i = 1; i <= n; i++) {
        ret = (ret + f[n][i]) % mod;
    }
    ret = (ret + mod) % mod;
    printf("%d", ret);
    
    return 0;
}

 

参考资料

  Editorial - estie Programming Contest 2023 (AtCoder Regular Contest 169):https://atcoder.jp/contests/arc169/editorial/7911

  AtCoder Regular Contest 169(A~D):https://zhuanlan.zhihu.com/p/671467218