Atcoder Beginner Contest 321 G - Electric Circuit 题解 - 状压dp | 指定最低位

发布时间 2023-11-10 01:59:45作者: bringlu

为了更好的阅读体验,请点击这里

题目链接:G - Electric Circuit

看到了 \(N\) 的数据范围,因此是显然的状压 dp。

不妨设 \(f_S\) 为仅使用 \(S\) 集合中的所有点,能够连成恰好 \(1\) 个连通块的方案数。\(g_S\) 为仅使用 \(S\) 集合中的所有点的方案数,其中 \(cntr(S)\)\(S\) 中为 red 的个数,\(cntb(S)\) 为在 \(S\) 中 blue 的个数。

不难发现对于某一集合 \(S\) 而言,只有在 \(cntr(S) = cntb(S)\) 时才能连成恰好 \(1\) 个连通块,对于答案才有贡献。因此最终答案为:

\[ans = \sum_S \frac{f_S \times cntr(\overline{S})!}{m!} \]

且容易观察到 \(g_S = cntr(S)!\)

再想一下 \(f_S\)\(g_S\) 的关系,如何求得 \(f_S\) 呢?枚举 \(S\) 的子集 \(T\),以 \(f_T\) 加权和求得 \(g_S\),即恰好用 \(T\) 这个集合构成 \(1\) 个连通块,而剩下的随意排布,方案数即为排列数。(下式是个错误式子)

\[g_S = \sum_{T \in S} f_T \times cntr(S \setminus T)! \]

上式的问题之处在于,如果 \(T\)\(S \setminus T\) 同时可以构成恰好 \(1\) 个连通块,那么这种方案数就被算了两遍。因此,可以指定最低位的数 \(a\),钦定它在集合 \(T\) 中,再推导一下,有:

\[f_S = g_S - \sum_{T \subset S, a \in T} f_T \times cntr(S \setminus T)! \]

这个题就做完了,最后我们证明一下为什么指定最低位的数 \(a\) 转移能不重不漏,将下列四种情况代入回上面式子有:

  1. \(f_T=0, f_{S\setminus T}=0\)时,无影响
  2. \(f_T \not =0, f_{S\setminus T}=0\)时,无影响,且这种情况不可能出现
  3. \(f_T=0, f_{S\setminus T} \not =0\)时,这种情况不可能出现
  4. \(f_T \not =0, f_{S\setminus T} \not =0\)时,无影响

唯一能影响到答案的情况 3 在当前 \(f_S \not = 0\) 的情况下不可能出现,因此成立。

应用这种指定最低位的数 \(a\) 的方法(泛化一下是任意指定某个数的方法)应当满足如下几个要素:

  1. 求方案数(也许求别的也可以应用)
  2. 对于某个集合 \(S\),将其分割为两个集合 \(T\)\(S\setminus T\) 时,满足都同 \(0\) 或都不同 \(0\),形式化地为以下两个条件中的一个:
    • \(f_T=0且f_{S\setminus T}=0\)
    • \(f_T \not =0且f_{S\setminus T} \not =0\)
#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef double db;
typedef long double ld;

#define IL inline
#define fi first
#define se second
#define mk make_pair
#define pb push_back
#define SZ(x) (int)(x).size()
#define ALL(x) (x).begin(), (x).end()
#define dbg1(x) cout << #x << " = " << x << ", "
#define dbg2(x) cout << #x << " = " << x << endl

template <typename T>
void _debug(const char* format, T t) {
    cerr << format << '=' << t << endl;
}

template <class First, class... Rest>
void _debug(const char* format, First first, Rest... rest) {
    while (*format != ',') cerr << *format++;
    cerr << '=' << first << ',';
    _debug(format + 1, rest...);
}

template <typename T>
ostream& operator<<(ostream& os, const vector<T>& V) {
    os << "[ ";
    for (const auto& vv : V) os << vv << ", ";
    os << ']';
    return os;
}
#ifdef LOCAL
    #define dbg(...) _debug(#__VA_ARGS__, __VA_ARGS__)
#else
    #define dbg(...) 
#endif

template<typename Tp> IL void read(Tp &x) {
    x=0; int f=1; char ch=getchar();
    while(!isdigit(ch)) {if(ch == '-') f=-1; ch=getchar();}
    while(isdigit(ch)) { x=x*10+ch-'0'; ch=getchar();}
    x *= f;
}
template<typename First, typename... Rest> IL void read(First &first, Rest&... rest) {
    read(first); read(rest...);
}
int buf[42];
template<typename Tp> IL void write(Tp x) {
    int p = 0;
    if(x < 0) { putchar('-'); x=-x;}
    if(x == 0) { putchar('0'); return;}
    while(x) {
        buf[++p] = x % 10;
        x /= 10;
    }
    for(int i=p;i;i--) putchar('0' + buf[i]);
}
template<typename First, typename... Rest> IL void write(const First& first, const Rest&... rest) {
    write(first); putchar(32); write(rest...);
}

#include <atcoder/modint.hpp>
using mint = atcoder::modint998244353;

void solve() {
    int n, m; read(n, m);
    vector<int> cntr(1 << n), cntb(1 << n);
    for (int i = 0; i < m; i++) {
        int r; read(r); r--;
        cntr[1 << r]++;
    }
    for (int i = 0; i < m; i++) {
        int b; read(b); b--;
        cntb[1 << b]++;
    }
    for (int S = 2; S < (1 << n); S++) {
        if (__builtin_popcount(S) < 2) continue;
        for (int i = 0; i < n; i++) if (S >> i & 1) {
            cntr[S] += cntr[1 << i];
            cntb[S] += cntb[1 << i];
        }
    }
    vector<mint> f(1 << n), g(1 << n);
    vector<mint> J(m + 1);
    J[0] = 1;
    for (int i = 1; i <= m; i++) J[i] = J[i-1] * i;
    mint ans = 0;
    for (int S = 1; S < (1 << n); S++) {
        if (cntr[S] != cntb[S]) continue;
        f[S] = g[S] = J[cntr[S]];
        for (int T = (S - 1) & S; T > S - T; T = (T - 1) & S) {
            f[S] -= f[T] * g[S - T];
        }
        ans += f[S] * J[m - cntr[S]];
    }
    ans /= J[m];
    write(ans.val()); putchar(10);
}

int main() {
#ifdef LOCAL
    freopen("test.in", "r", stdin);
    // freopen("test.out", "w", stdout);
#endif
    int T = 1;
    // read(T);
    while(T--) solve();
    return 0;
}