星空 (Easy version & Hard Version) 题解

发布时间 2023-09-23 22:26:23作者: 霜木_Atomic

星空 (Easy version & Hard Version) 题解

不知道简单版有没有单独的做法,反正我不会

很明显如果 \(a\) 中有大于 \(x\) 的数直接无解,输出 \(0\)

发现每个 \(a_i\) 都是 \(2\) 的整数次幂,这告诉我们每个 \(a_i\) 在二进制表示下只会有一位上是 \(1\),那么,相邻的两个数相加,最多就是进一个位。

然后我们来考虑 \(x\)。假如 \(x\) 的最高位 \(1\) 和次高位 \(1\) 分别在 \(i\) 位和 \(j\) 位上。由于没有大于 \(x\) 的数,所以现在 \(a\) 中最大的数也不会超过 \(2^i\)。我们来考虑这些数怎么放是合法的:

  • 对于等于 \(2^i\) 的数,必须分开;
  • 其他数彼此可以相邻,因为即使是 \(2^{i-1} + 2^{i-1} = 2^i\) 也小于等于 \(x\)
  • 然后来考虑 \(2^i\) 和其他数的关系,发现只需要 \(2^i\) 与其他数中大于 \(2^j\) 的数不相邻即可。

这样就可以利用插板法来解决问题了。设等于 \(2^i\) 的数有 \(s_1\) 个,小于 \(2^i\),大于 \(2^j\) 的数有 \(s^2\) 个,剩下的数有 \(s_3\) 个。以下分别简称位第一、二、三类数。

首先这些数内部可以随便排列,所以先有 \(s_1 !s_2 ! s_3 !\)

然后考虑把第一类数向第三类数里插入,可以插入的位置有 \(s_3+1\) 个,所以方案数为 \(s_3+1 \choose s_1\)

最后考虑把第二类数往序列里插入。因为第二类数彼此可以相邻,而不可以与第一类数相邻,所以只需要把第一类数占据的位置去掉,考虑往剩下的位置插入。这时问题就变成了向 \(s_3 + 1 - s_1\) 个有编号的桶中放入 \(s_2\) 个无编号的球,桶可以空,求方案数。插板法经典问题,方案为 \(s_2 + s_3 - s_1 \choose s_3 - s_1\)

所以最后的答案就是 \(s_1 !s_2 ! s_3 ! {s_3+1 \choose s_1} {s_2 + s_3 - s_1 \choose s_3 - s_1}\)

注意求组合数的时候,即使 \(n\)\(m\) 都小于 \(0\),只要 \(n = m\),结果也要是 \(1\)(因为这个吃了一发罚时 xwx)。

代码:

#include<bits/stdc++.h>

namespace IO2{
	int read(){
		int x = 0, f = 1; char ch = getchar();
		while(ch<'0' || ch>'9') {if(ch == '-') f = -1; ch = getchar();}
		while(ch>='0'&&ch<='9') {x = x * 10 + ch - 48, ch = getchar();}
		return x * f;
	}
}
using IO2::read;
#define ll long long
using namespace std;

const int N = 1e5+100;

int fac[N], inv[N];
ll n, X;
ll a[N];
const int mod = 1e9+7;

int C(int tn, int tm) {
	if(tn == tm) return 1;
	if(tn < 0 || tm < 0) return 0;
	if(tn < tm) return 0;
	return 1ll*fac[tn] * inv[tm]%mod * inv[tn-tm]%mod;
}
int fpow(int a, int b) {
	a%=mod;
	int ret = 1;
	while(b) {
		if(b & 1) {
			ret = 1ll*ret*a%mod;
		}
		b>>=1;
		a = 1ll*a*a%mod; 
	}
	return ret;
}

int fir, sec;
int hi[N];

int main() {
	scanf("%lld%lld", &n, &X);
	for(int i = 1; i<=n; ++i) {
		scanf("%lld", &a[i]);
		for(int j = 63; j>=1; --j) {
			if((a[i] >> (j-1)) & 1) {
				hi[i] = j;
				break;
			}
		}
	}
	fac[0] = 1;
	for(int i = 1; i<=n + 2; ++i) {
		fac[i] = 1ll*fac[i-1] * i%mod;
	}
	inv[n + 2] = fpow(fac[n + 2], mod-2);
	for(int i = n+1; i>=0; --i) {
		inv[i] = 1ll*inv[i+1]*(i+1)%mod;
	}
	for(int i = 1; i<=n; ++i) {
		if(a[i] > X) {
			puts("0");
			return 0;
		}
	} 
	for(int i = 63; i>=1; --i) {
		if((X >> (i-1)) & 1) {
			if(!fir) {
				fir = i;
			} else if(!sec) {
				sec = i;
			}
		} 
	}
	int cnt1 = 0, cnt2 = 0, cnt3 = 0;
	for(int i = 1; i<=n; ++i){
		if(hi[i] == fir) {
			++cnt1;
		} else if(hi[i] > sec) {
			++cnt2;
		} else ++cnt3;
	}
	int ans = 1ll*fac[cnt1] * fac[cnt2]%mod * fac[cnt3]%mod * C(cnt3+1, cnt1) %mod * C(cnt2 + cnt3 - cnt1, cnt3 - cnt1)%mod;
	printf("%d\n", ans);
	return 0;
}