[ABC313F] Flip Machines 题解

发布时间 2023-09-15 15:35:18作者: User-Unauthorized

题意

\(N\) 张卡片,第 \(i\) 张卡片正面印着一个数 \(A_i\),反面印着一个数 \(B_i\)。一开始所有数正面朝上。

\(M\) 种操作,第 \(i\) 种操作表示为:

  • \(50\%\) 的概率将卡片 \(X_i\) 翻转,否则将 \(Y_i\) 翻转。

求一个集合 \(S\subseteq \mathbb{N} \bigcap \left[1,M\right]\),使得进行了集合中所有的编号的操作之后正面朝上的所有数的和的期望最大。输出这个最大值。

\(1 \le N \le 40, 1 \le M \le 10^5\))。

题解

首先可以发现若对于卡片 \(i\),若有操作 \(k\) 满足 \(X_k = i\)\(Y_k = i\),那么其最终状态正反面概率相等,无论之前的状态如何。所以我们可以发现每个卡片只有被选择和不被选择两种可能,而每次操作为选择两个卡片,故我们可以尝试转化为图论问题,将操作视为边,将卡片视为节点。

现在我们的问题转化为了:给定一个无向图,每个点有点权,选择若干条边,使得覆盖的点权和最大。但是发现这样还不够,我们继续挖掘题目性质。可以发现

  • 若有一条边连接了两个非负权的点,那么这条边一定入选;
  • 如果一条边连接了两个负权的点,那么这条边一定不入选。

所以我们要处理的边还剩下两种情况:

  • 连接了一个非负权和一个负权的点的边;
  • 自环。

后者是好处理的,发现若一个节点被选择,那么其贡献期望为 \(\dfrac{A_i + B_i}{2}\),若其不被选择,那么其贡献为 \(A_i\),而形如自环的边可以让我们以百分百的概率翻转一条边,可以发现,若 \(A_i < B_i\),那么翻转该卡片一定不劣,否则翻转一定不优。

现在我们只需要处理连接了一个非负权和一个负权的点的边,发现其形如二分图的形式,而两个部分的节点数之和一定不超过 \(N\),设 \(P\) 表示非负权点数量,\(Q\) 表示负权点数量,那么我们可以得出 \(\min\left\{P, Q\right\} \le \dfrac{N}{2} \le 20\),这启发我们去枚举较小的点集来解决问题。

\(P \ge Q\)

我们考虑枚举负权点数量,由于在删点和删边操作后,图中只剩下了连接负权点和非负权点的边,也就是说与负权点相连的一定为非负权点,所以若有一个负权点被覆盖,覆盖所有与它相连的非负权点一定更优。得到负权点的策略后,通过枚举每个负权点是否被覆盖,我们可以得到最优解。

\(P < Q\)

考虑状压,设 \(f_S\) 代表覆盖点集 \(S\) 的最优解,既然我们得到了负权点的策略,那么我们就可以快速的进行转移了。

两种算法的复杂度均为 \(\mathcal{O}(2^{\frac{N}{2}} + N + M)\),可以通过本题。

Code

#include <bits/stdc++.h>

typedef long long valueType;
typedef long double realType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;
typedef std::pair<valueType, valueType> ValuePair;
typedef std::vector<ValuePair> PairVector;
typedef std::vector<bool> bitset;

ValueVector count;

valueType dfs(valueType x, ValueVector const &C, ValueVector const &set, ValueMatrix const &G) {
    if (x == set.size())
        return 0;

    valueType ans = dfs(x + 1, C, set, G);

    valueType sum = C[set[x]];

    for (auto const &iter: G[set[x]]) {
        if (count[iter] == 0)
            sum += C[iter];

        ++count[iter];
    }

    sum += dfs(x + 1, C, set, G);

    for (auto const &iter: G[set[x]])
        --count[iter];

    return std::max(ans, sum);
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    valueType N, M;

    std::cin >> N >> M;

    PairVector edge;

    ValueVector A(N), B(N);

    for (valueType i = 0; i < N; ++i)
        std::cin >> A[i] >> B[i];

    for (valueType i = 0; i < M; ++i) {
        valueType X, Y;

        std::cin >> X >> Y;

        --X;
        --Y;

        if (X == Y) {
            if (A[X] < B[X])
                std::swap(A[X], B[X]);

            continue;
        }

        edge.emplace_back(X, Y);
    }

    valueType const base = std::accumulate(A.begin(), A.end(), (valueType) 0);

    ValueVector C(N);

    for (valueType i = 0; i < N; ++i)
        C[i] = B[i] - A[i];

    bitset exist(N, true);
    ValueMatrix G(N);

    for (auto const &iter: edge) {
        if (C[iter.first] < 0 && C[iter.second] < 0)
            continue;

        if (C[iter.first] >= 0 && C[iter.second] >= 0) {
            exist[iter.first] = false;
            exist[iter.second] = false;

            continue;
        }

        if (!exist[iter.first] || !exist[iter.second])
            continue;

        if (C[iter.first] < 0)
            G[iter.first].push_back(iter.second);
        else
            G[iter.second].push_back(iter.first);
    }

    valueType ans = 0;
    valueType leftCount = 0, rightCount = 0;
    ValueVector rightSet, id(N), node(N);

    for (valueType i = 0; i < N; ++i) {
        if (!exist[i]) {
            ans += C[i];

            continue;
        }

        if (C[i] >= 0) {
            id[i] = leftCount;
            node[leftCount] = i;

            ++leftCount;
        } else {
            for (auto iter = G[i].begin(); iter != G[i].end();)
                if (!exist[*iter])
                    iter = G[i].erase(iter);
                else
                    ++iter;

            if (G[i].empty())
                continue;

            ++rightCount;

            rightSet.push_back(i);
        }
    }

    if (leftCount <= rightCount) {
        valueType const S = 1 << leftCount;

        ValueVector V(S, 0);

        for (valueType j = 0; j < S; ++j)
            for (valueType i = 0; i < leftCount; ++i)
                if (j & (1 << i))
                    V[j] += C[node[i]];

        ValueVector F(S, std::numeric_limits<valueType>::min() >> 1);

        F[0] = 0;

        for (auto const i: rightSet) {
            valueType bit = 0;

            for (auto const &iter: G[i])
                bit |= 1 << id[iter];

            for (valueType j = 0; j < S; ++j)
                F[j | bit] = std::max(F[j | bit], F[j] + C[i]);
        }

        valueType max = std::numeric_limits<valueType>::min();

        for (valueType i = 0; i < S; ++i)
            max = std::max(max, V[i] + F[i]);

        std::cout << std::fixed << std::setprecision(10) << (realType) (2 * base + ans + max) / 2.0 << std::endl;
    } else {
        count.resize(N, 0);

        std::cout << std::fixed << std::setprecision(10) << (realType) (2 * base + ans + dfs(0, C, rightSet, G)) / 2.0 << std::endl;
    }
}