好用的bitset——暴力的好帮手

发布时间 2023-06-02 15:07:46作者: jiangchenyangsong

会持续更新的。

bitset

C++ 的 bitset 在 bitset 头文件中,它是一种类似数组的结构,它的每一个元素只能是0或1,每个元素仅用 1 bit空间。
bitset的原理大概是将很多数压成一个,从而节省空间和时间, 一般来说bitset会让你的算法复杂度 /32

构造 bitset

一般有下五种:

#include<bits/stdc++.h>
#include<bitset> // 头文件 
using namespace std;
int main(){
	bitset<4> bit1; //无参构造,长度为4,默认每一位为0

	bitset<8> bit2(12); //长度为8,二进制保存,前面用0补充

	string s1 = "1010100";
	bitset<16> bit3(s1); //长度为 16,前面用0补充 

	bitset<10> bit4(string("1010100")); //与上一个相同,长度为 10 

	char s2[] = "101000"; 
	bitset<25> bit5(s2); //长度为 13,前面用0补充,好像要 C++14 以上才能用 
	
	cout << bit1 << endl; //0000
	cout << bit2 << endl; //00001100
	cout << bit3 << endl; //0000000001010100
	cout << bit4 << endl; //0001010100
	cout << bit5 << endl; //0000000000000000000101000 
	return 0;
} 

用字符串构造时,字符串只能包含 '0' 或 '1' ,否则会抛出异常。

构造时,需在<>中表明bitset 的大小(即size)。

在进行有参构造时,若参数的二进制表示比bitset的size小,则在前面用0补充(如上面的栗子);若比bitsize大,参数为整数时取后面部分,参数为字符串时取前面部分:

#include<bits/stdc++.h>
#include<bitset> // 头文件 
using namespace std;
int main(){
	bitset<2> bit1(12); // 12的二进制为1100(长度为4),但bit1的size=2,只取后面部分,即00
	
	string s1 = "100101";
	bitset<4> bit2(s1); // s的size=6,而bitset的size=4,只取前面部分,即1001
    
    char s2[] = "11101";
    bitset<4> bit3(s2); //与bit2同理,只取前面部分,即1110
    
    cout << bit1 << endl;    //00
    cout << bit2 << endl;    //1001
    cout << bit3 << endl;    //1110
	return 0;
} 

可用的操作符

bitset的运算就像一个普通的整数一样,可以进行与(&)、或(|)、异或(^)、左移(<<)、右移(>>)等操作。

#include<bits/stdc++.h>
#include<bitset> // 头文件 
using namespace std;
int main(){
	bitset<4> bit1(string("1001"));
	bitset<4> bit2(string("0011"));
	
	//注意这里赋值操作后值会改变 
	cout << (bit1 ^= bit2) << endl; // 1010 (1001 ^ 0011 bit1对bit2按位异或后赋值给bit1)
	cout << (bit1 &= bit2) << endl; // 0010 (1010 & 0011 按位与后赋值给bit1)
	cout << (bit1 |= bit2) << endl; // 0011 (0010 | 0011 按位或后赋值给bit1)
	//bit1=0011       bit2=0011
	
	cout << (bit1 <<= 2) << endl; //1100 (左移2位,低位补0,有自身赋值)
	cout << (bit1 >>= 1) << endl; //0110 (右移1位,高位补0,有自身赋值)
	//bit1=0110       bit2=0011
	
	cout << (~bit2) << endl; // 1100 (按位取反)
	cout << (bit2 << 1) << endl;  //0110 (左移,不赋值)
	cout << (bit2 >> 1) << endl;  // 0001 (右移,不赋值)
	//bit1=0110       bit2=0011
	
	cout << (bit1 == bit2) << endl; // false (0110==0011为false)
	cout << (bit1 != bit2) << endl; // true  (0110!=0011为true)
	
	cout << (bit1 & bit2) << endl; // 0010 (按位与,不赋值)
	cout << (bit1 | bit2) << endl; // 0111 (按位或,不赋值)
	cout << (bit1 ^ bit2) << endl; // 0101 (按位异或,不赋值)
	return 0;
} 

