Luogu 6097 【模板】子集卷积

发布时间 2023-07-21 08:35:03作者: Ender_32k

upd 2023/3/16:更改了时间复杂度的错误。


其实是暴力。

因为这是模板题,所以模板的前置知识也要讲。

  • 前置知识:FWT 计算或卷积。

这里只需要掌握快速计算或卷积的方法,所以内容较少。如果向了解更多(比如异或卷积)的话可以去 P4717 看看。

例题:给定长度为 \(2^n\) 的序列 \(a,b\),求 \(c_k=\sum\limits_{i|j=k}a_ib_j\) 序列中每一项的值。我们需要一个 \(O(n2^n)\) 的解法。

考虑根据 \(a,b\) 构造两个序列 \(A,B\),若 \(a\to A,b\to B\) 均为 \(O(n2^n)\) 并且这个过程可逆,令 \(C_i=A_i\times B_i\),若 \(C\) 能还原回 \(c\),那么我们就可以 \(O(n2^n)\) 计算 \(c\)

在这里,我们令 \(A_i=\sum\limits_{j|i=i}a_j\)(也就是 \(j\)\(i\) 的子集)。那么有 \(C_i=\sum\limits_{(j|k)|i=i}a_jb_k\)

那么有:

\[\begin{aligned}A_i\times B_i&=\left(\sum\limits_{j|i=i}a_j\right)\left(\sum\limits_{j|i=i}b_j\right)\\&=\sum\limits_{j|i=i,k|i=i}a_jb_k\\&=\sum\limits_{(j|k)|i=i}a_jb_k\\&=C_i\end{aligned} \]

也就是说我们证明了这个 \(A,B\) 是的确能映射到一个正确的 \(C\) 的。现在考虑如何快速求 \(A_i=\sum\limits_{j|i}a_j\)

显然可以从低到高枚举每个二进制位,然后当前位为 \(0\) 的是右边对应位置为 \(1\) 的子集,从左依次贡献到右即可。

\(C\to c\) 相当于求一个逆过程,右边依次减去左边的贡献即可。

参考了这里的代码。

void fwt(int *s, int op) {
	op = (op + mod) % mod;
	for (int o = 2, k = 1; o <= S + 1; o <<= 1, k <<= 1) 
		for (int i = 0; i <= S; i += o)
			for (int j = 0; j < k; j++)
				(s[i + j + k] += 1ll * s[i + j] * op % mod) %= mod;
}
  • 回到原题

你发现这东西就多加了一个限制 \(i\&j=0\),也就是说 \(i,j\) 无交。

考虑一个充要条件,\(i\&j=0\) 并且 \(i|j=k\) 其实就相当于 \(|i|+|j|=|k|\) 并且 \(i|j=k\)\(|i|\) 表示 \(i\) 集合的大小,即 \(i\)\(1\) 的个数。

所以可以预处理 \(f_{i,j}\) 表示满足 \(|j|=i\)\(a_j\) 的值,\(g_{i,j}\)\(b\) 同理。

那么 \(c_{k}=\sum\limits_{i=0}^n\sum\limits_{j=0}^{2^n-1}\sum\limits_{j|l=k}f_{i,j}g_{k-i,l}\)

\(h_{i}=\sum\limits_{k=0}^nf_{k}*g_{i-k}\)\(*\) 表示进行或卷积,那么 \(c_i=h_{|i|,i}\)

FWT 预处理每个 \(f_k\)\(g_k\) 的子集和,枚举这个 \(i,k\) 求出 \(h\) 的子集和,然后做逆的 FWT 就做完了。复杂度 \(O(n^22^n)\)

const int maxs = (1 << 20) + 100;
const int mod = 1e9 + 9;
int n, S, a[maxs], b[maxs], c[21][maxs], f[21][maxs], g[21][maxs];
int p[maxs];

void fwt(int *s, int op) {
	op = (op + mod) % mod;
	for (int o = 2, k = 1; o <= S + 1; o <<= 1, k <<= 1) 
		for (int i = 0; i <= S; i += o)
			for (int j = 0; j < k; j++)
				(s[i + j + k] += 1ll * s[i + j] * op % mod) %= mod;
}

int main() {
	n = read(), S = (1 << n) - 1;
	for (int i = 1; i <= S; i++) p[i] = p[i >> 1] + (i & 1);
    for (int i = 0; i <= S; i++) f[p[i]][i] = read();
    for (int i = 0; i <= S; i++) g[p[i]][i] = read();
	for (int i = 0; i <= n; i++) fwt(f[i], 1), fwt(g[i], 1);
	for (int i = 0; i <= n; i++)
		for (int j = 0; j <= i; j++)
			for (int k = 0; k <= S; k++)
				(c[i][k] += 1ll * f[j][k] * g[i - j][k] % mod) %= mod;
	for (int i = 0; i <= n; i++) fwt(c[i], -1);
	for (int i = 0; i <= S; i++) write(c[p[i]][i]), pc(' ');
	return 0;
}