【算法】【线性表】两个排序数组的中位数

发布时间 2023-12-11 08:26:31作者: 酷酷-

1  题目

两个排序的数组AB分别含有mn个数,找到两个排序数组的中位数,要求时间复杂度应为 O(log(m + n))

中位数的定义:

  • 这里的中位数等同于数学定义里的中位数
  • 中位数是排序后数组的中间值。
  • 如果有数组中有n个数且n是奇数,则中位数为 A((n-1)/2)。
  • 如果有数组中有n个数且n是偶数,则中位数为 (A((n-1)/2) + A((n-1)/2+1))/2
  • 比如:数组A=[1,2,3]的中位数是2,数组A=[1,19]的中位数是10。

样例 1:

输入:

A = [1,2,3,4,5,6]
B = [2,3,4,5]

输出:

3.50

解释:合并后的数组为[1,2,2,3,3,4,4,5,5,6],中位数为(3 + 4) / 2。

样例 2:

输入:

A = [1,2,3]
B = [4,5]

输出:

3.00

解释:合并后的数组为[1,2,3,4,5],中位数为3。

2  解答

2.1  复杂度 O(m+n)

两个有序数组,定义两个指针分别指向两个数组,然后遍历,找到中间位置的两个数,第一遍这样的:

public class Solution {
    /**
     * @param a: An integer array
     * @param b: An integer array
     * @return: a double whose format is *.5 or *.0
     */
    public double findMedianSortedArrays(int[] a, int[] b) {
        // write your code here
        double res = 0.0;
        // 1、如果两个数组都为空 返回0
        if (a == null && b == null) {
            return formatDouble(res);
        }
        // 2、如果 a 为空那么就直接计算 b 的中位数
        if (a == null || a.length == 0) {
            return findMedianSortedArraysWithOne(b);
        }
        // 3、如果 b 为空,那么就直接计算 a 的中位数
        if (b == null || b.length == 0) {
            return findMedianSortedArraysWithOne(a);
        }
        // 4、就剩a b 都不为空的情况了
        // 
        int target = (a.length + b.length - 1) / 2;
        int start = 0;
        int aIndex = 0;
        int bIndex = 0;
        // true 取 a false 取 b
        boolean flag = false;
        while (start < target && aIndex < a.length && bIndex < b.length) {
            if (a[aIndex] <= b[bIndex]) {
                aIndex ++;
                start ++;
                flag = true;
                continue;
            }
            bIndex ++;
            start ++;
            flag = false;
        }
        System.out.println(String.format("%s-%s-%s-%s-%s", aIndex, bIndex, start, target, flag));
        if (aIndex >= a.length) {
            flag = false;
        }
        while (start < target && aIndex < a.length) {
            aIndex ++;
            start ++;
            flag = true;
        }
        while (start < target && bIndex < b.length) {
            bIndex ++;
            start ++;
            flag = false;
        }
        System.out.println(String.format("%s-%s-%s-%s-%s", aIndex, bIndex, start, target, flag));
        if (flag) {
            res += a[aIndex];
        } else {
            res += b[bIndex];
        }

        if ((a.length + b.length) % 2 != 0) {
            return formatDouble(res);
        } else {
            if (flag) {
                if (aIndex + 1 < a.length && a[aIndex+1] < b[bIndex]) {
                    res += a[aIndex + 1];
                } else {
                    res += b[bIndex];
                }
               
            } else {
                if (bIndex + 1 < b.length && b[bIndex+1] < a[aIndex]) {
                    res += b[bIndex + 1];
                } else {
                    res += a[aIndex];
                }
            }
            return formatDouble(res / 2.0);
        }
    }

    // 格式化 double
    public double formatDouble(double a) {
        return Double.valueOf(String.format("%.1f", a));
    }

    // 当只有一个数组的情况时,返回某个数组的中位数
    // 比如 a 为空 或者 b为空
    public double findMedianSortedArraysWithOne(int[] arr) {
        // 默认返回结果
        double res = 0.0;
        if (arr == null || arr.length == 0) {
            return formatDouble(res);
        }
        if (arr.length <= 1) {
            return arr[0];
        }
        int len = arr.length;
        int middle = (len-1) / 2;
        // 说明有偶数个
        if (len % 2 == 0) {
            return formatDouble((arr[middle] + arr[middle + 1]) / 2.0);
        } else {
            return formatDouble(arr[middle]);
        }
    }
}

