CF960G Bandit Blues

发布时间 2023-08-11 13:39:45作者: Ender_32k

半个月前做的题,这段时间一直在颓所以没写题解,今天突然想起来才准备补上。

考虑枚举最大值 \(n\) 的位置 \(i\),那么排列就被分成 \(2\) 个段 \([1,i-1]\)\([i+1,n]\),而且 \(\forall k\in [i+1,n]\)\(k\) 不可能是前缀最大值;\(\forall k\in [1,i-1]\)\(k\) 不可能是后缀最大值。

于是两个段可以互相独立地进行计数。就是说 \([1,i-1]\) 中有 \(a-1\) 个前缀最大值,\([i+1,n]\) 中有 \(b-1\) 个后缀最大值。

\(f_{i,j}\) 表示 \(i\) 个数的排列前缀最大值有 \(j\) 个的方案数。枚举 \(n\) 的位置 \(i\),从剩下 \(n-1\) 个数中选 \(i-1\) 个放到 \([1,i-1]\),那么答案就是:

\[\text{ans}=\sum\limits_{i=1}^n\dbinom{n-1}{i-1}f_{i-1,a-1}f_{n-i,b-1} \]

考虑 \(f\) 怎么求,我们从大到小往当前排列 \(P\) 中插入最小值 \(k\),考虑 \(n\) 个数的排列有 \(n+1\) 个位置可以插入:

  • \(k\) 插入到 \(P\) 的开头,新增一个前缀最大值,则 \(f_{i,j}\gets f_{i-1,j-1}\)
  • \(k\) 插入到 \(P\) 的其它位置,前缀最大值个数不变,则 \(f_{i,j}\gets (i-1)f_{i-1,j}\)

于是:

\[f_{i,j}=f_{i-1,j-1}+(i-1)f_{i-1,j} \]

看得出来 \(f_{i,j}\) 这东西就是第一类斯特林数 \(\begin{bmatrix}i\\j\end{bmatrix}\)

那么答案就是:

\[\text{ans}=\sum\limits_{i=1}^n\dbinom{n-1}{i-1}\begin{bmatrix}i-1\\a-1\end{bmatrix}\begin{bmatrix}n-i\\b-1\end{bmatrix} \]

至此,不怕麻烦的同学就可以直接计算 \(a-1,b-1\) 两列的斯特林数,然后卷一下就行了。

如果继续化简,有 \(3\) 种方式:

  • 代数推导(天地灭)
  • 生成函数(天地灭灭灭灭灭灭灭灭)
  • 组合意义(非常的新鲜,非常的美味啊!)

所以我们考虑组合意义。

求和里面那坨就相当于从 \(n-1\) 个数中选 \(i-1\) 个生成 \(a-1\) 个圆排列,剩下 \(n-i\) 个数生成 \(b-1\) 个圆排列的方案数。

按照范德蒙德卷积的思路,相当于 \(n-1\) 个数中先生成 \(a+b-2\) 个圆排列,再从这么多圆排列中选出 \(a-1\) 个的方案数。由于选出来后 \(a-1\) 个圆排列和 \(i\)\(a-1\) 个圆排列总大小)是对应的,所以既不会重也不会漏。

于是答案就是 \(\begin{bmatrix}n-1\\a+b-2\end{bmatrix}\dbinom{a+b-2}{a-1}\)

考虑到第一类斯特林数没有实用的通项公式,直接大力把第 \(n-1\) 行算出来就行了。复杂度 \(O(n\log n)\)

typedef vector<int> poly;
const int G = 114514;
const int P = 998244353;
const int N = 6e5 + 600;

int n, a, b, fac[N], ifac[N];
poly bs;

int qpow(int p, int q) {
    int res = 1;
    for (; q; q >>= 1, p = 1ll * p * p % P)
        if (q & 1) res = 1ll * res * p % P;
    return res;
}

const int iG = qpow(G, P - 2);

void init(int lim) {
    fac[0] = 1;
    for (int i = 1; i <= lim; i++)
        fac[i] = 1ll * fac[i - 1] * i % P;
    ifac[lim] = qpow(fac[lim], P - 2);
    for (int i = lim - 1; ~i; i--)
        ifac[i] = 1ll * ifac[i + 1] * (i + 1) % P;
}

void NTT(int *f, int len, int lim, int op) {
    static int tr[N];
    for (int i = 1; i < len; i++)
        tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (lim - 1));
    for (int i = 0; i < len; i++) 
        if (i < tr[i]) swap(f[i], f[tr[i]]);
    for (int o = 2, k = 1; k < len; o <<= 1, k <<= 1) {
        int tg = qpow(~op ? G : iG, (P - 1) / o);
        for (int i = 0; i < len; i += o) {
            for (int j = 0, w = 1; j < k; j++, w = 1ll * w * tg % P) {
                int x = f[i + j], y = 1ll * w * f[i + j + k] % P;
                f[i + j] = (x + y) % P, f[i + j + k] = (x - y + P) % P;
            }
        }
    }
    if (~op) return;
    int iv = qpow(len, P - 2);
    for (int i = 0; i < len; i++)
        f[i] = 1ll * f[i] * iv % P;
}

poly Mul(poly x, poly y) {
    static int tx[N], ty[N];
    int len = 1, lim = 0, sx = x.size(), sy = y.size();
    while (len < (sx + sy)) len <<= 1, lim++;
    memset(tx, 0, sizeof(int) * (len + 10));
    memset(ty, 0, sizeof(int) * (len + 10));
    for (int i = 0; i < sx; i++) tx[i] = x[i];
    for (int i = 0; i < sy; i++) ty[i] = y[i];
    NTT(tx, len, lim, 1), NTT(ty, len, lim, 1);
    for (int i = 0; i < len; i++)
        tx[i] = 1ll * tx[i] * ty[i] % P;
    NTT(tx, len, lim, -1);
    poly res;
    for (int i = 0; i < sx + sy - 1; i++) res.pb(tx[i]);
    return res;
}

poly conq(int len) {
    if (len == 1) return bs;
    if (len & 1) {
        poly tp = conq(len - 1), res;
        res.pb(1ll * tp[0] * (len - 1) % P);
        for (int i = 1; i < len; i++) 
            res.pb((tp[i - 1] + 1ll * tp[i] * (len - 1) % P) % P);
        res.pb(tp[len - 1]);
        return res;
    }
    poly tp = conq(len >> 1), ta, tb, tc;
    int res = 1;
    for (int i = 0; i <= (len >> 1); i++)
        ta.pb(1ll * tp[i] * fac[i] % P), tb.pb(1ll * res * ifac[i] % P), res = 1ll * res * (len >> 1) % P;
    reverse(ta.begin(), ta.end());
    ta = Mul(ta, tb);
    for (int i = 0; i <= (len >> 1); i++)
        tc.pb(1ll * ifac[i] * ta[(len >> 1) - i] % P);
    return Mul(tp, tc);
}

int C(int n, int m) {
    if (n < m || m < 0) return 0;
    return 1ll * fac[n] * ifac[m] % P * ifac[n - m] % P;
}

int main() {
    n = rd(), a = rd(), b= rd(), init(n << 1), bs = poly(2, 1), bs[0] = 0;
    if (n == 1) {
        if (a == 1 && b == 1) puts("1");
        else puts("0");
        return 0;
    }
    if (a + b - 2 > n) return puts("0"), 0;
    poly ans = conq(n - 1);
    wr(1ll * ans[a + b - 2] * C(a + b - 2, a - 1) % P);
	return 0;
}