CF1767C Count Binary Strings 题解

发布时间 2023-09-20 10:11:49作者: cn_ryh

CF1767C Count Binary Strings 题解

Foreword

感谢 @樱雪喵@swiftc 两位大佬的耐心指导。

洛谷

Codeforces

Description

有一个长度为 \(n\) 的 01 串 \(s\)(下标从 \(1\) 开始)和一些限制 \(a_{i,j}(1 \le i \le j \le n)\)

\(a_{i,j}\) 的含义如下:

\(a_{i,j}=\) 含义
\(0\) 没有限制
\(1\) 对于所有的 \(i \le p \le q \le j\) 均有 \(s_p=s_q\)
\(2\) 存在 \(i \le p \le q \le j\) 使得 \(s_p \neq s_q\)

求可能的 \(s\) 的个数。答案对 \(998244353\) 取模。

对于 \(100\%\) 的数据,\(2 \le n \le 100\)\(0 \le a_{i,j} \le 2\)

Solution

这种计数类问题大概率是组合数学或者 dp,然而这道题限制条件很多,组合数学大概是做不成了,那么考虑 dp。

首先一个很显然的情况是如果 \(a_{i,i} = 2\) 那么无解。接下来考虑对 \(a_{i,j} = 1\) 的限制,我们把这些需要区间内全相同的合并成一个块,对块内每个位置的限制等价于对整个块的限制,使用并查集维护即可,注意每个块合并的时候都合并到最前面的位置。

int fa[222];
void init()
{
    for (int i = 1; i <= n; i++)
        fa[i] = i;
}
int find(int u)
{
    if (fa[u] == u)
        return fa[u];
    else
        return fa[u] = find(fa[u]);
}
void merge(int u, int v)
{
    if (find(u) == find(v))
        return;
    if (find(u) < find(v))
        fa[find(v)] = find(u);
    else
        fa[find(u)] = find(v);
}

主函数中:

init();
for (int i = 1; i <= n; i++)
{
    for (int j = i; j <= n; j++)
    {
        read(nums[i][j]);
        if (nums[i][j] == 1)
        {
            for (int k = i + 1; k <= j; k++)
            {
                merge(i, k);
            }
        }
    }
}

之后,我们很容易想到,对于每个块 \(p\),我们找到对于 \(i \in p\)\(a_{i,j} = 2\) 这样的限制中最大的 \(j\),只要满足这条限制,在前面的限制也就都满足了,求出 \(mx_{i}\) 表示块 \(i\) 之前最后一个能满足从 \(mx_{i}\)\(i\)\(i\) 填的数全相同的块。

DP 的思路有两种,一种是一维的,另一种是二维的。我开始写的一维,然而没有考虑到一些问题写挂了,尝试了二维通过之后又回到了一维。

这里先从二维 DP 开始讲。

考虑某个位置和前一位是否相同,有:

\[\begin{cases} dp_{i,j} = dp_{i - 1,j} & \texttt{if } s_{i} = s_{i - 1} \\ dp_{i,i} = \sum_{j = mx_{i}}^{i - 1} dp_{i - 1,j} & \texttt{if } s_{i} \neq s_{i - 1} \end{cases}\]

即如果要求和前一位相同,不会有新的贡献,否则累加贡献。

直接判断是否可行并转移即可。

dp[1][1] = 1;
for (int i = 2; i <= n; i++)
{
    if (find(i) == i)
    {
        for (int j = 1; j < i; j++)
        {
            (dp[i][i] += dp[i - 1][j]) %= 998244353;
        }
    }
    for (int j = mx[find(i)]; j < i; j++)
    {
        (dp[i][j] += dp[i - 1][j]) %= 998244353;
    }
}

int res = 0;
for (int i = mmx; i <= n; i++)
{
    (res += dp[n][i]) %= 998244353;
}
writeln(2 * res % 998244353);

接下来考虑一维 DP 怎么做。

\[\begin{cases} dp_{i,j} = dp_{i - 1,j} & \texttt{if } s_{i} = s_{i - 1} \\ dp_{i,i} = \sum_{j = mx_{i}}^{i - 1} dp_{i - 1,j} & \texttt{if } s_{i} \neq s_{i - 1} \end{cases}\]

上面的方程实际上就是 相同的地方 复制了 前面第一个不同的地方 /kk

定义 \(k\)\(i\) 之前第一个不同的。有

\[dp_{i,i} = \sum_{k}\sum_{j = mx_{i}}^{i - 1} dp_{k,j} \ \texttt{ if } s_{i} \neq s_{i - 1} \]

也就是我们每次从 \(k\) 转移并且累加一下。

