洛谷P1552 [APIO2012] 派遣 题解 左偏树

发布时间 2023-04-06 18:19:36作者: quanjun

题目链接:https://www.luogu.com.cn/problem/P1552

题目大意:

每次求子树中薪水和不超过 \(M\) 的最大节点数。

解题思路:

使用左偏树维护一个大根堆。

首先定义一个 Node 的结构体:

struct Node {
    int s[2], c, sz, dis;
    long long sum;
    Node() {};
    Node(int _c) { s[0] = s[1] = 0; dis = sz = 1; sum = c = _c; }
} tr[maxn];

其中:

  • s[0] 左儿子编号(不存在则为 \(0\)
  • s[1] 右儿子编号(不存在则为 \(0\)
  • c:即题目描述中的薪水 \(C_i\),表示该节点对应的薪水
  • sz:以当前节点为根的子树中的节点个数
  • dis:左偏树中当前节点到最近一个外节点的距离(方便起见,实际加了一个 \(1\)
  • sum:以当前节点为根的子树中所有节点对应的薪水 c 之和

我在实现的时候维护了一个并查集,并查集的作用是方便查找当前节点所在的堆的根节点(通过如下所述的 find 函数):

void init() {
    for (int i = 1; i <= n; i++) f[i] = i;
}

int find(int x) {
    return x == f[x] ? x : f[x] = find(f[x]);
}

然后就是涉及两个左偏树的合并操作了,在操作前我封装了一个 push_up 函数,用于更新当前节点的信息:

void push_up(int x) {
    int l = tr[x].s[0], r = tr[x].s[1];
    tr[x].dis = tr[r].dis + 1;
    tr[x].sz = tr[l].sz + tr[r].sz + 1;
    tr[x].sum = tr[l].sum + tr[r].sum + tr[x].c;
}

然后是 merge 操作,合并根节点为 x 和 y:

int merge(int x, int y) {
    if (!x || !y) return x + y;
    if (tr[x].c < tr[y].c) swap(x, y);
    int &l = tr[x].s[0], &r = tr[x].s[1];
    f[r = merge(r, y)] = x;
    if (tr[l].dis < tr[r].dis) swap(l, r);
    push_up(x);
    return x;
}

然后是 pop 操作,弹出根节点为 x 的那棵树的根节点(即堆顶元素):

void pop(int x) {
    int l = tr[x].s[0], r = tr[x].s[1];
    f[l] = l;
    f[r] = r;
    f[x] = merge(l, r);
}

这里的操作我的想法是:因为虽然 x 被 pop 掉了,但是之后可能还会需要找 x 所在集合的根节点,所以将 f[x] 更新为它的左右子树合并后的新的根节点。

然后处理的时候从 n 到 1 处理每一个节点,对于当前节点,只要其 sum > M,就 pop 堆顶元素(因为堆顶元素的 c 是最大的,所有优先 pop 掉 c 大的)

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;

struct Node {
    int s[2], c, sz, dis;
    long long sum;
    Node() {};
    Node(int _c) { s[0] = s[1] = 0; dis = sz = 1; sum = c = _c; }
} tr[maxn];

int n, f[maxn], b[maxn], L[maxn];
long long m, ans;

void init() {
    for (int i = 1; i <= n; i++) f[i] = i;
}

int find(int x) {
    return x == f[x] ? x : f[x] = find(f[x]);
}

void push_up(int x) {
    int l = tr[x].s[0], r = tr[x].s[1];
    tr[x].dis = tr[r].dis + 1;
    tr[x].sz = tr[l].sz + tr[r].sz + 1;
    tr[x].sum = tr[l].sum + tr[r].sum + tr[x].c;
}

int merge(int x, int y) {
    if (!x || !y) return x + y;
    if (tr[x].c < tr[y].c) swap(x, y);
    int &l = tr[x].s[0], &r = tr[x].s[1];
    f[r = merge(r, y)] = x;
    if (tr[l].dis < tr[r].dis) swap(l, r);
    push_up(x);
    return x;
}

void pop(int x) {
    int l = tr[x].s[0], r = tr[x].s[1];
    f[l] = l;
    f[r] = r;
    f[x] = merge(l, r);
}

int main() {
    scanf("%d%lld", &n, &m);
    init();
    for (int i = 1; i <= n; i++) {
        int c;
        scanf("%d%d%d", b+i, &c, L+i);
        tr[i] = Node(c);
    }
    for (int i = n; i >= 1; i--) {
        int x = find(i);
        while (tr[x].sum > m) pop(x), x = find(x);
        long long tmp = (long long) tr[x].sz * L[i];
        ans = max(ans, tmp);
        if (b[i] && find(b[i]) != x) merge(find(b[i]), x);
    }
    printf("%lld\n", ans);
    return 0;
}