CF1815D XOR Counting 题解

发布时间 2023-08-26 06:57:27作者: User-Unauthorized

题意

给定 \(n, m\),对于所有满足 \(\displaystyle \left(\sum\limits_{i = 1}^{m}a_i\right) = n\) 的非负整数序列 \(a_m\),求所有可能的 \(\displaystyle \bigoplus\limits_{i = 1}^{m} a_i\) 的值的和,相同的异或和值只计算一次。

\(1 \le n \le 10^{18}, 1 \le m \le 10^5, 1 \le T \le 10^4\))。

题解

\(m\) 分类讨论。


\(m = 1\)

直接输出 \(n\) 即可。

\(m \ge 3\)

考虑将 \(n\) 划为形如 \(x, \dfrac{n - x}{2}, \dfrac{n - x}{2}, 0, 0, \cdots\) 的若干整数,可以发现 \(x\) 这个序列的和为 \(n\),异或和为 \(x\),同时 \(x\) 可以取到所有与 \(n\) 奇偶性相同且不大于 \(n\) 的非负整数。最终答案为所有可能的 \(x\) 的和,即

\[\sum\limits_{x \le n} \left[x \equiv n \bmod 2\right] x \]

下面论证不存在与 \(n\) 奇偶性不同且符合条件的 \(x\)

观察到

  • 奇数与奇数相加为偶数,奇数与奇数相异或也为偶数;
  • 偶数与奇数相加为奇数,偶数与奇数相异或也为奇数;
  • 偶数与偶数相加为偶数,偶数与偶数相异或也为偶数;

也就是加法运算和异或运算奇偶性相同,故最终异或和不可能与 \(n\) 奇偶性不同。

\(m = 2\)

考虑减小问题规模,设 \(x, y\) 为符合要求的一组解, \(f(n)\) 表示 \(n\) 的答案,\(g(n)\) 表示 \(n\) 的答案数,即可能的异或值数量。因为 \(n\) 的奇偶性确定后,两个数的奇偶性也会确定,故按 \(n\) 的奇偶性分类讨论。

\(n\) 为奇数

\(x, y\) 中必定有一个奇数和一个偶数,假定 \(x\) 为奇数,\(y\) 为偶数,设 \(x^{\prime} = \dfrac{x - 1}{2}, y^{\prime} = \dfrac{y}{2}\),那么有

\[\begin{aligned} x^{\prime} + y^{\prime} &= \dfrac{n - 1}{2}\\ x \oplus y &= 2 \times \left(x^{\prime} \oplus y^{\prime}\right) + 1 \end{aligned}\]

那么有转移

\[\begin{aligned} f(n) &= 2 \times f(\dfrac{n - 1}{2}) + g(\dfrac{n - 1}{2})\\ g(n) &= g(\dfrac{n - 1}{2}) \end{aligned}\]

\(n\) 为偶数

\(x, y\) 奇偶性一定相同。

\(x, y\) 均为偶数,那么设 \(x^{\prime} = \dfrac{x}{2}, y^{\prime} = \dfrac{y}{2}\),有

\[\begin{aligned} x^{\prime} + y^{\prime} &= \dfrac{n}{2}\\ x \oplus y &= 2 \times \left(x^{\prime} \oplus y^{\prime}\right) \end{aligned}\]

\(x, y\) 均为偶数,那么设 \(x^{\prime} = \dfrac{x - 1}{2}, y^{\prime} = \dfrac{y - 1}{2}\),有

\[\begin{aligned} x^{\prime} + y^{\prime} &= \dfrac{n}{2} - 1\\ x \oplus y &= 2 \times \left(x^{\prime} \oplus y^{\prime}\right) \end{aligned}\]

综合两种情况,有转移

\[\begin{aligned} f(n) &= 2 \times f(\dfrac{n}{2}) + 2 \times f(\dfrac{n}{2} - 1)\\ g(n) &= 2 \times g(\dfrac{n}{2}) + 2 \times g(\dfrac{n}{2} - 1) \end{aligned}\]

递归处理即可,总复杂度 \(\mathcal{O}(\log n)\),可以通过本题。

Code

//Codeforces - 1815D
#include <bits/stdc++.h>

typedef long long valueType;
typedef std::pair<valueType, valueType> ValuePair;
typedef std::map<valueType, ValuePair> Memory;

constexpr valueType MOD = 998244353;

template<typename T1, typename T2, typename T3 = valueType>
void Inc(T1 &a, T2 b, const T3 &mod = MOD) {
    a = a + b;

    if (a >= mod)
        a -= mod;
}

template<typename T1, typename T2, typename T3 = valueType>
void Dec(T1 &a, T2 b, const T3 &mod = MOD) {
    a = a - b;

    if (a < 0)
        a += mod;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 sum(T1 a, T2 b, const T3 &mod = MOD) {
    return a + b >= mod ? a + b - mod : a + b;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 sub(T1 a, T2 b, const T3 &mod = MOD) {
    return a - b < 0 ? a - b + mod : a - b;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 mul(T1 a, T2 b, const T3 &mod = MOD) {
    return (long long) a * b % mod;
}

template<typename T1, typename T2, typename T3 = valueType>
void Mul(T1 &a, T2 b, const T3 &mod = MOD) {
    a = (long long) a * b % mod;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 pow(T1 a, T2 b, const T3 &mod = MOD) {
    T1 result = 1;

    while (b > 0) {
        if (b & 1)
            Mul(result, a, mod);

        Mul(a, a, mod);
        b = b >> 1;
    }

    return result;
}

Memory memory;

ValuePair solve(valueType n) {
    if (memory.count(n))
        return memory[n];

    if (n == 0)
        return memory[n] = std::make_pair(0, 1);

    if (n == 1)
        return memory[n] = std::make_pair(1, 1);

    if (n & 1) {
        auto const result = solve(n >> 1);

        return memory[n] = std::make_pair(sum(mul(result.first, 2), result.second), result.second);
    } else {
        auto const A = solve(n / 2), B = solve(n / 2 - 1);

        return memory[n] = std::make_pair(mul(sum(A.first, B.first), 2), sum(A.second, B.second));
    }
}

constexpr valueType Inv2 = 499122177;

int main() {
    valueType T;

    std::cin >> T;

    for (valueType testcase = 0; testcase < T; ++testcase) {
        valueType N, M;

        std::cin >> N >> M;

        if (N == 0) {
            std::cout << 0 << '\n';
        } else if (M == 1) {
            std::cout << (N % MOD) << '\n';
        } else if (M == 2) {
            memory.clear();

            std::cout << (solve(N).first % MOD) << '\n';
        } else {
            if (N & 1) {
                std::cout << (mul((N + 1) % MOD, mul(((N + 1) / 2) % MOD, Inv2))) << '\n';
            } else {
                std::cout << (mul((N + 2) % MOD, mul((N / 2) % MOD, Inv2))) << '\n';
            }
        }
    }

    std::cout << std::flush;

    return 0;
}