LY1129 [ 20230308 CQYC省选模拟赛 T1 ] 或

发布时间 2024-01-08 21:57:43作者: cxqghzj

题意

给定 \(n\) 个数,你需要把每个数分成两组,使得:

  • 没有组为空。
  • 两个组内所有数的按位或相等。

Sol

不难发现,当某一位上全是 \(1\) 的数被分在了一个集合内时,方案一定不合法。

考虑容斥计算这个东西。

对于每一位,建一个点。考虑把所有数抽象成边,对于当前 \(a_i\) 如果第 \(j\) 位和第 \(k\) 位都为 \(1\),那么有一条 \(j \to k\) 的边。

注意到当前容斥的状态钦定点的方案为:\(2 ^ {联通块数量} - 2\)

如何计算剩下点有多少个?不难发现剩下点一定是当前状态的补集的子集,直接高维前缀和即可。

Code

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <cmath>
#include <vector>
#define int long long
using namespace std;
#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif
int read() {
	int p = 0, flg = 1;
	char c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-') flg = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9') {
		p = p * 10 + c - '0';
		c = getchar();
	}
	return p * flg;
}
void write(int x) {
	if (x < 0) {
		x = -x;
		putchar('-');
	}
	if (x > 9) {
		write(x / 10);
	}
	putchar(x % 10 + '0');
}
const int N = 2e5 + 5, M = 4e6 + 5, mod = 998244353;

int pow_(int x, int k, int p) {
	int ans = 1;
	while (k) {
		if (k & 1) ans = ans * x % p;
		x = x * x % p;
		k >>= 1;
	}
	return ans;
}

array <int, N> s, lg;

array <array <int, 23>, 23> g;
array <int, M> f;

namespace Uni {

array <int, N> fa;

int find(int x) {
	if (x == fa[x]) return x;
	return fa[x] = find(fa[x]);
}

void merge(int x, int y) {
	int fx = find(x),
		fy = find(y);
	if (fx == fy) return;
	fa[fx] = fy;
}

}

int lowbit(int x) {
	return x & -x;
}

array <array <int, 23>, M> arc;

void Mod(int &x) {
	if (x >= mod) x -= mod;
	if (x < 0) x += mod;
}

signed main() {
	freopen("or.in", "r", stdin);
	freopen("or.out", "w", stdout);
	int n = read(), tp = 0;
	for (int i = 1; i <= n; i++)
		s[i] = read(), f[s[i]]++, tp |= s[i];
	int m = log2(tp) + 1;
	for (int i = 0; i < m; i++)
		for (int j = 0; j < 1 << m; j++)
			if (j & (1 << i)) f[j] += f[j ^ (1 << i)];
	for (int i = 1; i <= n; i++)
		for (int j = 0; j < m; j++)
			for (int k = j + 1; k < m; k++)
				if ((s[i] & (1 << j)) && (s[i] & (1 << k)))
					g[j + 1][k + 1] = 1;
	for (int i = 1; i <= m; i++)
		arc[0][i] = i;
	int ans = 0;
	for (int T = 0; T < (1 << m); T++) {
		if ((T & tp) != T) continue;
		int sum = 0, x = log2(lowbit(T)) + 1, k = 0;
		for (int i = 1; i <= m; i++)
			if (T & (1 << (i - 1))) k++;
		/* _.clear(); */
		for (int i = 1; i <= m; i++)
			if (k) Uni::fa[i] = arc[T ^ lowbit(T)][i];
		for (int i = 1; i <= m; i++)
			if (k && g[x][i] && (T & (1 << (i - 1)))) Uni::merge(x, i);

		for (int i = 1; i <= m; i++) {
			if (T & (1 << (i - 1)) && Uni::fa[i] == i) sum++;
			if (k) arc[T][i] = Uni::fa[i];
		}
		/* puts(""); */
		/* if (k) sum++; */
		ans += pow_(-1, k, mod) * (pow_(2, sum + f[tp ^ T], mod) - 2) % mod, Mod(ans);
		/* write(T), putchar(32); */
		/* write(sum), putchar(32); */
		/* write(f[tp ^ T]), puts(""); */
	}
	write(ans), puts("");
	cerr << clock() << endl;
	return 0;
}