LOJ #6040「雅礼集训 2017 Day5」矩阵

发布时间 2023-08-10 22:02:40作者: came11ia

给定 \(01\) 矩阵 \(C\),求有多少个 \(01\) 矩阵的有序对 \((A,B)\) 满足 \(A \times B \equiv C \pmod 2\)

\(n \leq 2 \times 10^3\)


先考虑如果知道了 \(A\) 怎么做。考虑把 \(C\)\(A\) 写成若干行向量的组合 \(c_1 \sim c_n\)\(a_1 \sim a_n\),容易发现 \(c_i\) 是由 \(a_1 \sim a_n\) 线性组合得到的结果,每个 \(a_i\) 的系数由 \(B\) 给出。那么我们可以求出 \(A\) 的极大线性无关组,然后对每个 \(c_i\) 判断是否都能被线性表出,如果是,设 \(A\) 的秩为 \(r_A\),由于每个 \(c_i\) 独立,且恰有 \(n-r_A\) 个自由元,于是答案就是 \(2^{(n-r_A)n}\)

上面对每个 \(A\) 求极大线性无关组再判断能否表出 \(c_1 \sim c_n\) 的过程实在是太麻烦了,考虑能否规避掉这个过程。重要的结论是:对 \(A,C\),同秩矩阵都是等价的。对 \(C\) 的一个简单证明:考虑同秩矩阵 \(C,D\),由于同秩矩阵一定可以通过初等变换得到,并且初等变换恰好对应可逆矩阵,则可设 \(D = PCQ\),那么对于 \((A,B)\) 满足 \(A \times B \equiv C \pmod 2\)\((PA,BQ)\) 就满足 \((PA) \times BQ \equiv P \times (AB) \times Q \equiv PCQ \equiv D \pmod 2\),这是一个双射。从线性变换的角度也容易知道对 \(A\) 同秩矩阵等价。

于是我们可以先高斯消元求出 \(C\) 的秩 \(r_C\),求出秩为 \(r_C\) 的矩阵的答案,再除以这样的矩阵的个数即可。

考虑把问题稍作转化,变成考虑有序对 \((A,C)\)\(C\) 的秩为 \(r_C\)\(A\) 的秩为 \(r_A\),满足 \(r_C \leq r_A\),这样一对的贡献是 \(2^{(n-r_A)n}\),求所有对的贡献和。设 \(f_{i,j}\) 表示 \(i \times n\) 的矩阵,秩为 \(j\) 的方案数,考虑每次新加一个大小为 \(n\) 的向量,并讨论其是否能被之前的东西表出,则有 \(f_{i,j} \gets f_{i-1,j} \times 2^j + f_{i-1,j-1} \times (2^n - 2^{j-1})\)。考虑枚举 \(A\) 的秩 \(r_A\),这样的 \(A\) 的个数为 \(f_{n,r_A}\),考虑如何算 \(C\) 的个数。由于每个 \(c_i\) 都能被 \(A\) 表出,我们可以把 \(C\) 改写成 \(n\) 个大小为 \(r_A\) 的向量,又由于矩阵行向量的秩和列向量的秩相等,于是 \(C\) 的个数为 \(f_{r_A,r_C}\)

于是就可以算答案了。求 \(r_C\) 可以使用 bitset 加速,总时间复杂度 \(\mathcal{O}(n^2 + \frac{n^3}{\omega})\)

code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef vector <LL> vi;
constexpr int N = 5e3 + 5, mod = 1e9 + 7;
int ksm(int a, int b) {
	int ret = 1;
	for (; b; b >>= 1, a = 1LL * a * a % mod) if (b & 1) ret = 1LL * ret * a % mod;
	return ret;
}
int n, r, f[N][N], pw[N];
bitset <N> a[N];
int main() {
	ios :: sync_with_stdio(false);
	cin.tie(nullptr);
	freopen("mat.in", "r", stdin);
	freopen("mat.out", "w", stdout);
	cin >> n;
	for (int i = 1; i <= n; i++) {
		for (int j = 1; j <= n; j++) {
			int x;
			cin >> x, a[i][j] = x;
		}
	}
	r = 0;
	for (int i = 1; i <= n; i++) {
		int p = r + 1, k = p;
		while (a[k][i] == 0 && k <= n) k++;
		if (k == n + 1) continue;
		if (k > p) swap(a[k], a[p]);
		for (int j = p + 1; j <= n; j++) if (a[j][i]) a[j] ^= a[p];
		r++; 
	}
	pw[0] = 1;
	for (int i = 1; i <= n; i++) pw[i] = 2LL * pw[i - 1] % mod;
	f[0][0] = 1;
	for (int i = 1; i <= n; i++) {
		for (int j = 0; j <= i; j++) {
			f[i][j] = 1LL * f[i - 1][j] * pw[j] % mod;
			if (j >= 1) f[i][j] = (f[i][j] + 1LL * f[i - 1][j - 1] * (pw[n] + mod - pw[j - 1]) % mod) % mod;
		}
	}
	int ans = 0;
	for (int i = r; i <= n; i++) ans = (ans + 1LL * f[n][i] * f[i][r] % mod * ksm(pw[n - i], n) % mod) % mod;
	ans = 1LL * ans * ksm(f[n][r], mod - 2) % mod;
	cout << ans << "\n"; 
	return 0;
}