CF1264D Beautiful Bracket Sequence

发布时间 2023-07-20 18:24:45作者: Ender_32k

这里是加强版,\(n\le 10^6\)

考虑到最后删剩下括号序列形如 (((...(()))...)),想到枚举分界点。

\(p\) 为当前枚举的分界点,\(l\)\([1,p]\)( 的个数,\(r\)\([p+1,n]\)) 的个数,\(x\)\([1,p]\)? 的个数,\(y\)\([p+1,n]\)? 的个数。

那么 \(p\) 这个位置作为分界线的贡献为 \(\sum\limits_{i=0}^x(l+i)\dbinom{x}{i}\dbinom{y}{l+i-r}\),即枚举左边的 ? 有多少个是填 (,然后和右边匹配。

注意到这里的组合数有组合意义,那么当 \(l+i-r>y\) 或者 \(<0\) 时,\(\dbinom{y}{l+i-r}\) 应为 \(0\),因为右边匹配不完左边。

考虑化简:

\[\begin{aligned}&\sum\limits_{i=0}^x(l+i)\dbinom{x}{i}\dbinom{y}{l+i-r}\\=&l\sum\limits_{i=0}^{x}\dbinom{x}{i}\dbinom{y}{l+i-r}+\sum\limits_{i=0}^xi\dbinom{x}{i}\dbinom{y}{l+i-r}\end{aligned} \]

利用吸收恒等式和范德蒙德卷积不难推出:

\[\begin{aligned}&l\sum\limits_{i=0}^{x}\dbinom{x}{i}\dbinom{y}{l+i-r}+\sum\limits_{i=0}^xi\dbinom{x}{i}\dbinom{y}{l+i-r}\\=&l\sum\limits_{i=0}^x\dbinom{x}{i}\dbinom{y}{y+r-l-i}+x\sum\limits_{i=0}^{x}\dbinom{x-1}{i-1}\dbinom{y}{y+r-l-i}\\=&l\dbinom{x+y}{y+r-l}+x\dbinom{x+y-1}{y+r-l-1}\end{aligned} \]

\(O(n)\) 枚举 \(p\),计算贡献即可。

#include <bits/stdc++.h>
using namespace std;

namespace vbzIO {
	char ibuf[(1 << 20) + 1], *iS, *iT;
//	#if ONLINE_JUDGE
//	#define gh() (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), (iS == iT ? EOF : *iS++) : *iS++)
//	#else
	#define gh() getchar()
//	#endif
	#define rd read
	#define wr write
	#define pc putchar
	#define pi pair<int, int>
	#define mp make_pair
	#define fi first
	#define se second
	#define pb push_back
	#define ins insert
	#define era erase
	inline int read () {
		char ch = gh();
		int x = 0;
		bool t = 0;
		while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
		while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gh();
		return t ? ~(x - 1) : x;
	}
	inline void write(int x) {
		if (x < 0) {
			x = ~(x - 1);
			putchar('-');
		}
		if (x > 9)
			write(x / 10);
		putchar(x % 10 + '0');
	}
}
using vbzIO::read;
using vbzIO::write;

const int N = 1e6 + 100;
const int mod = 998244353;
int n, pr[N][3], fac[N], ifac[N], inv[N];
char s[N];

void init(int n) {
	fac[0] = ifac[0] = inv[1] = 1;
	for (int i = 1; i <= n; i++) {
		if (i > 1) inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
		fac[i] = 1ll * fac[i - 1] * i % mod, ifac[i] = 1ll * ifac[i - 1] * inv[i] % mod;
	}
}

int C(int n, int m) {
	if (m < 0 || m > n) return 0;
	return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}

int main() {
	scanf("%s", s + 1), n = strlen(s + 1), init(n);
	for (int i = 1; i <= n; i++) {
		pr[i][0] = pr[i - 1][0] + (s[i] == '?');
		pr[i][1] = pr[i - 1][1] + (s[i] == '(');
		pr[i][2] = pr[i - 1][2] + (s[i] == ')');
	}
	int res = 0;
	for (int i = 1; i <= n - 1; i++) {
		(res += 1ll * pr[i][1] * C(pr[n][0], pr[n][0] - pr[i][0] + pr[n][2] - pr[i][2] - pr[i][1]) % mod) %= mod;
		(res += 1ll * pr[i][0] * C(pr[n][0] - 1, pr[n][0] - pr[i][0] + pr[n][2] - pr[i][2] - pr[i][1] - 1) % mod) %= mod;
	}
	wr(res);
	return 0;
}