P6665 Forget You

发布时间 2023-08-07 23:39:57作者: zhouyuhang

补完番后来做一下这道题。

首先考虑 \(n=1\) 怎么做。一个很直观的感觉是,如果将一组集合进行首尾配对,即 \((1,a_i),(2,a_i-1),\cdots\),那么每一对中的两个数地位均等(即在所有方案中的出现次数均等)。证明可以考虑将所有方案进行配对,\((p_1,p_2,\cdots,p_l)\) 对应 \((a_i-p_l+1,\cdots,a_i-p_2+1,a_i-p_1+1)\) 即可。于是我们只需对每个 \(j\) 统计长度为 \(j\) 的序列个数,将其乘上 \(\frac{j(1+a_1)}{2}\) 并求和即为 \(n=1\) 的答案。

现在来看看怎么统计一个集合形成的长度为 \(j\) 序列的个数。用插板法很容易得到答案为 \(\tbinom{a_1+j-1}{j}\)。于是我们解决了 \(n=1\) 时的问题。怎么将其扩展至更大的 \(n\) 呢?

我们先来考虑求总的方案数。套路地,写出每个集合的 \(\operatorname{EGF}\)(设其为 \(F_i(x)=\sum_{j=0}^{b_i}\frac{\tbinom{a_i+j-1}{j}}{j!}x^j\)),将它们卷起来得到 \(M(x)=\prod F_i(x)\),并求出 \(\sum_i[i!x^i]M(x)\) 即为答案。回到原问题,同样套路地拆出求和的贡献,答案就是将每个集合的和乘上它与剩下集合的方案数并求和。对于这样的形式,我们可以进一步设出 \(G_i(x)=\sum_{j=0}^{b_i}\left([\frac{j(2s+a_i+1)}{2}x^j]F_i(x)\right)x^i\),其中 \(s=\sum_{j=1}^{i-1}a_j\)。显然 \(G_i(x)\) 是第 \(i\) 个集合的和的 \(\operatorname{EGF}\)。那么我们只需修改 \(M(x)\) 的定义为 \(M(x)=\sum_{i=1}^nG_i(x)\prod_{j\neq i}F_j(x)\),同样有 \(\sum_i[i!x^i]M(x)\) 即为答案。而维护 \(M(x)\) 则是简单的分治 NTT,在此不多赘述。最终复杂度为 \(\Theta(k\log^2 k)\)

代码:

const int N = 1e5 + 10;

int n, maxa = 0, maxb = 0;
int a[N], b[N];

vint fac, ifac;
void init(int lim) {
	fac = ifac = vint(lim + 1, 1);
	for (int i = 1; i <= lim; ++i) fac[i] = mul(fac[i - 1], i);
	ifac[lim] = Pow(fac[lim], P - 2);
	for (int i = lim; i >= 1; --i) ifac[i - 1] = mul(ifac[i], i);
}
int c(int n, int m) { return (n < m || m < 0) ? 0 : mul(fac[n], mul(ifac[m], ifac[n - m]));}

using Node = array<Poly, 2>;
queue<Node> q;

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	
	prework();
	
	cin >> n;
	for (int i = 1; i <= n; ++i) cin >> a[i] >> b[i], maxa = max(maxa, a[i]), maxb = max(maxb, b[i]);
	init(maxa + maxb);

	for (int i = 1, s = 0; i <= n; s = add(s, a[i]), ++i) {
		Poly x(b[i] + 1), y(b[i] + 1);
		int t = add(s, mul(a[i] + 1, (P + 1) / 2));
		for (int j = 0; j <= b[i]; ++j) x[j] = mul(c(j + a[i] - 1, j), ifac[j]), y[j] = mul(t, mul(j, x[j]));
		q.push({x, y});
	}
	
	while (q.size() > 1) {
		Node u = q.front(); q.pop();
		Node v = q.front(); q.pop();
		q.push({u[0] * v[0], u[0] * v[1] + u[1] * v[0]});
	}
	
	Poly x = q.front()[1];
	int ans = 0;
	for (int i = 0, t = 1; i < x.size(); ++i, t = mul(t, i)) ans = add(ans, mul(x[i], t));
	
	cout << ans << endl;

	return 0;
}

隐去了多项式模板。