AGC046C

发布时间 2023-11-02 21:57:54作者: lalaouye

这是一种 dp 状态不那么抽象的组合数做法。但是很复杂,仅供参考。

经过思考后发现,我们可以将字符串串按零的位置割开并分成若干个子串,设 \(a_i\) 表示第 \(i\) 个子串中 \(1\) 的个数(子串长度),这样就能转化为每一次操作将后面的一个数减 \(1\),前面的一个数加 \(1\),求操作数小于等于 \(k\) 能产生的本质不同的数组个数,我们发现这个可以用 dp 解决。

具体应该怎么做的?容易想到一个不管重复的 dp,也就是记 \(f_{i,j}\) 表示考虑到第 \(i\) 个数,操作 \(j\) 次可以得到的不同操作序列,显然这个肯定是错的。该怎么判重?通过观察发现,每一个 \(a_i\) 要么只加要么只减,否则所需步数还多,最终还有可能重复。我们统计的数即为数组的加减情况(废话),但是这还远远不够,因为我们发现对于当前状态下的两种情况 6 25 3,现在考虑加一个 \(1\),当我们加到第一个数时,可以得到 7 26 3,当我们加到第二个数时,可以得到 6 35 4,很遗憾,又有重复情况出现,考虑继续增添限制。观察这个过程,我们如果在一个方案中,先加完需要增加的第一个,再加需要增加的第二个,然后第一个就不再去管了,于是这个做法就不会被操作顺序等影响了,这样就可以进行一个 dp,记 \(f_{i,j,k}\) 表示考虑到第 \(i\) 位,前面有 \(j\) 个位置可以加,共花了 \(k\) 次操作的方案数,这个 dp 有两个转移:

\[f_{i,j,k}\longleftarrow f_{i-1,j-1,k} \]

\[f_{i,j,k}=\sum_{l=0}^{j-1}\sum_{o=1}^{\min(k,a_i)} f_{i-1,j,k-o}\times\binom{o+l-1}{l} \]

我们将只能被加的位叫做接收位,只能减的位叫做输出位。

第一个转移表示让当前这一位作为接收位。

第二个转移就有点复杂了,首先这个是这一位作为输出位的转移,\(l\) 表示这一位加完了前面的多少接收位(意思就是说,有 \(l\) 位是只有我能加的,其它的输出位只能加这 \(l\) 位以后的了),显然 \(l\) 是可以为 \(0\) 的,因为这就表示这个输出位全部就加到当前的接收位,并且下一个输出位也可以加到这里。

\(o\) 则表示一共输出了 \(o\),后面的组合数则表示给前面的 \(l+1\) 的接收位的分配方案数,用插板法计算,至于上面为啥要减个 \(1\),是因为如果不保证最后一位一定会被分配就又会有重复的情况,所以先钦定一个输出量给最后一个接收位。

因为 \(o\) 最多就会枚举 \(n\) 次,所以这个 dp 的时间复杂度为 \(\mathcal{O}(n^4)\),但是估计出题人没想到会有这么愚蠢的方法出现,于是没有故意卡,最慢的点也只有四百多毫秒,而且它跑不满,所以直接过了

代码:

#include <bits/stdc++.h>
#define int long long
#define rep(i, l, r) for (int i = l; i <= r; ++ i)
#define rrp(i, l, r) for (int i = r; i >= l; -- i)
#define pii pair <int, int>
#define eb emplace_back
#define inf 1000000000
#define linf 1000000000000000000
using namespace std;
typedef long long ll;
constexpr int N = 305, P = 998244353;
inline int rd ()
{
    int x = 0, f = 1;
    char ch = getchar ();
    while (! isdigit (ch))
    {
        if (ch == '-') f = -1;
        ch = getchar ();
    }
    while (isdigit (ch))
    {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar ();
    }
    return x * f;
}
void add (int &x, int y)
{
	(x += y) >= P && (x -= P);
}
char s[N];
int m;
int f[2][N][N];
int a[N];
int C[N << 1][N << 1];
signed main ()
{
    // freopen ("1.in", "r", stdin);
    // freopen ("1.out", "w", stdout);
	C[0][0] = 1;
	rep (i, 1, 600) rep (j, 0, i) C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % P;
    scanf ("%s", s + 1); m = rd ();
	int n = strlen (s + 1); s[++ n] = '0';
	m = min (m, n); int cnt = 0, tot = 0, sum = 0;
	rep (i, 1, n) if (s[i] == '0') a[++ tot] = cnt, cnt = 0; else ++ cnt, ++ sum;
	int p = 0; f[0][0][0] = 1;
	rep (i, 1, tot)
	{
		int v = ! p; rep (j, 0, i) rep (k, 0, m)
			{ if (j) f[v][j][k] = f[p][j - 1][k];
				rep (l, 0, j - 1)
					rep (o, 1, min (k, a[i]))
						add (f[v][j - l][k], f[p][j][k - o] * C[o + l - 1][l] % P);
			} memset (f[p], 0, sizeof f[p]);
	}
	int ret = 0; rep (i, 0, tot) rep (j, 0, m)
	add (ret, f[p][i][j]); printf ("%lld\n", ret);
}