注意到由于 \(s_{k} \neq s_{k - 1}\),我们之前计算 \(dp_{k,j'}\) 的时候应该只更新了 \(dp_{k,k}\)

因此实际上我们得到的是 \(dp_{i,i} = \sum_{k}dp_{k,k} \ \texttt{ if } mx_{i} \leq k \leq i - 1\)

优化掉第二维,有 $dp_{i} = \sum_{k = mx_{i}}^{i - 1}dp_{k} $。

注意:这里有个问题,假设有按顺序 \(A,B,C,D\) 四个块,如果限制 \(B,C\) 不能相同,那么显然我们无法从 \(A\)\(D\) 全部相同,因此我们 \(mx\) 还要取一个前缀 \(\operatorname{max}\)

由于我们合并块的编号并不连续,这不利于我们 dp,因此将块的编号离散化一下即可。

// 由于 fa 更新的时候不是 1,2,3,4 这样,而是每块第一个的编号
// 我们把 fa[i] 离散化,pos 表示是第几个,rea 表示第 i 个的实际 fa
for (int i = 1; i <= n; i++)
{
    if (find(i) == i)
    {
        rea[++rea[0]] = i;
        pos[i] = rea[0];
    }
}

之后枚举上一个不同的点转移就可以了,方程:

for (int i = 1; i <= n; i++)
{
    if (find(i) != i)
    {
        continue;
    }
    ++cnt;
    // 现在只有无限制和要求出现不同了
    // 我们可以枚举上一个不同的位置

    for (int j = pos[mx[rea[cnt]]]; j < cnt; j++)
    {
        (dp[cnt] += dp[j]) %= 998244353;
    }
}

把所有 \(0\)\(1\) 交换不会违反限制,因此答案要乘 \(2\)

Codes

一维完整代码。

// Problem: C. Count Binary Strings
// Contest: Educational Codeforces Round 140 (Rated for Div. 2)
// URL: https://codeforces.com/contest/1767/problem/C
// Memory Limit: 512 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
using namespace std;
#define int long long
void read(int &p)
{
    p = 0;
    int k = 1;
    char c = getchar();
    while (c < '0' || c > '9')
    {
        if (c == '-')
        {
            k = -1;
        }
        c = getchar();
    }
    while (c >= '0' && c <= '9')
    {
        p = p * 10 + c - '0';
        c = getchar();
    }
    p *= k;
    return;
}
void write_(int x)
{
    if (x < 0)
    {
        putchar('-');
        x = -x;
    }
    if (x > 9)
    {
        write_(x / 10);
    }
    putchar(x % 10 + '0');
}
void writesp(int x)
{
    write_(x);
    putchar(' ');
}
void writeln(int x)
{
    write_(x);
    puts("");
}
int n, nums[200][200];
int fa[222];
void init()
{
    for (int i = 1; i <= n; i++)
    {
        fa[i] = i;
    }
}
int find(int u)
{
    if (fa[u] == u)
    {
        return fa[u];
    }
    else
    {
        return fa[u] = find(fa[u]);
    }
}
void merge(int u, int v)
{
    if (find(u) == find(v))
    {
        return;
    }
    if (find(u) < find(v))
    {
        fa[find(v)] = find(u);
    }
    else
    {
        fa[find(u)] = find(v);
    }
}
int dp[222];
int mx[222];
int rea[222];
int pos[222];
signed main()
{
    read(n);
    init();
    for (int i = 1; i <= n; i++)
    {
        for (int j = i; j <= n; j++)
        {
            read(nums[i][j]);
            if (nums[i][j] == 1)
            {
                for (int k = i + 1; k <= j; k++)
                {
                    merge(i, k);
                }
            }
        }
    }
    for (int i = 1; i <= n; i++)
    {
        if (nums[i][i] == 2)
        {
            puts("0");
            return 0;
        }
    }
    // 对于每一个需要找到最晚需要不同的,这样能满足前面的所有条件
    for (int i = 1; i <= n; i++)
    {
        mx[i] = 0;
    }
    for (int i = 2; i <= n; i++)
    {
        for (int j = i - 1; j; j--)
        {
            if (nums[j][i] == 2)
            {
                if (find(i) == find(j))
                {
                    puts("0");
                    exit(0);
                }
                mx[find(i)] = max(mx[find(i)], find(j));
                break;
            }
        }
    }
    for (int i = 1; i <= n; i++)
    {
        mx[i] = max(mx[i], mx[i - 1]);
    }

    int cnt = 0;
    dp[0] = 1;
    for (int i = 1; i <= n; i++)
    {
        if (find(i) == i)
        {
            rea[++rea[0]] = i;
            pos[i] = rea[0];
        }
    }

    for (int i = 1; i <= n; i++)
    {
        if (find(i) != i)
        {
            continue;
        }
        ++cnt;
        // 现在只有无限制和要求出现不同了
        // 我们可以枚举上一个不同的位置
        // 由于 fa 更新的时候不是 1,2,3,4 这样,而是每块第一个的编号
        // 我们把 fa[i] 离散化,pos 表示是第几个,rea 表示第 i 个的实际 fa
        for (int j = pos[mx[rea[cnt]]]; j < cnt; j++)
        {
            (dp[cnt] += dp[j]) %= 998244353;
        }
    }
    //   cout << cnt << endl;
    writeln(2 * dp[cnt] % 998244353);
    return 0;
}