买蛋糕

发布时间 2023-08-25 09:59:46作者: xinyimama

思路记录

给定三个序列,求所有的\(X_i+Y_i+Z_i\)中,求前K大的数字分别是多少?

原来练习堆的时候,做过类似的题序列求和,求前K小的数字,但是我不记得思路是什么了?

错误思路一

刚才有一个思路,枚举的时候,先枚举两个序列的最大值,让他们和第三个序列的每个值加起来,应该是前几个,但是这个思路是错的

例如 7 8
5 6
1 5
那么最大的加起来就是8+6+5,但是第二大的并不是8+6+1,而是7+5+5。

模仿序列合并

当时,我做序列合并的时候,为了防止相同的元素重复进队,我设定了方向信息来控制。

对了这个题,我想,相同的元素重复进队的话,我标记就可以了,所以前30%的数据,我使用了标记数组。但是当数据量巨大的时候,标记数组就会MLE

然后我重新打开序列合并的题解,发现,第一份题解写的非常好:

首先,把A和B两个序列分别从小到大排序,变成两个有序队列。这样,从A和B中各任取一个数相加得到N2个和,可以把这些和看成形成了n个有序表/队列:

A[1]+B[1] <= A[1]+B[2] <= … <= A[1]+B[N]

A[2]+B[1] <= A[2]+B[2] <= … <= A[2]+B[N]

……

A[N]+B[1] <= A[N]+B[2] <= … <= A[N]+B[N]

接下来,就相当于要将这N个有序队列进行合并排序:

首先,将这N个队列中的第一个元素放入一个堆中;

然后;每次取出堆中的最小值。若这个最小值来自于第k个队列,那么,就将第k个队列的下一个元素放入堆中。

时间复杂度:O(NlogN)。

这种方式不再使用方向来控制扩展,这样的扩展方式更加通用,基于这个思路,我找到了 这个题的做法

题解思路

上一个题解最重要的方式是固定一个数组A,让A数组的每一个元素都和B数组的第一个元素结合,最后扩展的时候只需要扩展B即可,对于这个题而言,我们首先将数组从大到小排序,需要先找到三个中最小的第一个元素,假设为A[1],那么让B[i]和C[j]自由组合,和A[1]组成\(N^2\)个元素,将这\(N^2\)个元素放入堆里,弹出一个元素的话,接着将他的第二个元素放进去。代码如下:

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1005;
int k, xx, yy, zz;
long long x[maxn], y[maxn], z[maxn];
struct node {
    long long val;
    int px, py, pz;
};
bool operator<(node a, node b) { return a.val < b.val; }
priority_queue<node> q;
bool cmp(long long a, long long b) { return a > b; }
int main() {
    scanf("%d%d%d%d", &xx, &yy, &zz, &k);
    for (int i = 1; i <= xx; i++) scanf("%lld", &x[i]);
    sort(x + 1, x + 1 + xx, cmp);
    for (int i = 1; i <= yy; i++) scanf("%lld", &y[i]);
    sort(y + 1, y + 1 + yy, cmp);
    for (int i = 1; i <= zz; i++) scanf("%lld", &z[i]);
    sort(z + 1, z + 1 + zz, cmp);
    if (y[1] >= x[1] && z[1] >= x[1]) {
        for (int i = 1; i <= yy; i++)
            for (int j = 1; j <= zz; j++) q.push(node{ x[1] + y[i] + z[j], 1, i, j });
        int dx, dy, dz;
        for (int i = 1; i <= k; i++) {
            node d = q.top();
            q.pop();
            cout << d.val << endl;
            dx = d.px;
            dy = d.py;
            dz = d.pz;
            if (dx + 1 <= xx)
                q.push(node{ x[dx + 1] + y[dy] + z[dz], dx + 1, dy, dz });
        }
        if (x[1] == y[1] && x[1] == z[1])
            return 0;
    }
    if (x[1] >= y[1] && z[1] >= y[1] && x[1] != y[1]) {
        for (int i = 1; i <= xx; i++)
            for (int j = 1; j <= zz; j++) q.push(node{ x[i] + y[1] + z[j], i, 1, j });
        int dx, dy, dz;
        for (int i = 1; i <= k; i++) {
            node d = q.top();
            q.pop();
            cout << d.val << endl;
            dx = d.px;
            dy = d.py;
            dz = d.pz;
            if (dy + 1 <= yy)
                q.push(node{ x[dx] + y[dy + 1] + z[dz], dx, dy + 1, dz });
        }
    }
    if (x[1] >= z[1] && y[1] >= z[1] && z[1] != x[1] && z[1] != y[1]) {
        for (int i = 1; i <= xx; i++)
            for (int j = 1; j <= yy; j++) q.push(node{ x[i] + y[j] + z[1], i, j, 1 });
        int dx, dy, dz;
        for (int i = 1; i <= k; i++) {
            node d = q.top();
            q.pop();
            cout << d.val << endl;
            dx = d.px;
            dy = d.py;
            dz = d.pz;
            if (dz + 1 <= zz)
                q.push(node{ x[dx] + y[dy] + z[dz + 1], dx, dy, dz + 1 });
        }
    }
    return 0;
}
/**
 3 3 3 5
 2 20 100
 1 10 100
 1 10 100
 */

