Balance Addicts 题解

发布时间 2023-11-09 09:43:42作者: TKXZ133

Balance Addicts

题目大意

给定序列 \(a\),求有多少种合法的划分方案。

定义一种划分方案是合法的当且仅当划分出的各段序列的和构成回文序列。

思路分析

一种不太一样的做法。

我们先对 \(a\) 做一遍前缀和,得到 \(s\)

观察各段序列的和形式:

\[s_{p_1},s_{p_2} - s_{p_1},s_{p_3}-s_{p_2},...,s_{p_k}-s_{p_{k-1}} \]

其中,\(p_i\) 是第 \(i\) 段的结尾下标,\(k\) 是划分出的段数。

因为这个序列构成回文序列,因此我们有:

\[\begin{cases} s_{p_1}=s_{p_k}-s_{p_{k-1}}\\ s_{p_2}-s_{p_1}=s_{p_{k-1}}-s_{p_{k-2}}\\ ...\\ s_{p_{m}}-s_{p_{m-1}} = s_{p_{k-m+1}} - s_{p_{k-m}} \end{cases}\]

移项,得到以下等式:

\[s_{p_m}+s_{p_{k-m}}=s_{p_{m-1}}+s_{p_{k-m+1}}=...=s_{p_2}+s_{p_{k-2}}=s_{p_1}+s_{p_{k-1}}=s_{p_k}=s_n \]

也就是说,我们选出的序列是回文序列的充要条件是:

\[\boxed{s_{p_i}+s_{p_{k-i}}=s_n} \]

因为 \(a_i\ge0\),所以 \(s\) 单调不减,也就是说,\(s\) 中相同的值均相邻

那么我们对 \(s\) 的每一段值分别考虑。我们可以发现,\(s\) 的每一段值之间互不影响,也就是说,对于 \(s\) 中的一种值我们计算且只计算一次答案。

对于一种值 \(x\),设 \(cnt(x)\) 表示 \(x\)\(s\) 中的出现次数,那么它对答案的贡献就是:

\[\sum_{i=1}^{cnt(x)}{cnt(x)\choose i}{cnt(s_n-x)\choose i} \]

也就是枚举这种值放几个到回文序列中去,用乘法原理和加法原理组合出结果。这里默认 \(cnt(x)\le cnt(s_n-x)\),不满足交换一下就可以了。

然后考虑范德蒙德卷积,也就是:

\[\sum_{i=1}^{cnt(x)}{cnt(x)\choose i}{cnt(s_n-x)\choose i}=\sum_{i=1}^{cnt(x)}{cnt(x)\choose cnt(x)-i}{cnt(s_n-x)\choose i}={cnt(x)+cnt(s_n-x)\choose cnt(x)} \]

(这里其实麻烦了,直接枚举也是可以的,因为 \(\sum cnt(x)=n\)。是我脑抽了。)

因此,我们可以得出答案的最终表达式为:

\[\prod_{x\in V} {cnt(x)+cnt(s_n-x)\choose cnt(x)} \]

其中,\(V\) 表示 \(s\) 的值域集合,注意 \(x\)\(s_n-x\) 只能计算一次。

\(s_n\) 是偶数时,会出现 \(s_{p_i}=s_{p_{k-i}}=\dfrac{1}{2}s_n\) 这样的情况,因此这时需要额外乘上 \(2^{cnt(\frac{1}{2}s_n)}\),也就是等于 \(\dfrac{1}{2}s_n\) 的数可以随便选的方案数。

代码

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <map>

using namespace std;
const int N = 200200, L = 200000, mod = 998244353;
#define inf 0x3f3f3f3f
#define int long long

int n, ans = 1, T;
int a[N], sum[N], fac[N], inv[N];

map <int, int> mp;

int q_pow(int a, int b){
    int res = 1;
    while (b) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

int C(int n, int m){
    if (m > n) return 0;
    return fac[n] * inv[n - m] % mod * inv[m] % mod;  
}

signed main(){
    fac[0] = 1;
    for (int i = 1; i <= L; i ++) fac[i] = fac[i - 1] * i % mod;
    inv[L] = q_pow(fac[L], mod - 2);
    for (int i = L; i >= 1; i --) inv[i - 1] = inv[i] * i % mod;
    scanf("%lld", &T);
    while (T --) {
        mp.clear(); ans = 1;
        scanf("%lld", &n);
        for (int i = 1; i <= n; i ++) scanf("%lld", &a[i]);
        for (int i = 1; i <= n; i ++) sum[i] = sum[i - 1] + a[i];
        for (int i = 1; i < n; i ++) mp[sum[i]] ++;
        int pos = 1;
        while (sum[pos] * 2 < sum[n]) {
            if (sum[pos] != sum[pos + 1]) 
                ans = ans * C(mp[sum[pos]] + mp[sum[n] - sum[pos]], mp[sum[pos]]) % mod;
            pos ++;
        }
        if (sum[n] % 2 == 0) ans = ans * q_pow(2, mp[sum[n] / 2]) % mod;
        cout << ans << '\n';
    }
    return 0;
}