「解题报告」ARC124E Pass to Next

发布时间 2023-03-27 10:12:53作者: APJifengc

我果然还是无脑型选手。

首先还是考虑设 \(\{x_i\}\) 表示第 \(i\) 个位置的人往后传了几个球,那么考虑如何将操作序列与最终局面一一对应。发现如果 \(\{x_i\}\) 中的所有数都 \(\ge 1\),那么我们可以直接全部减去一个 \(1\),最终局面不变。所以,我们只需要统计所有 \(\min x_i = 0\) 时的答案即可。这个可以简单容斥求出,即求出 \(x_i \in [0, a_i]\) 时的答案,减去 \(x_i \in [1, a_i]\) 时的答案。下面以前者为例子。

考虑我们要求的答案,实际上等于 \(\displaystyle \prod_{i=1}^n (a_i - x_i + x_{i - 1})\)(当 \(i=1\) 时为 \((a_1 - x_1 + x_n)\))。

直接统计不好统计,而这是一个三项式乘积,我们可以把乘积拆开,每次从三项中选一项乘起来,这样同样也可以得到答案。

考虑每一位怎么选。

  • 假如我们选了 \(a_i\),那么这一位对于 \(x_i \in [0, a_i]\) 中的任意一种情况贡献都为 \(a_i\)
  • 假如我们选了 \(-x_i\),那么对于 \(x_i \in [0, a_i]\) 中的任意一种情况的贡献都为 \(x_i\),等差数列求和;
  • 假如我们选了 \(x_{i - 1}\),那么我们需要知道上一次选的是否为 \(x_i\),如果是,那么两个数的贡献需要一起来算,即求平方和;否则,这一位还是自己算自己的贡献,求等差数列和。同时,需要乘上 \(x_i\) 的方案数。

发现第三种情况很难考虑,我们不妨在选当前一项的时候就先确定下一项是否选 \(x_{i - 1}\),如果选就直接从 \(f_i\) 转移到 \(f_{i + 2}\),否则转移到 \(f_{i + 1}\)

具体的,我们可以设 \(f_i\) 为考虑到第 \(i\) 项,上一项选择了 \(x_i\) 且这一项不选择 \(x_{i-1}\) 的权值和;\(g_i\) 为考虑到第 \(i\) 项,上一项没有选择 \(x_i\),这一项可以选择 \(x_{i-1}\) 的权值和。那么转移直接考虑上述几种情况即可。

但是这个转移是一个环,所以我们需要先断开,钦定好断点处的选择方案,然后再 DP 统计方案。


更简单的做法是考虑组合意义,\(\prod b_i\) 相当于在每种最终局面中,每个人从自己手中的球中选取一个的方案数。发现选取的球要不然是选取自己剩下的,要不然是选取上一个人传下来的,那么我们可以根据这两种情况设计 DP,然后转移即可。本质好像和这种做法差不多,但是我写的转移麻烦的很。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005, P = 998244353;
int n;
int a[MAXN];
int f[MAXN], g[MAXN];
int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
const int INV2 = (P + 1) / 2, INV6 = qpow(6, P - 2), INV3 = qpow(3, P - 2);
int solve0() {
    int ans = 0;
    auto dp = [&] {
        for (int i = 1; i <= n; i++) {
            // sel a_i
            g[i + 1] = (g[i + 1]
                + 1ll * (f[i] + g[i]) * (a[i] + 1) % P * a[i]
            ) % P;
            // sel -x_i and no x_i-1
            f[i + 1] = (f[i + 1]
                - 1ll * (f[i] + g[i]) * a[i] % P * (a[i] + 1) % P * INV2 % P + P
            ) % P;
            // sel -x_i and x_i-1
            g[i + 2] = (g[i + 2] 
                - 1ll * (f[i] + g[i]) * a[i] % P * (a[i] + 1) % P * (2 * a[i] + 1) % P * INV6 % P
                * (a[i + 1] + 1) % P + P
            ) % P;
            // sel x_i-1
            g[i + 1] = (g[i + 1]
                + 1ll * g[i] * a[i - 1] % P * INV2 % P * (a[i] + 1) % P
            ) % P;
        }
    };
    // case 1: no x_n
    for (int i = 1; i <= n + 1; i++) 
        f[i] = g[i] = 0;
    f[1] = 1;
    dp();
    ans = (ans + (f[n + 1] + g[n + 1]) % P) % P;
    // case 2: x_n
    for (int i = 1; i <= n + 1; i++) 
        f[i] = g[i] = 0;
    f[1] = g[1] = 0;
    g[2] = a[1] + 1;
    dp();
    ans = (ans + 
        1ll * f[n + 1] * (2 * a[n] + 1) % P * INV3 +
        1ll * g[n + 1] * a[n] % P * INV2
    ) % P;
    return ans;
}
int solve1() {
    int ans = 0;
    auto dp = [&] {
        for (int i = 1; i <= n; i++) {
            // sel a_i
            g[i + 1] = (g[i + 1]
                + 1ll * (f[i] + g[i]) * a[i] % P * a[i]
            ) % P;
            // sel -x_i and no x_i-1
            f[i + 1] = (f[i + 1]
                - 1ll * (f[i] + g[i]) * a[i] % P * (a[i] + 1) % P * INV2 % P + P
            ) % P;
            // sel -x_i and x_i-1
            g[i + 2] = (g[i + 2] 
                - 1ll * (f[i] + g[i]) * a[i] % P * (a[i] + 1) % P * (2 * a[i] + 1) % P * INV6 % P
                * a[i + 1] % P + P
            ) % P;
            // sel x_i-1
            g[i + 1] = (g[i + 1]
                + 1ll * g[i] * (a[i - 1] + 1) % P * INV2 % P * a[i] % P
            ) % P;
        }
    };
    // case 1: no x_n
    for (int i = 1; i <= n + 1; i++) 
        f[i] = g[i] = 0;
    f[1] = 1;
    dp();
    ans = (ans + (f[n + 1] + g[n + 1]) % P) % P;
    // case 2: x_n
    for (int i = 1; i <= n + 1; i++) 
        f[i] = g[i] = 0;
    f[1] = g[1] = 0;
    g[2] = a[1];
    dp();
    ans = (ans + 
        1ll * f[n + 1] * (2 * a[n] + 1) % P * INV3 +
        1ll * g[n + 1] * (a[n] + 1) % P * INV2
    ) % P;
    return ans;
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    // printf("%d\n", solve0());
    // printf("%d\n", solve1());
    printf("%d\n", (solve0() - solve1() + P) % P);
    return 0;
}