Sum of XOR Functions 题解

发布时间 2023-12-19 12:14:48作者: Creeper_l

题意

给定一个数 \(n\) 和一个包含 \(n\) 个数的序列 \(a\),求出以下式子模 \(998244353\) 的值:

\(\sum_{i=1}^{n}\sum_{j=i}^{n} f(i,j)\times (j-i+1)\)

其中 \(f(i,j)\) 的值为 \(a_{i}\oplus a_{i+1}\oplus a_{i+2}\oplus...\oplus a_{j}\)

思路

首先我们可以考虑这道题的弱化版 P3917 异或序列,我们可以设 \(dp_{i,j}\) 表示第 \(i\) 个数的第 \(j\) 位对答案的贡献,也就是在 \(i\) 之前的每一个数的第 \(j\) 位有多少个数满足区间异或的值为一。那么我们考虑将每个 \(a_{i}\) 按照二进制位拆开,如果当前位置是 \(1\) 的话,那么 \(dp_{i,j}=i-dp_{i-1,j}\),因为只有前面的异或和为 \(0\) 时(所有数的数量减去异或和为 \(1\) 的数量),异或上当前位置的 \(1\) 答案才为 \(1\);否则如果当前位置的值是 \(0\),那么 \(dp_{i,j}=dp_{i-1,j}\),因为只有前面的异或和为 \(1\) 时,异或上当前位置的 \(0\) 答案才为 \(1\)

因为这里的每一个 \(dp_{i,j}\) 表示的是贡献的数量,所以还要乘上对应的值,也就是 \(2^{j}\),这样就可以得出答案了。注意 \(dp\) 数组可以滚掉一维,节省空间。

然后我们考虑这道题怎么做,其实比较好想了。我们可以对每一位计算贡献,对于每个数的每一位,可以多维护一个 \(s_{0,1}\) 表示 \(a_{1}\oplus a_{2}\oplus a_{3}\oplus...\oplus a_{i}\) 的值,\(cnt_{0/1}\) 表示有多少个数的 \(s_{j}\)\(1/0\)\(sum_{1/0}\) 表示所有 \(s_{j}\)\(0/1\) 的数到 \(1\) 的距离(也就是 \(j\))。然后每次加上相应的距离就行了。

Code

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define inf 0x3f
#define inf_db 127
#define ls id << 1
#define rs id << 1 | 1
#define re register
#define endl '\n'
typedef pair <int,int> pii;
const int MAXN = 3e5 + 10;
const int mod = 998244353;
int n,a[MAXN],s[MAXN],ans; 
signed main()
{
	cin >> n;
	for(int i = 1;i <= n;i++) cin >> a[i];
	for(int i = 0;i <= 30;i++)
	{
		s[1] = 0;
		for(int j = 1;j <= n;j++) s[j + 1] = s[j] ^ (a[j] >> i & 1);
		vector <int> cnt(2, 0);
    	vector <int> sum(2, 0);
    	for(int j = 1;j <= n + 1;j++)
    	{
    		ans = (ans + cnt[s[j] ^ 1] * j % mod * (1ll << i)) % mod;
			ans = (ans - sum[s[j] ^ 1] * (1ll << i) + mod) % mod;
    		cnt[s[j]]++;
    		sum[s[j]] = (sum[s[j]] + j) % mod;
		}
		ans = ans % mod;
	} 
	cout << ans << endl;
	return 0;
}