此外,可以通过 [ ] 访问元素(类似数组),注意最低位下标为0,如下:

#include<bits/stdc++.h>
#include<bitset> // 头文件 
using namespace std;
int main(){
	bitset<4> bit1(string("1011"));
	
	cout << bit1[0] << endl; //1
	cout << bit1[1] << endl; //1
	cout << bit1[2] << endl; //0
	return 0;
} 

当然,通过这种方式对某一位元素赋值也是可以的,例子就不放了。

可用函数

bit1.size() 返回大小(位数)
bit1.count() 返回1的个数
bit1.any() 返回是否有1
bit1.none() 返回是否没有1
bit1.set() 全都变成1
bit1.test(p) 用来检查第 p + 1 位是否为 1
bit1.set(p) 将第p + 1位变成1
bit1.set(p, x) 将第p + 1位变成x
bit1.reset() 全都变成0
bit1.reset(p) 将第p + 1位变成0
bit1.flip() 全都取反
bit1.flip(p) 将第p + 1位取反
bit1.to_ulong() 返回它转换为unsigned long的结果,如果超出范围则报错
bit1.to_ullong() 返回它转换为unsigned long long的结果,如果超出范围则报错
bit1.to_string() 返回它转换为string的结果

例题

P1537 弹珠

这是非常显然的一个布尔背包问题,因为数据规模比较小,可以直接把多重背包拆分成01背包做。

比较朴素的背包转移方程就是:对于第 \(i\) 个物品,枚举所有可能得到的价值 \(j\) 来进行转移:

\[f_{i,j}|=f_{i-1,j-w_i} \]

显然第一维可以滚掉,那么我们就可以使用 bitset 优化dp,假设 f 是一个bitset,它的第x位为1就表示x可以被表示出来:

\[f|=(f<<i) \]

\(f<<i\) 表示取了有这么一些情况,\(f|=\) 就是不取的一些情况,那么 \(f|=(f<<i)\) 就是取与不取两种情况。
同理,只要是bool型背包应该都能用 bitset 优化。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1.2e5;
inline int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int T, sum;
bitset<N> f;
int a[10];
int main(){
	while(++T){
		sum = 0;
		for(int i = 1; i <= 6; ++i) a[i] = read(), sum += i * a[i];
		if(sum == 0) break;
		f.reset();
		printf("Collection #%d:\n", T);
		if(sum & 1){
			printf("Can't be divided.\n\n");
			continue;
		}
		bool flag = 0;
		f.set(0);
		for(int i = 1; i <= 6; ++i){
			for(int j = 1; j <= a[i]; ++j){
				f |= (f << i);
				if(f[sum / 2]) flag = 1;
			}
		}
		if(!flag) printf("Can't be divided.\n\n");
		else printf("Can be divided.\n\n");
	}
	return 0;
} 

P3674 小清新人渣的本愿

