CF1860D 题解

发布时间 2023-09-05 21:59:33作者: Illumina

在 Codeforces 上看到了这题的 \(\mathcal{O}(n ^ 4 / \omega)\) 做法,和大家分享一下。

原版 Solution 链接

\(d\) 为原字符串 \(s\)\(01\)\(10\) 的个数差。

观察题目可以得到以下性质:

  • 不可能交换两个 \(0\) 或两个 \(1\),不可能在同一个位置上进行两次操作。

  • 当交换位置 \((p, q)\) 上的 \(0\)\(1\) 时,\(d\) 会增加或减少 \((p - q) \times 2\)

现在我们需要选出两个长度都为 \(l\)\(s\) 的序列 \(a\)(选择交换的 \(0\) 的下标)和 \(b\)(选择交换的 \(1\) 的下标),在满足 \(\sum_{i=1}^{k}{b_i - a_i} = d / 2\) 的条件下希望 \(l\) 最小。

考虑对下标 DP,设 \(f_{0 / 1,i , j, k} = 0 / 1\) 表示目前转移到 \(0\)\(1\) 的第 \(i\) 个位置,已经选择的下标之和为 \(j\),选了 \(k\) 个数是否可行。转移为:

f[s][i + 1][j + i][k + 1] |= f[s][i][j][k]
f[s][i + 1][j][k] |= f[s][i][j][k]

最后求出最小的 \(l\) 满足存在 \(x\) 使 \(f_{0, n, x, l} =1\)
\(f_{1, n, x + d, l} = 1\)

这个 DP 是 \(\mathcal{O}(n ^ 4)\) 的,注意到 \(f\) 数组的取值为 \(0\)\(1\),可以使用 bitset 将第三维压掉。设 bitset 数组 \(g_{0/1, i, k}\) 为目前转移到 \(0\)\(1\) 的第 \(i\) 个位置,选了 \(k\) 个数的下标之和的状态。转移为:

g[s][i + 1][k + 1] |= g[s][i][k] << i
g[s][i + 1][k] |= g[s][i][k]

最后利用倒序枚举 \(k\) 滚动第二位得转移方程为:

g[s][k + 1] |= g[s][k] << i

复杂度为 \(\mathcal{O}(n ^ 4/w)\)。附上代码:

#include <iostream>
#include <bitset>
#include <string.h>
#define gc getchar 
#define pc putchar
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
template <class Typ> Typ &read(Typ &x) {
    char ch = gc(), sgn = 0; 
    for(; !isdigit(ch); ch = gc()) sgn |= ch == '-';
    for(x = 0; isdigit(ch); ch=gc()) x = x * 10 + (ch ^ 48);
    return sgn && (x = -x), x;
}
template <class Typ> void write(Typ x) {
    if(x < 0) pc('-'), x = -x;
    if(x > 9) write(x / 10);
    pc(x % 10 ^ 48);
}
const int N = 110;
bitset<N * N / 2> pack[2][N];
char s[N];
int n, diff; 
int main() {
    scanf("%s", s + 1), n = strlen(s + 1);
    for(int i = 1; i <= n; i++)
        for(int j = i + 1; j <= n; j++) {
            if(s[i] == '0' && s[j] == '1') diff++;
            if(s[i] == '1' && s[j] == '0') diff--;
        }
    diff >>= 1;
    pack[0][0][0] = 1, pack[1][0][0] = 1;
    for(int i = 1; i <= n; i++) for(int j = i - 1; j >= 0; j--)
        pack[s[i] - '0'][j + 1] |= pack[s[i] - '0'][j] << i;
    for(int i = 0; i <= n; i++) 
        for(int j = 0; j <= n * (n-1) >> 1; j++) if(j + diff >= 0) 
            if(pack[0][i][j] && pack[1][i][j + diff]) { 
                write(i), pc('\n');
                return 0;
            }
    return 0;
}