1775.equal sum arrays with minimum number of operations

发布时间 2023-07-26 16:04:21作者: zwyyy456

Description

1775.equal-sum-arrays-with-minmum-number-of-operations

Solution

hash table + greedy algorithm

The general idea of this problem is hash + greedy algorithm.

Assuming that the sum1 minus sum2 to be diff and diff > 0, to reduce the difference of the sum of two arrays, we should make diff be value minus the maximum value of changing a number in one of the two array, and make the count of number changed minus 1, until diff < 0.

It's similar when diff < 0.

optimization

Suppose that sum1 < sum2, otherwise swap(nums1, nums2) and swap(sum1, sum2). Then let diff = sum2 - sum1, diff can be minus 1, 2, 3, 4, 5 for some times, which depends on nums1 and nums2, we can use a hash table to record.

Finally, we traverse i = 5, 4, 3, 2, 1.

Code

hash table + greedy algorithm

class Solution {
public:
    int find_min(vector<int> &v) {
        for (int i = 1; i < v.size(); i++) {
            if (v[i] != 0)
                return i;
        }
        return 6;
    }
    int find_max(vector<int> &v) {
        for (int i = v.size() - 1; i >= 1; i--) {
            if (v[i] != 0)
                return i;
        }
        return 1;
    }
    int minOperations(vector<int>& nums1, vector<int>& nums2) {
        int sum1 = 0, sum2 = 0;
        vector<int> mp1(7, 0);
        vector<int> mp2(7, 0);
        if (nums1.size() > 6 * nums2.size() || 6 * nums1.size() < nums2.size())
            return -1;
        for (int i = 0; i < nums1.size(); i++) {
            mp1[nums1[i]]++;
            sum1 += nums1[i];
        }
        for (int i = 0; i < nums2.size(); i++) {
            mp2[nums2[i]]++;
            sum2 += nums2[i];
        }
        int diff = sum1 - sum2;
        int cnt = 0;
        if (diff == 0)
            return cnt;
        else if (diff > 0) {
            // for nums1, change a number: maximum value of change: find_max(mp1) - 1;
            // for nums2, 6 - find_min(mp1);
            while (diff > 0) {
                cnt++;
                int minus1 = find_max(mp1) - 1;
                int plus2 = 6 - find_min(mp2);
                if (minus1 >= plus2) {
                    diff -= minus1;
                    mp1[find_max(mp1)]--;
                } else {
                    diff -= 6 - find_min(mp2);
                    mp2[find_min(mp2)]--;
                }
            }
            return cnt;
        } else {
            // for nums1, change a number: maximum value of change:  6 - find_min(mp1);
            // 对nums2, maximum value of change: find_max(mp2) - 1;
            while (diff < 0) {
                cnt++;
                int minus2 = find_max(mp2) - 1;
                int plus1 = 6 - find_min(mp1);
                if (minus2 >= plus1) {
                    diff += minus2;
                    mp2[find_max(mp2)]--;
                } else {
                    diff += plus1;
                    mp1[find_min(mp1)]--;
                }
            }
            return cnt;                  
        }

    }
};

better code

class Solution {
  public:
    int minOperations(vector<int> &nums1, vector<int> &nums2) {
        // 先判断不可能相等的情况
        if (nums1.size() > 6 * nums2.size() || nums2.size() > 6 * nums1.size()) {
            return -1;
        }
        // 两个数组的和
        int sum1 = 0, sum2 = 0;
        unordered_map<int, int> mp1;
        for (int &num : nums1) {
            sum1 += num;
        }
        for (int &num : nums2) {
            sum2 += num;
        }
        if (sum1 == sum2) {
            return 0;
        }
        if (sum1 > sum2) {
            vector<int> tmp = nums1;
            int tmp_sum = sum1;
            nums1 = nums2;
            sum1 = sum2;
            nums2 = tmp;
            sum2 = tmp_sum;
        }
        int cnt = 0;
        int diff = sum2 - sum1;
        for (int num : nums1) {
            ++mp1[6 - num];
        }
        for (int num : nums2) {
            ++mp1[num - 1];
        }
        for (int i = 5;; i--) {
            if (diff <= i * mp1[i]) {
                return cnt + (diff + i - 1) / i;
            }
            cnt += mp1[i];
            diff -= i * mp1[i];
        }
        return cnt;
    }
};