2104. 子数组范围和 (Medium)

发布时间 2023-06-13 16:46:20作者: zwyyy456

问题描述

2104. 子数组范围和 (Medium) 给你一个整数数组 numsnums 中,子数组的 范围 是子 数组中最大元素和最小元素的差值。

返回 nums所有 子数组范围的

子数组是数组中一个连续 非空 的元素序列。

示例 1:

输入:nums = [1,2,3]
输出:4
解释:nums 的 6 个子数组如下所示:
[1],范围 = 最大 - 最小 = 1 - 1 = 0
[2],范围 = 2 - 2 = 0
[3],范围 = 3 - 3 = 0
[1,2],范围 = 2 - 1 = 1
[2,3],范围 = 3 - 2 = 1
[1,2,3],范围 = 3 - 1 = 2
所有范围的和是 0 + 0 + 0 + 1 + 1 + 2 = 4

示例 2:

输入:nums = [1,3,3]
输出:4
解释:nums 的 6 个子数组如下所示:
[1],范围 = 最大 - 最小 = 1 - 1 = 0
[3],范围 = 3 - 3 = 0
[3],范围 = 3 - 3 = 0
[1,3],范围 = 3 - 1 = 2
[3,3],范围 = 3 - 3 = 0
[1,3,3],范围 = 3 - 1 = 2
所有范围的和是 0 + 0 + 0 + 2 + 0 + 2 = 4

示例 3:

输入:nums = [4,-2,-3,4,1]
输出:59
解释:nums 中所有子数组范围的和是 59

提示:

  • 1 <= nums.length <= 1000
  • -10⁹ <= nums[i] <= 10⁹

进阶: 你可以设计一种时间复杂度为 O(n) 的解决方案吗?

解题思路

暴力方案,即枚举左端点,子循环再枚举右端点,找到子数组的最大值和最小值,时间复杂度为 $O(n^2)$。

本题所求的结果,可以转化为求所有子数组的最大值之和减去所有子数组的最小值之和。

所有子数组的最大值之和为 $sum_{max} = \sum\limits_{i=0}^{n-1} nums[i] * cnt_i$,其中 $cnt_i$ 为以 nums[i] 为最大值的子数组的数量,到这里题目就转化成了 795.区间子数组个数 (Medium)

所有子数组的最小值之和为 $sum_{min} = \sum\limits_{i=0}^{n-1} nums[i] *cnt_i$,这里 $cnt_i$ 则是以 nums[i] 为最小值的子数组的数量。

代码

class Solution {
  public:
    long long subArrayRanges(vector<int> &nums) {
        // 结果可以表示为所有子数组的最大值的和减去最小值的和
        // 对 nums[i],假设有 nums[idx] <= nums[i] 小且 idx <= i 的最大 idx 为 l,
        // 同理最小 idx 为 r,则以 nums[idx] 为最大值的子数组共有 (r - idx + 1) * (idx - l + 1);
        // 可以用单调递减栈来解决,从栈底到栈顶单调递减
        stack<int> stk;
        int n = nums.size();
        long sum_max = 0;
        for (int i = 0; i < n; ++i) {
            while (!stk.empty() && nums[i] > nums[stk.top()]) {
                int len1 = i - stk.top();
                int idx = stk.top();
                long val = nums[idx];
                stk.pop();
                int len2 = stk.empty() ? idx + 1 : idx - stk.top();
                sum_max += val * len1 * len2;
            }
            stk.push(i);
        }
        while (!stk.empty()) {
            int idx = stk.top();
            int len1 = n - idx;
            long val = nums[idx];
            stk.pop();
            int len2 = stk.empty() ? idx + 1 : idx - stk.top();
            sum_max += val * len1 * len2;
        }
        long sum_min = 0;
        stack<int> stk_min;
        for (int i = 0; i < n; ++i) {
            while (!stk_min.empty() && nums[i] < nums[stk_min.top()]) {
                int idx = stk_min.top();
                int len1 = i - idx;
                long val = nums[idx];
                stk_min.pop();
                int len2 = stk_min.empty() ? idx + 1 : idx - stk_min.top();
                sum_min += val * len1 * len2;
            }
            stk_min.push(i);
        }
        while (!stk_min.empty()) {
            int idx = stk_min.top();
            int len1 = n - idx;
            long val = nums[idx];
            stk_min.pop();
            int len2 = stk_min.empty() ? idx + 1 : idx - stk_min.top();
            sum_min += val * len1 * len2;
        }
        return sum_max - sum_min;
    }
};