1439. 有序矩阵中的第 k 个最小数组和

发布时间 2023-05-31 11:18:05作者: Tianyiya

给你一个 m * n 的矩阵 mat,以及一个整数 k ,矩阵中的每一行都以非递减的顺序排列。

你可以从每一行中选出 1 个元素形成一个数组。返回所有可能数组中的第 k 个 最小 数组和。

来源:力扣(LeetCode)
链接:https://leetcode.cn/problems/find-the-kth-smallest-sum-of-a-matrix-with-sorted-rows
著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;

class Solution {

    private int[] merge(int[] a, int[] b, int k) {
        if (a.length < b.length) {
            return merge(b, a, k);
        }
        PriorityQueue<Node> queue = new PriorityQueue<>(new Comparator<Node>() {
            @Override
            public int compare(Node o1, Node o2) {
                return Integer.compare(o1.sum, o2.sum);
            }
        });
        for (int i = 0; i < b.length && i < k; i++) {
            queue.offer(new Node(0, i, a[0] + b[i]));
        }

        List<Integer> tmp = new ArrayList<>();
        while (tmp.size() < k && !queue.isEmpty()) {
            Node node = queue.poll();
            tmp.add(node.sum);
            if (node.aIndex < a.length - 1) {
                node.aIndex++;
                node.sum = a[node.aIndex] + b[node.bIndex];
                queue.offer(node);
            }
        }
        int[] ans = new int[tmp.size()];
        for (int i = 0; i < tmp.size(); i++) {
            ans[i] = tmp.get(i);
        }
        return ans;
    }

    public int kthSmallest(int[][] mat, int k) {
        int[] arr = mat[0];
        for (int i = 1; i < mat.length; i++) {
            arr = merge(arr, mat[i], k);
        }
        return arr[k - 1];
    }
}

class Node {
    int aIndex;
    int bIndex;
    int sum;

    public Node(int aIndex, int bIndex, int sum) {
        this.aIndex = aIndex;
        this.bIndex = bIndex;
        this.sum = sum;
    }
}

二分

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

class Solution {

    private int upper(int[] arr, int key) {
        int left = 0, right = arr.length - 1;
        while (left <= right) {
            int mid = (left + right) >> 1;
            if (arr[mid] <= key) {
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return left;
    }

    private int[] merge(int[] a, int[] b, int k) {
        if (a.length < b.length) {
            return merge(b, a, k);
        }
        k = Math.min(k, a.length * b.length);
        int left = a[0] + b[0], right = a[a.length - 1] + b[b.length - 1];
        int sum = 0;
        while (left <= right) {
            int mid = (left + right) >> 1;
            int cnt = 0;
//            for (int i = 0; i < a.length; i++) {
//                cnt += upper(b, mid - a[i]);
//            }
            int j = b.length - 1;
            for (int i = 0; i < a.length; i++) {
                while (j >= 0 && a[i] + b[j] > mid) {
                    j--;
                }
                cnt += j + 1;
            }
            if (cnt >= k) {
                sum = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        List<Integer> tmp = new ArrayList<>();
        for (int i = 0; i < a.length && tmp.size() < k; i++) {
            for (int j = 0; j < b.length && tmp.size() < k; j++) {
                if (a[i] + b[j] < sum) {
                    tmp.add(a[i] + b[j]);
                }
            }
        }

        while (tmp.size() < k) {
            tmp.add(sum);
        }

        Collections.sort(tmp);
        int[] ans = new int[tmp.size()];
        for (int i = 0; i < tmp.size(); i++) {
            ans[i] = tmp.get(i);
        }
        return ans;
    }

    public int kthSmallest(int[][] mat, int k) {
        int[] arr = mat[0];
        for (int i = 1; i < mat.length; i++) {
            arr = merge(arr, mat[i], k);
        }
        return arr[k - 1];
    }
}