后续思考

写完了这道题之后,我又看了别人的代码,刚才提到的思路里,我使用数组来避免重复的问题,但是数组太大了,会MLE,所以题解中有很多使用set来模拟堆来做的,当然,set里不止存在数值,还可以存储结构体,这样做的话,这个题瞬间简单了

点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define ll long long
struct chs {
    int i, j, k;
    chs(int x = 0, int y = 0, int z = 0) {
        i = x;
        j = y;
        k = z;
    }
};
ll a[1005];
ll b[1005];
ll c[1005];
bool operator<(chs p, chs q) {
    if (a[p.i] + b[p.j] + c[p.k] != a[q.i] + b[q.j] + c[q.k]) {
        return a[p.i] + b[p.j] + c[p.k] > a[q.i] + b[q.j] + c[q.k];
    }
    if (p.i != q.i) {
        return p.i > q.i;
    } else if (p.j != q.j) {
        return p.j > q.j;
    } else {
        return p.k > q.k;
    }
}
set<chs> st;
int main() {
    int X, Y, Z, K;
    scanf("%d %d %d %d", &X, &Y, &Z, &K);
    for (int i = 0; i < X; i++) {
        scanf("%lld", &a[i]);
    }
    for (int i = 0; i < Y; i++) {
        scanf("%lld", &b[i]);
    }
    for (int i = 0; i < Z; i++) {
        scanf("%lld", &c[i]);
    }
    sort(a, a + X, greater<ll>());
    sort(b, b + Y, greater<ll>());
    sort(c, c + Z, greater<ll>());
    st.insert(chs(0, 0, 0));
    while (K--) {
        chs x = *st.begin();
        st.erase(st.begin());
        printf("%lld\n", a[x.i] + b[x.j] + c[x.k]);
        if (x.i < X - 1)
            st.insert(chs(x.i + 1, x.j, x.k));
        if (x.j < Y - 1)
            st.insert(chs(x.i, x.j + 1, x.k));
        if (x.k < Z - 1)
            st.insert(chs(x.i, x.j, x.k + 1));
    }
    return 0;
}

还有一个办法避免重复,使用map存储。

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1005;
int k, xx, yy, zz;
long long x[maxn], y[maxn], z[maxn];
struct node {
    long long val;
    int px, py, pz;
};
struct ps {
    int px, py, pz;
    bool operator<(const ps nod) const {  //为什么要这样做?因为map和set都是有序容器
        if (px == nod.px && py == nod.py)
            return pz < nod.pz;
        if (px == nod.px)
            return py < nod.py;
        return px < nod.px;
    }
};
bool operator<(node a, node b) { return a.val < b.val; }
priority_queue<node> q;
map<ps, int> mp;
bool cmp(long long a, long long b) { return a > b; }
int main() {
    scanf("%d%d%d%d", &xx, &yy, &zz, &k);
    for (int i = 1; i <= xx; i++) scanf("%lld", &x[i]);
    sort(x + 1, x + 1 + xx, cmp);
    for (int i = 1; i <= yy; i++) scanf("%lld", &y[i]);
    sort(y + 1, y + 1 + yy, cmp);
    for (int i = 1; i <= zz; i++) scanf("%lld", &z[i]);
    sort(z + 1, z + 1 + zz, cmp);
    q.push(node{ x[1] + y[1] + z[1], 1, 1, 1 });
    int dx, dy, dz;
    for (int i = 1; i <= k; i++) {
        while (1) {
            node d = q.top();
            dx = d.px;
            dy = d.py;
            dz = d.pz;
            if (mp[ps{ dx, dy, dz }] == 1)
                q.pop();
            else
                break;
        }
        node d = q.top();
        q.pop();
        cout << d.val << endl;
        mp[ps{ dx, dy, dz }] = 1;
        if (dz + 1 <= zz)
            q.push(node{ x[dx] + y[dy] + z[dz + 1], dx, dy, dz + 1 });
        if (dx + 1 <= xx)
            q.push(node{ x[dx + 1] + y[dy] + z[dz], dx + 1, dy, dz });
        if (dy + 1 <= yy)
            q.push(node{ x[dx] + y[dy + 1] + z[dz], dx, dy + 1, dz });
    }
}

扩展

当然,这个题可以扩展,可以扩展到N个序列,如果是这样,该怎么做呢?
首先固定一个序列的办法就不可取了,因为前面的时间复杂度为\(N^{(n-1)}\),而且空间上也不可取,同时map和set的方法都不可取。