字节笔试题-区间异或

发布时间 2023-11-03 12:02:08作者: 完美二叉树

题目

给两个长度为n的数组a,b,请你计算出有多少个区间[l, r],满足 \(a_{l}\oplus a_{l+1}\oplus a_{l+2}\oplus \ldots \oplus a_{r} = b_{l}\oplus b_{l+2}\oplus b_{l+2}\oplus \ldots \oplus b_{r}\)\(\oplus\)表示按位异或)

输入描述

第一行输入一个整数n,表示数组长度。
第二行输入n个整数\(a_i\)
第三行输入n个整数\(b_i\)

\[\begin{aligned}1\leq n\leq 10^{5}\\ 1\leq a_{i},bi\leq 10^9\end{aligned} \]

示例

input

5
1 2 3 4 5
5 4 3 2 1

output

3

满足条件的区间有[1,5],[2,4],[3,3]

思路

这道题如果直接暴力的话,需要两层循环来确定区间边界,确定区间之后还要对区间求异或和,这样时间复杂度就是\(O(n^3)\)

首先能想到的优化方法是,对区间进行异或和是不是可以像求和一样,使用前缀和?

异或运算满足交换律和结合律,这样通过递推可以算出[0,i)区间的异或和,其中\(0<i<=n\)

那么通过[0,i)[0,j)能否算出[i,j)呢?

其实也是可以的,因为异或运算中\(x\oplus x=0\),所以有\([i, j) = [0, i) \oplus [0,j)\)

这样就节省了求异或和的时间,时间复杂度为\(O(n^2)\)

def f1(a: list, b: list, n: int) -> int:
    s = 0
    a_s = [0] + [s := s ^ i for i in a]
    s = 0
    b_s = [0] + [s := s ^ i for i in b]

    res = 0
    for i in range(n):
        for j in range(i + 1, n + 1):
            if a_s[j] ^ a_s[i] == b_s[j] ^ b_s[i]:
                res += 1
    return res

本来以为这样够优化了,结果提交代码只能过30%,时间复杂度还是太高。最后时间快到的时候才想到怎么优化

在上面的代码中,我用a_s[j] ^ a_s[i] == b_s[j] ^ b_s[i]这个式子来判断 a b 两个数组在[i, j)区间上的异或和是否相等,其实这个式子可以变形一下,变成a_s[i] ^ b_s[i] == a_s[j] ^ b_s[j]和前面的式子是等效的

所以这个问题可以转化为:已知边界 j ,求有多少个 i ,满足a_s[i] ^ b_s[i] == a_s[j] ^ b_s[j]。这样,我们就可以用哈希表记录有多少个a_s[i] ^ b_s[i],key为a_s[i] ^ b_s[i],value为个数,为了保证 i < j 我们在计算前缀和的同时,一边算个数,一边用哈希表统计。

def f2(a: list, b: list, n: int) -> int:
    # a_s[j] ^ a_s[i] == b_s[j] ^ b_s[i] 等价为 a_s[i] ^ b_s[i] == a_s[j] ^ b_s[j]
    a_s = [0 for i in range(n + 1)]
    b_s = [0 for i in range(n + 1)]
    d = {0: 1}
    res = 0
    for i in range(n):
        a_s[i + 1] = a_s[i] ^ a[i]
        b_s[i + 1] = b_s[i] ^ b[i]
        s = a_s[i + 1] ^ b_s[i + 1]
        res += d.get(s, 0)
        d[s] = d.get(s, 0) + 1
    return res

这样时间复杂度为\(O(n)\),空间复杂度为\(O(n)\)
其实不需要用一个数组来存前缀和了,进一步优化

def f3(a: list, b: list, n: int) -> int:
    d = {0: 1}
    s = res = 0
    for i in range(n):
        s ^= a[i] ^ b[i]
        res += d.get(s, 0)
        d[s] = d.get(s, 0) + 1
    return res

为了验证算法的正确性,我写了一个程序对随机构造的数据进行验证

import random


def f1(a: list, b: list, n: int) -> int:
    s = 0
    a_s = [0] + [s := s ^ i for i in a]
    s = 0
    b_s = [0] + [s := s ^ i for i in b]

    res = 0
    for i in range(n):
        for j in range(i + 1, n + 1):
            if a_s[j] ^ a_s[i] == b_s[j] ^ b_s[i]:
                res += 1
    return res


def f2(a: list, b: list, n: int) -> int:
    # a_s[j] ^ a_s[i] == b_s[j] ^ b_s[i] 等价为 a_s[i] ^ b_s[i] == a_s[j] ^ b_s[j]
    a_s = [0 for i in range(n + 1)]
    b_s = [0 for i in range(n + 1)]
    d = {0: 1}
    res = 0
    for i in range(n):
        a_s[i + 1] = a_s[i] ^ a[i]
        b_s[i + 1] = b_s[i] ^ b[i]
        s = a_s[i + 1] ^ b_s[i + 1]
        res += d.get(s, 0)
        d[s] = d.get(s, 0) + 1
    return res


def f3(a: list, b: list, n: int) -> int:
    d = {0: 1}
    s = res = 0
    for i in range(n):
        s ^= a[i] ^ b[i]
        res += d.get(s, 0)
        d[s] = d.get(s, 0) + 1
    return res


def main():
    for i in range(100):
        n = random.randint(1, 10 ** 3)
        a = [random.randint(1, 10 ** 5) for i in range(n)]
        b = [random.randint(1, 10 ** 5) for i in range(n)]
        if f1(a, b, n) != f3(a, b, n):
            print('error')
            print(n)
            print(a)
            print(b)
            break


main()

经过验证,优化后的算法结果上和优化前的一致,说明算法是正确的

对三种方法测时,可以看到我们优化后对时间的提升是非常大的