上边这个有点啰嗦,优化后的:

public class Solution {
    /**
     * @param a: An integer array
     * @param b: An integer array
     * @return: a double whose format is *.5 or *.0
     */
    public double findMedianSortedArrays(int[] a, int[] b) {
        // write your code here
        int count = a.length + b.length;
        // 因为要多找一次 所以 + 1  
        // 长度5 (5-1)/ 2 = 2 索引为2的位置  找三次 0 1 2
        // 长度6 (6-1)/ 2 = 2 索引为2、3的位置  找四次 0 1 2 3
        int target = (count - 1) / 2 + 1;
        int start = 0;
        // 分别定位 a b 两个数组的索引位置
        int aIndex = 0, bIndex = 0;
        // num 记录中位数 moreNum 表示多找一个
        int num = 0, moreNum = 0;
        while (start <= target) {
            System.out.println(String.format("%s-%s-%s-%s", num, moreNum, aIndex, bIndex));
            num = moreNum;
            // 当两个数组加起来,只有一个元素的时候,多找一次的话会数组越界,所以这里直接退出
            if (bIndex >= b.length && aIndex >= a.length) break;
            if (bIndex >= b.length || (aIndex < a.length && a[aIndex] <= b[bIndex])) {
                moreNum = a[aIndex];
                aIndex ++;
                start ++;
                continue;
            }
            moreNum = b[bIndex];
            bIndex ++;
            start ++;
        }
        if (count % 2 != 0) {
            return num;
        } else {
            return (num + moreNum) / 2.0;
        }
    }
}

2.2  复杂度 O(log(m+n))

上边的复杂度大,是因为我们一次遍历就相当于去掉不可能是中位数的一个值,也就是一个一个排除。由于数列是有序的,其实我们完全可以一半儿一半儿的排除。假设我们要找第 k 小数,我们可以每次循环排除掉 k/2 个数。看下边一个例子。

假设我们要找第 7 小的数字:

我们比较两个数组的第 k/2 个数字,如果 k 是奇数,向下取整。也就是比较第 333 个数字,上边数组中的 444 和下边数组中的 333,如果哪个小,就表明该数组的前 k/2 个数字都不是第 k 小数字,所以可以排除。也就是 111,222,333 这三个数字不可能是第 777 小的数字,我们可以把它排除掉。将 134913491349 和 456789104567891045678910 两个数组作为新的数组进行比较。

更一般的情况 A[1] ,A[2] ,A[3],A[k/2] ... ,B[1],B[2],B[3],B[k/2] ... ,如果 A[k/2]<B[k/2] ,那么A[1],A[2],A[3],A[k/2]都不可能是第 k 小的数字。

A 数组中比 A[k/2] 小的数有 k/2-1 个,B 数组中,B[k/2] 比 A[k/2] 小,假设 B[k/2] 前边的数字都比 A[k/2] 小,也只有 k/2-1 个,所以比 A[k/2] 小的数字最多有 k/1-1+k/2-1=k-2个,所以 A[k/2] 最多是第 k-1 小的数。而比 A[k/2] 小的数更不可能是第 k 小的数了,所以可以把它们排除。

橙色的部分表示已经去掉的数字。

由于我们已经排除掉了 3 个数字,就是这 3 个数字一定在最前边,所以在两个新数组中,我们只需要找第 7 - 3 = 4 小的数字就可以了,也就是 k = 4。此时两个数组,比较第 2 个数字,3 < 5,所以我们可以把小的那个数组中的 1 ,3 排除掉了。

我们又排除掉 2 个数字,所以现在找第 4 - 2 = 2 小的数字就可以了。此时比较两个数组中的第 k / 2 = 1 个数,4 == 4,怎么办呢?由于两个数相等,所以我们无论去掉哪个数组中的都行,因为去掉 1 个总会保留 1 个的,所以没有影响。为了统一,我们就假设 4 > 4 吧,所以此时将下边的 4 去掉。

由于又去掉 1 个数字,此时我们要找第 1 小的数字,所以只需判断两个数组中第一个数字哪个小就可以了,也就是 4。

所以第 7 小的数字是 4。

我们每次都是取 k/2 的数进行比较,有时候可能会遇到数组长度小于 k/2的时候。

此时 k / 2 等于 3,而上边的数组长度是 2,我们此时将箭头指向它的末尾就可以了。这样的话,由于 2 < 3,所以就会导致上边的数组 1,2 都被排除。造成下边的情况。