考虑用莫队离线操作。
然后用 bitset 维护,每一位x表示 x 这个数字是否存在,记为 b1。
在记录一个 b2 每一位x表示 N-x 这个数字是否存在。
对于 1 操作 ,如果存在\(z-y=x\),可以同时转化成同时存在 \(z\)\(z-x\),那么将 b1 左移 x(表示将每一个数都加上\(x\))与 b1做 & 预算,看是否有 1,有的话就表示存在。
对于 2 操作,b2 中的\(y′\)表示 b1 中的 \(N-y\),要求 \(z + y = x\),那么可以转化成 \(z + N - y'=x\)\(z - y' = x - N\) ,就又变成操作 1 的形式了。那么就只要将 b2 右移 \(N-x\)(这里相当于是将 b2 中的每一个数都加了 \(N-x\)) 与 b1 做 & 操作。
对于 3 操作,直接 \(O(\sqrt{n})\) 枚举约数。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5;
inline int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int n, m, len;
bitset<N + 51> b1, b2;
int a[N + 51], ans[N + 51], c[N + 51], pos[N + 51];
struct Query{
	int id, opt, l, r, x;
	bool operator < (const Query &A) const{return pos[l] == pos[A.l] ? (pos[l] & 1 ? r < A.r : r > A.r) : pos[l] < pos[A.l];}
}q[N + 51];
inline void add(int x){
	if(c[x]++ == 0) b1[x] = 1, b2[N - x] = 1;
}
inline void del(int x){
	if(--c[x] == 0) b1[x] = 0, b2[N - x] = 0;
}
int main(){
	n = read(), m = read(); len = sqrt(n);
	for(int i = 1; i <= n; ++i) a[i] = read(), pos[i] = (i - 1) / len + 1;
	for(int i = 1; i <= m; ++i) q[i] = (Query){i, read(), read(), read(), read()};
	sort(q + 1, q + 1 + m);
	for(int i = 1, l = 1, r = 0; i <= m; ++i){
		while(r < q[i].r) add(a[++r]);
		while(l > q[i].l) add(a[--l]);
		while(l < q[i].l) del(a[l++]);
		while(r > q[i].r) del(a[r--]);
		if(q[i].opt == 1) ans[q[i].id] = (b1 & (b1 << q[i].x)).any();
		else if(q[i].opt == 2) ans[q[i].id] = (b1 & (b2 >> (N - q[i].x))).any();
		else{
			for(int j = 1; j * j <= q[i].x; ++j){
				if(q[i].x % j) continue;
				if(b1[j] && b1[q[i].x / j]){
					ans[q[i].id] = 1; break;
				} 
			}
		}
	}
	for(int i = 1; i <= n; ++i) printf("%s\n", ans[i] ? "hana" : "bi");
	return 0;
} 

P5355 [Ynoi2017] 由乃的玉米田

加减乘三个操作上面都已讲。这里主要讲除法。

  • 如果 \(x \ge \sqrt{n}\)​,那么我们可以暴力枚举商,然后判断有没有出现即可。因为这个商 \(\le \sqrt{n}\)​,所以复杂度是正确的。
  • 如果 \(x < \sqrt{n}\),我们可以预处理出 \(x \in [1,\sqrt{n})\) 的答案。对于每个 \(x\) 遍历一遍序列,找出每个 \(1 \le i \le n\) 的离 \(i\) 最近且 \(\le i\)\(res_i​\),满足 \(a_i\)​ 和 $aa_{res_i}​​ 的商为 \(x\),于是每个询问的答案为 \(l\) 是否 \(\le res_{r}\)​。这一部分的时间复杂度为 \(O(n\sqrt{n})\)
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5;
inline int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int n, m, len, qcnt;
bitset<N + 51> b1, b2;
int a[N + 51], ans[N + 51], c[N + 51], pos[N + 51];
int pre[N + 51], res[N + 51];
vector<int> ql[N], qr[N], qi[N];
struct Query{
	int id, opt, l, r, x;
	bool operator < (const Query &A) const{return pos[l] == pos[A.l] ? (pos[l] & 1 ? r < A.r : r > A.r) : pos[l] < pos[A.l];}
}q[N + 51];
inline void add(int x){
	if(c[x]++ == 0) b1[x] = 1, b2[N - x] = 1;
}
inline void del(int x){
	if(--c[x] == 0) b1[x] = 0, b2[N - x] = 0;
}
inline void solve(){
	for(int x = 1; x <= len; ++x){
		if(qi[x].empty()) continue;
		int l = 0;
		for(int i = 1; i <= n; ++i){
			int y = a[i];
			pre[y] = i;
			if(x * y <= N) l = max(l, pre[x * y]);
			if(y % x == 0) l = max(l, pre[y / x]);
			res[i] = l;
		}
		for(int i = 0; i < qi[x].size(); ++i) ans[qi[x][i]] = (ql[x][i] <= res[qr[x][i]]);
		memset(pre, 0, sizeof(pre)), memset(res, 0, sizeof(res));
	}
}
int main(){
	n = read(), m = read(); len = sqrt(n);
	for(int i = 1; i <= n; ++i) a[i] = read(), pos[i] = (i - 1) / len + 1;
	for(int i = 1; i <= m; ++i){
		int opt = read(), l = read(), r = read(), x = read();
		if(opt == 4 && x <= 300) qi[x].push_back(i), ql[x].push_back(l), qr[x].push_back(r);
		else q[++qcnt] = (Query){i, opt, l, r, x};
	} 
	sort(q + 1, q + 1 + qcnt);
	for(int i = 1, l = 1, r = 0; i <= qcnt; ++i){
		while(r < q[i].r) add(a[++r]);
		while(l > q[i].l) add(a[--l]);
		while(l < q[i].l) del(a[l++]);
		while(r > q[i].r) del(a[r--]);
		if(q[i].opt == 1) ans[q[i].id] = (b1 & (b1 << q[i].x)).any();
		else if(q[i].opt == 2) ans[q[i].id] = (b1 & (b2 >> (N - q[i].x))).any();
		else if(q[i].opt == 3){
			for(int j = 1; j * j <= q[i].x; ++j){
				if(q[i].x % j) continue;
				if(b1[j] && b1[q[i].x / j]){
					ans[q[i].id] = 1; break;
				} 
			}
		}else{
			for(int j = 1; j * q[i].x <= N; ++j)
				if(b1[j] && b1[j * q[i].x]){
					ans[q[i].id] = 1; break;
				} 
		}
	}
	solve();
	for(int i = 1; i <= m; ++i) printf("%s\n", ans[i] ? "yuno" : "yumi");
	return 0;
} 

