【题解】CF1616H Keep XOR Low

发布时间 2023-08-01 02:44:18作者: kymru

很好计数题,爱来自汐斯塔。

思路

01Trie 上 dp.

首先根据两两异或想到 01Trie,既然是计数自然考虑在 01Trie 上 dp.

先将 \(a\) 中的所有数插入 01Trie.

最直观的想法是按位 dp,也就是令 \(f[u]\) 表示 01Trie 上在 \(u\) 的子树内选取的合法方案数。

考虑 \(x\)\(u\) 所代表的二进制位上的取值,发现如果是 \(1\) 的话下面的两棵子树之间会相互限制,所以很难转移。

考虑改改状态,设成 \(f[u1][u2]\) 表示在 01Trie 上的两棵子树 \(u1, u2\) 内选取的方案数。

注意这里的 \(u1, u2\) 是同一层的结点,也就是代表的二进制位相同。并且我们钦定 \(u1, u2\) 在更高的二进制位上 异或的结果小于 \(x\).

至于这是怎么想到的,我不知道,懂哥麻烦教教。

然后考虑根据 \(x\)\(u1\) 子结点 代表的二进制位上的取值分类讨论:

先记 01Trie 中结点 \(u\) 的左子结点为 \(\operatorname{ls}(u)\),右子结点为 \(\operatorname{rs}(u)\),子树中存在数的个数为 \(\operatorname{sz}(u)\).

  • \(x\) 在这一位上是 \(0\).

    那么 \(\operatorname{ls}(u1)\)\(\operatorname{rs}(u2)\)\(\operatorname{ls}(u2)\)\(\operatorname{rs}(u1)\) 的子树不能同时选取。

    • 只在 \(u1\)\(u2\) 的子树内选,因为钦定的条件可以随意选,方案数为 \((2^{\operatorname{sz}(\operatorname{ls}(u1))} - 1) \cdot (2^{\operatorname{sz}(\operatorname{rs}(u1))} - 1) + (2^{\operatorname{sz}(\operatorname{ls}(u2))} - 1) \cdot (2^{\operatorname{sz}(\operatorname{rs}(u2))} - 1)\).(排除为空的情况)

    • \(u1, u2\) 的同向子结点的子树内选,方案数为 \(f[\operatorname{ls}(u1)][\operatorname{ls}(u2)] + f[\operatorname{rs}(u1)][\operatorname{rs}(u2)]\).

  • \(x\) 在这一位上是 \(1\).

    注意到 \(\operatorname{ls}(u1)\)\(\operatorname{ls}(u2)\),以及 \(\operatorname{rs}(u1)\)\(\operatorname{rs}(u2)\) 之间异或的值一定小于等于 \(x\)

    所以现在相互之间存在约束的只有 \(\operatorname{ls}(u1)\)\(\operatorname{rs}(u2)\),以及 \(\operatorname{ls}(u2)\)\(\operatorname{rs}(u1)\).

    这里有一个相当巧妙的结论。我们发现钦定其中一组约束成立的时候,另一组约束相较于这组约束是独立的。例如当前选取的方案满足 \(\operatorname{ls}(u2)\)\(\operatorname{rs}(u1)\) 的约束时(只在其中选取),已经选取的数不会影响到 \(\operatorname{ls}(u1)\)\(\operatorname{rs}(u2)\) 的选取方式。

    这意味着我们只需要分别求出满足这两种约束的方案总数,然后相乘就行。这里的意义实际上是在 \(\operatorname{ls}(u1), \operatorname{rs}(u1), \operatorname{ls}(u2), \operatorname{rs}(u2)\) 四棵子树内在满足两组约束的前提下选数。

    方案总数为 \(f[\operatorname{ls}(u1)][\operatorname{rs}(u2)] \cdot f[\operatorname{ls}(u2)][\operatorname{rs}(u1)]\).

边界条件是递归完所有的二进制位后可以任意选取(不存在约束条件)。

注意当 \(u1 = u2\) 的时候需要特判。

时间复杂度是 \(O(n \log |V|)\).

代码

#include <cstdio>
using namespace std;

typedef long long ll;

const int maxn = 1.5e5 + 5;
const int tr_sz = maxn * 40;
const int mod = 998244353;

int n, x;
int cnt_nd = 1;
int a[maxn], pw[maxn];
int son[tr_sz][2], sz[tr_sz];

void insert(int x)
{
	int u = 1;
	for (int i = 30; i >= 0; i--)
	{
		int c = (x >> i) & 1;
		if (!son[u][c]) son[u][c] = ++cnt_nd;
		sz[u]++, u = son[u][c];
	}
	sz[u]++;
}

int dfs(int u1, int u2, int dep)
{
	if ((!u1) || (!u2)) return pw[sz[u1 | u2]];
	if (u1 == u2)
	{
		if (dep < 0) return pw[sz[u1]];
		int ls = son[u1][0], rs = son[u1][1];
		if ((x >> dep) & 1) return dfs(ls, rs, dep - 1);
		return ((ll)dfs(ls, ls, dep - 1) + dfs(rs, rs, dep - 1) - 1 + mod) % mod;
	}
	if (dep < 0) return pw[sz[u1] + sz[u2]];
	int ls1 = son[u1][0], rs1 = son[u1][1], ls2 = son[u2][0], rs2 = son[u2][1];
	if ((x >> dep) & 1) return (ll)dfs(ls1, rs2, dep - 1) * dfs(rs1, ls2, dep - 1) % mod;
	int res = ((ll)dfs(ls1, ls2, dep - 1) + dfs(rs1, rs2, dep - 1) - 1 + mod) % mod;
	res = (res + ((ll)pw[sz[ls1]] - 1 + mod) * ((ll)pw[sz[rs1]] - 1 + mod) % mod) % mod;
	res = (res + ((ll)pw[sz[ls2]] - 1 + mod) * ((ll)pw[sz[rs2]] - 1 + mod) % mod) % mod;
	return res;
}

int main()
{
	scanf("%d%d", &n, &x);
	pw[0] = 1;
	for (int i = 1; i <= n; i++)
	{
		scanf("%d", &a[i]);
		insert(a[i]);
		pw[i] = (pw[i - 1] << 1) % mod;
	}
	int ans = (dfs(1, 1, 30) - 1 + mod) % mod;
	printf("%d\n", ans);
	return 0;
}