由于 2 个元素被排除,所以此时 k = 5,又由于上边的数组已经空了,我们只需要返回下边的数组的第 5 个数字就可以了。

从上边可以看到,无论是找第奇数个还是第偶数个数字,对我们的算法并没有影响,而且在算法进行中,k 的值都有可能从奇数变为偶数,最终都会变为 1 或者由于一个数组空了,直接返回结果。

所以我们采用递归的思路,为了防止数组长度小于 k/2,所以每次比较 min(k/2,len(数组) 对应的数字,把小的那个对应的数组的数字排除,将两个新数组进入递归,并且 k 要减去排除的数字的个数。递归出口就是当 k=1 或者其中一个数字长度是 0 了。

嘿嘿,上边的例子,是看的一些题解然后了解二分查找的作用思路后,我自己写的哈:

public class Solution {
    /**
     * @param a: An integer array
     * @param b: An integer array
     * @return: a double whose format is *.5 or *.0
     */
    public double findMedianSortedArrays(int[] a, int[] b) {
        // 长度之和
        int count = a.length + b.length;
        if (count % 2 != 0) {
            // 奇数情况,返回中间位置
            return getKth(a, b, 0, a.length-1, 0, b.length-1, (count-1) / 2 + 1);
        } else {
            // 偶数情况,中间两个数求和/2.0
            int num1 = getKth(a, b, 0, a.length-1, 0, b.length-1, (count-1) / 2 + 1);
            int num2 = getKth(a, b, 0, a.length-1, 0, b.length-1, (count-1) / 2 + 2);
            System.out.println(num1);
            System.out.println(num2);
            return (num1 + num2) / 2.0;
        }
    }

    /**
     * 找到数组中第 K 个大的数
     * K从1开始
     */
    public int getKth(int[] a, int[] b, int aStart, int aEnd, int bStart, int bEnd, int k) {
        // 当 k <= 1 也就是也就是取一个了
        if (k <= 1) {
            return getValue(a, b, aStart, aEnd, bStart, bEnd);
        }
        // 如果 a 数组遍历完了,直接返回 b 数组的
        if (aStart > aEnd) {
            return b[bStart + k-1];
        }
        // 一样,b 数组遍历完了,直接返回 a 数组的
        if (bStart > bEnd) {
            return a[aStart + k-1];
        }
        // 当 k 大于1 的情况下 折半
        int middleIndex = k / 2 - 1;
        int aMax = a[aEnd];
        int bMax = b[bEnd];
        // a 在范围内,就取范围内的
        if (aStart + middleIndex < aEnd) {
            aMax = a[aStart + middleIndex];
        }
        // b 在范围内,就取范围内的
        if (bStart + middleIndex < bEnd) {
            bMax = b[bStart + middleIndex];
        }
        // 如果aMax较小,那么a这边的都可以排除掉
        if (aMax < bMax) {
            // a 的长度判断
            if (aStart + middleIndex < aEnd) {
                return getKth(a, b, aStart + middleIndex + 1, aEnd, bStart, bEnd, k - (middleIndex+1));
            } else {
                return getKth(a, b, aEnd, -1, bStart, bEnd, k - (aEnd - aStart + 1));
            }
        } else {
            // b 的长度判断
            if (bStart + middleIndex < bEnd) {
                return getKth(a, b, aStart, aEnd, bStart + middleIndex + 1, bEnd, k - (middleIndex+1));
            } else {
                return getKth(a, b, aStart, aEnd, bEnd, -1, k - (bEnd - bStart + 1));
            }
        }
    }

    public int getValue(int[] a, int[] b, int aStart, int aEnd, int bStart, int bEnd) {
        // 如果 a 数组遍历完了,直接返回 b 数组的
        if (aStart > aEnd) {
            return b[bStart];
        }
        // 一样,b 数组遍历完了,直接返回 a 数组的
        if (bStart > bEnd) {
            return a[aStart];
        }
        // 返回 a、b中较小的那个
        return a[aStart] > b[bStart] ? b[bStart] : a[aStart];
    }
}

3  小结

二分查找的作用点就在折半的去剔除一些元素,比如要第 5 个 那么可以找出每个数组的 第2 个元素,进行比较,小的一方就可以剔除2个,然后继续在剩下的找第3个,依次递归,直到找1个的时候,退出递归哈,加油。