Luogu P5664 [CSP-S2019] Emiya 家今天的饭

发布时间 2023-05-19 19:57:33作者: lhzawa

发现“每种主要食材至多在 \(\lfloor \frac{k}{2} \rfloor\) 个菜中被使用”有一个性质,在不合法的情况下绝对只有 \(1\) 个主要食材的个数 \(> \lfloor \frac{k}{2} \rfloor\),因为 \(k - \lfloor \frac{k}{2} \rfloor - 1\le \lfloor \frac{k}{2} \rfloor\)
然后就能发现算不合法的情况比算合法情况好多了啊,因为合法情况的需要考虑所有主要食材,但不合法只需要考虑不合法的这一个主要食材算出其他主要食材组成的方案数即可
所以考虑用总方案数 \(-\) 不合法方案数 \(=\) 合法方案数

考虑算出总方案数,因为每种烹饪方法只能用一次,考虑设 \(f_{i}\) 为前 \(i\) 种烹饪方法的方案数,则对于第 \(i\) 种烹饪方法有两种可能:

  1. 不选第 \(i\) 种烹饪方法,则直接 \(f_i = f_{i - 1}\)
  2. 选第 \(i\) 种方法,因为第 \(i\) 种烹饪方法有 \(h_i = \sum\limits_{j = 1}^m a_{i, j}\) 个菜且因为不关心是否合法任意一种都可以,所以 \(f_i = f_{i - 1}\times h_i\)

综合一下,\(f_i = f_{i - 1} + f_{i - 1}\times h_i\),最后 \(f_n\) 即为总方案数,记得减掉空集产生的 \(1\) 个方案

然后考虑不合法方案数,首先枚举不合法的那一个主要食材 \(1\le i\le m\),然后可以把选的菜的主要食材分为两类:“第 \(i\) 种主要食材”和“非第 \(i\) 种主要食材”
那就可以设 \(g_{j, fi, se}\) 为前 \(j\) 种烹饪方法选了 \(fi\) 个“第 \(i\) 种主要食材”的菜和 \(se\) 个“非第 \(i\) 种主要食材”的菜
那么对于第 \(j\) 种烹饪方法有 \(3\) 种可能:

  1. 不选第 \(j\) 种烹饪方法,\(g_{j, fi, se} = g_{j - 1, fi, se}\)
  2. 选第 \(j\) 种烹饪方法且是第 \(i\) 种主要食材,则有 \(a_{j, i}\) 种菜品可以选择,\(g_{j, fi, se} = g_{j, fi - 1, se}\times a_{j, i}\)
  3. 选第 \(j\) 种烹饪方法且但不第 \(i\) 种主要食材,则有 \(h_j - a_{j, i}\) 种菜品可以选择,\(g_{j, fi, se} = g_{j, fi, se - 1}\times (h_j - a_{j, i})\)

综合一下,\(g_{j, fi, se} = g_{j - 1, fi, se} + g_{j, fi - 1, se}\times a_{j, i} + g_{j, fi , se - 1}\times (h_j - a_{j, i})\),注意一下边界
不合法方案数即为 \(\sum\limits_{fi = 1}^{n}\sum\limits_{se = 0}^{fi - 1} g_{n, fi, se}\),减掉这部分贡献即可

时间复杂度 \(\mathcal{O}(n^3m)\),过不了最后一个 \(\text{Sub}\)

// by lhzawa
#include<bits/stdc++.h>
using namespace std;
const long long mod = 998244353;
const int N = 1e2 + 10, M = 2e3 + 10;
long long a[N][M], h[N];
long long dpc[N], dph[N][N][N];
int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            scanf("%lld", &a[i][j]);
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            h[i] = (h[i] + a[i][j]) % mod;
        }
    }
    dpc[0] = 1;
    for (int i = 1; i <= n; i++) {
        dpc[i] = (dpc[i - 1] + dpc[i - 1] * h[i] % mod) % mod; 
    }
    long long ans = (dpc[n] - 1 + mod) % mod;
    for (int i = 1; i <= m; i++) {
        dph[0][0][0] = 1;
        for (int j = 1; j <= n; j++) {
            for (int fi = 0; fi <= j; fi++) {
                for (int se = 0; fi + se <= j; se++) {
                    dph[j][fi][se] = dph[j - 1][fi][se];
                    if (se) {
                        dph[j][fi][se] = (dph[j][fi][se] + dph[j - 1][fi][se - 1] * (h[j] - a[j][i] + mod) % mod) % mod;
                    }
                    if (fi) {
                        dph[j][fi][se] = (dph[j][fi][se] + dph[j - 1][fi - 1][se] * a[j][i] % mod) % mod;
                    }
                }
            }
        }
        for (int fi = 1; fi <= n; fi++) {
            for (int se = 0; se < fi; se++) {
                ans = (ans - dph[n][fi][se] + mod) % mod;
            }
        }
    }
    printf("%lld", ans);
    return 0;
}

发现算法的瓶颈在于转移的 \(n^3\) 很难处理,首先 \(j\) 这一维肯定不能去,于是考虑对 \(fi, se\) 这部分进行优化
发现最后的答案只取决于 \(fi > se\),即 \(fi - se > 0\),那就可以想到用 \(fi - se\) 代替 \(fi, se\)
具体的,设 \(g_{j, c}\) 为前 \(j\) 种烹饪方法中“第 \(i\) 种主要食材”的菜的个数 \(-\) “非第 \(i\) 种主要食材”的菜的个数 \(= c\) 的方案数
那么对于第 \(j\) 种烹饪方法有 \(3\) 种可能:

  1. 不选第 \(j\) 种烹饪方法,\(g_{j, c} = g_{j - 1, c}\)
  2. 选第 \(j\) 种烹饪方法且是第 \(i\) 种主要食材,则差值会 \(+1\)\(g_{j, c} = g_{j, c - 1}\times a_{j, i}\)
  3. 选第 \(j\) 种烹饪方法且但不第 \(i\) 种主要食材,则差值会 \(-1\)\(g_{j, c} = g_{j, c + 1}\times (h_j - a_{j, i})\)

时间复杂度 \(\mathcal{O}(n^2m)\)

// by lhzawa
#include<bits/stdc++.h>
using namespace std;
const long long mod = 998244353;
const int N = 1e2 + 10, M = 2e3 + 10;
long long a[N][M], h[N];
long long dpc[N], dph[N][2 * N];
int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            scanf("%lld", &a[i][j]);
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            h[i] = (h[i] + a[i][j]) % mod;
        }
    }
    dpc[0] = 1;
    for (int i = 1; i <= n; i++) {
        dpc[i] = (dpc[i - 1] + dpc[i - 1] * h[i] % mod) % mod; 
    }
    long long ans = (dpc[n] - 1 + mod) % mod;
    for (int i = 1; i <= m; i++) {
        dph[0][N + 0] = 1;
        for (int j = 1; j <= n; j++) {
            for (int c = N - j; c <= N + j; c++) {
                dph[j][c] = dph[j - 1][c];
                dph[j][c] = (dph[j][c] + dph[j - 1][c + 1] * (h[j] - a[j][i] + mod) % mod) % mod;
                dph[j][c] = (dph[j][c] + dph[j - 1][c - 1] * a[j][i] % mod) % mod;
            }
        }
        for (int c = N + 1; c <= N + n; c++) {
            ans = (ans - dph[n][c] + mod) % mod;
        }
    }
    printf("%lld", ans);
    return 0;
}