CF1761D Carry Bit

发布时间 2023-07-20 17:26:43作者: Ender_32k

Description

\(f(x,y)\)\(x+y\) 的二进制进位数(即 \(f(x,y)=g(x)+g(y)-g(x+y)\) ,其中 \(g(x)\)\(x\) 的二进制表示中 \(1\) 的个数)。

给定两个整数 \(n\)\(k\) ,求出满足\(0 \leq a,b < 2^n\)\(f(a,b) = k\) 的有序对 \((a,b)\) 的个数。注意,对于\(a\neq b\)\((a,b)\)\((b,a)\) 被认为是两个不同的对。

由于这个数可能很大,对它取模 \(10^9+7\) 输出。

\(0\le k<n\le 10^6\)

Solution

考虑什么时候两个数 \(x+y\)\(i\) 处有进位:

  • \(x_i=y_i=1\),则必然有进位
  • \(x_i\ \text{xor}\ y_i=1\),并且之前一位有进位,也会有进位。

但是之前一位是否有进位你是不知道的,所以以上分类讨论还不能发挥什么作用。

那我们就连着上一位一同分类:

  • 上一位没有进位,这一位也没有进位,那么 \(x,y\) 所对应的位可以是 \((0,0),(0,1),(1,0)\)
  • 上一位有进位,这一位没有进位:\(x,y\) 只能对应 \((0,0)\)
  • 上一位没有进位,这一位有进位:\(x,y\) 只能对应 \((1,1)\)
  • 上一位有进位,这一位也有进位:\(x,y\) 能够对应 \((1,0),(0,1),(1,1)\)

我们发现当且仅当上一位和这一位的进位状态相同时,这种方案对答案的贡献才能增加。为了方便,不妨使用一个长度 \(n+1\)\(01\) 序列 \(t\) 记录 \(0\)\(n\) 位的进位状态。显然有 \(t_{n}=0\) 并且这一位其实是不需要的,但是为了方便我们可以加入这一位,这样的话我们知道当 \(f_i\neq f_{i+1}\) 的个数一定时必然是 \(0\) 的连续段更多。

于是我们设有 \(i\)\(f_i\neq f_{i+1}\) 的地方,那么一共有 \(i+1\) 个段,那么 \(0\) 的连续段数为 \(\left\lceil\dfrac{i+1}{2}\right\rceil\)\(1\) 的连续段个数就是 \(\left\lfloor\dfrac{i+1}{2}\right\rfloor\)

由于有 \(k\) 次进位,所以有 \(k\)\(1\);由于一共 \(n+1\) 个位,所以 \(0\) 就有 \(n+1-k\) 个。

于是变成长度为 \(n\) 的序列分成 \(m\) 段的方案数,显然是 \(\dbinom{n-1}{m-1}\)

由于有 \(i\)\(f_i\neq f_{i+1}\),那么有 \(n-i\)\(f_i=f_{i+1}\),那么有 \(3^{n-i}\) 的贡献。

于是答案就是:

\[ans=\sum\limits_{i=0}^{n}3^{n-i}\dbinom{n-k}{\left\lceil\dfrac{i+1}{2}\right\rceil-1}\dbinom{k-1}{\left\lfloor\dfrac{i+1}{2}\right\rfloor-1} \]

直接 \(O(n\log n)\) 做即可,预处理 \(3\) 的幂优化至 \(O(n)\)

#include <bits/stdc++.h>
#define int long long
namespace mystd {
	inline int read() {
	    char c = getchar();
	    int x = 0, f = 1;
	    while (c < '0' || c > '9') f = (c == '-') ? -1 : f, c = getchar();
	    while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + c - '0', c = getchar();
	    return x * f;
	}
	inline void write(int x) {
	    if (x < 0) x = ~(x - 1), putchar('-');
	    if (x > 9) write(x / 10);
	    putchar(x % 10 + '0');
	}
}
using namespace std;
using namespace mystd;

const int maxn = 1e6 + 100;
const int mod = 1e9 + 7;
int n, k, fac[maxn], ifac[maxn], inv[maxn];;

int qpow(int p, int q) {
	int res = 1;
	while (q) {
		if (q & 1) res = res * p % mod;
		p = p * p % mod, q >>= 1;
	}
	return res;
}

int C(int n, int m) { return n < m ? 0 : fac[n] * ifac[m] % mod * ifac[n - m] % mod; }

signed main() { 
	n = read(), k = read();
	fac[0] = ifac[0] = inv[1] = 1;
	for (int i = 1; i <= n; i++) {
		if (i > 1) inv[i] = inv[mod % i] * (mod - mod / i) % mod;
		fac[i] = fac[i - 1] * i % mod, ifac[i] = ifac[i - 1] * inv[i] % mod;
	}
	if (k == 0) return write(qpow(3, n)), 0;
	int ans = 0;
	for (int i = 0; i <= n; i++) {
		(ans += qpow(3, n - i) * C(n - k, (i + 2) / 2 - 1) % mod * C(k - 1, i + 1 - ((i + 2) / 2) - 1) % mod) %= mod;
	}
	write(ans);
	return 0;
}