P6134 [JSOI2015]最小表示

对于每一条边 \((u,v)\),如果从 \(u\)\(v\) 仅存在这条路径,那么这条边一定要保存,否则一定可以删除,因为若存在另一条从 \(u\)\(v\) 的路线,那么一定存在一个不同于 \(u,v\) 的点 \(x\) ,可以使 \(u\)\(x\), \(x\)\(v\),那么显然边 \((u,v)\) 是可以删除的。为什么这样一定正确呢?因为我们保证了原图联通的点在删边操作仍是联通的。

那么接下来就要想办法对每条边判断是否存在点 \(x\) ,我们可以对每个点,处理出这个点可以到的点,以及有哪些点可以到这个点,那么就是可达性统计了。

接下来就是对 \(u\)\(v\),看 \(u\) 可到达的点和可以到 \(v\) 的点是否有相同的就可以了。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 3e4 + 51, M = 2e5 + 51;
inline int read(){
	int x = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') f = -f; ch = getchar();}
	while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
	return x * f;
}
int n, m, dfn, ans; 
int u[M], v[M], in[N], id[N];
vector<int> e1[N], e2[N];
bitset<N> to[N], co[N];
inline void topu(){
	queue<int> q;
	for(int i = 1; i <= n; ++i) if(!in[i]) q.push(i);
	while(!q.empty()){
		int x = q.front(); q.pop();
		id[++dfn] = x;
		for(int y : e1[x])
			if(--in[y] == 0) q.push(y);
	}
} 
int main(){
	n = read(), m = read();
	for(int i = 1; i <= m; ++i){
		u[i] = read(), v[i] = read();
		++in[v[i]];
		e1[u[i]].push_back(v[i]), e2[v[i]].push_back(u[i]);
	}
	topu();
	for(int i = n; i; --i){
		int x = id[i];
		for(int y : e1[x]) to[x][y] = 1, to[x] |= to[y];
	}
	for(int i = 1; i <= n; ++i){
		int x = id[i];
		for(int y : e2[x]) co[x][y] = 1, co[x] |= co[y];
	} 
	for(int i = 1; i <= m; ++i)
		if((to[u[i]] & co[v[i]]) != 0) ans++;
	printf("%d\n", ans);
	return 0;
}