Educational Codeforces Round 155 D (CF1879_D)

发布时间 2023-09-25 23:44:42作者: 过_路_人

题目大意

给一个长度为 \(n\) 的数组,求 \(\Sigma_{i=1}^{n} \Sigma_{j=i}^{n} 区间异或和 \times (j-i+1)\)
其中 \(n\leq 3e5,~a[i]\leq 1e9\)

分析

首先注意到由 \(l\)\(r\) 的区间异或和可以转化为 \(sum_{l-1}~XOR~sum_r\)
因此,对于每一个点 \(x\) ,无论它作为上述的 \(sum_{i-1}\) 还是 \(sum_j\) ,如果它的某个二进制位(假设为第 \(k\) 位)为 \(1\) 而另一个值(\(sum_{i-1}\)\(sum_r\) ) 的这一位为 \(0\) ,那么点 \(i\) 的第 \(j\) 位会对整体答案产生 \((1<<j)\times 某个值\) 的贡献

现在考虑这个“某个值”应该怎么算:

  • 对于点 \(x\) 的第 \(k\) 位,这个值是所有第 \(k\) 位为 \(0\) 的点的距离之和

因此设 \(f_{i,j}\) 表示所有第 \(j\) 位为 \(0\) 的点与第 \(i\) 个点的距离之和,考虑从 \(f_{i-1,j}\)\(f_{i,j}\) 的转移:

  • 对于 \(i-1\) 以前的点(包含):每个点的距离会增加 \(1\)
  • 对于 \(i\) 及以后得点:每个点的距离会减少 \(1\)

那么考虑用一个 \(cnt_{i,j}\) 记录 \(i\) 及以前有多少个点的第 \(j\) 位是 \(0\) (这个很好记录),再记录一下所有数中第 \(j\) 位为 \(0\) 的数有多少个(记为 \(sumbit_j\)

\(f_{i,j}=f_{i-1,j}+cnt_{i-1,j}-(sumbit_j-cnt{i-1,j})\)

计算完 \(f_{i,j}\) 后,如果第 \(i\) 个数的第 \(j\) 位是 \(1\),答案加上 \((1<<j)\times f_{i,j}\) 即可

Code:

#include<bits/stdc++.h>
#define IO ios::sync_with_stdio(false); cin.tie(0)
using namespace std;
using ll = long long;
using ull = unsigned long long;
const ll mod = 998244353;
const int N = 3e5 + 5, M = 32;
ll f[N][M], cnt[N][M], a[N], sum[N], ans, sumbit[N];
int n, maxbit;

int main() {
    // freopen("d.txt", "r", stdin);
    IO;
    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        sum[i] = sum[i-1] ^ a[i];
        maxbit = max(maxbit, int(log2(sum[i])));
    }
    maxbit = 30;
    for (int i = 0; i <= n; i++) {
        for (int j = 0; j <= maxbit; j++) {
            if (i != 0) 
                cnt[i][j] = cnt[i-1][j];
            if (!(sum[i] & (1 << j))) {
                sumbit[j]++, cnt[i][j]++;
                f[0][j] = (f[0][j] + i) % mod;
            }   
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j <= maxbit; j++) {
            f[i][j] = ((f[i-1][j] + 2 * cnt[i-1][j]) % mod - sumbit[j] + mod) % mod;
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 0; (1 << j) <= sum[i]; j++) {
            if (sum[i] & (1 << j)) {
                ans += (1ll << j) * f[i][j];
                ans = (ans % mod + mod) % mod;
            }
        }
    }
    cout << ans;
    return 0;
}