题目链接:https://www.luogu.com.cn/problem/P1552
题目大意:
每次求子树中薪水和不超过 \(M\) 的最大节点数。
解题思路:
使用左偏树维护一个大根堆。
首先定义一个 Node 的结构体:
struct Node {
int s[2], c, sz, dis;
long long sum;
Node() {};
Node(int _c) { s[0] = s[1] = 0; dis = sz = 1; sum = c = _c; }
} tr[maxn];
其中:
s[0]
左儿子编号(不存在则为 \(0\))s[1]
右儿子编号(不存在则为 \(0\))c
:即题目描述中的薪水 \(C_i\),表示该节点对应的薪水sz
:以当前节点为根的子树中的节点个数dis
:左偏树中当前节点到最近一个外节点的距离(方便起见,实际加了一个 \(1\))sum
:以当前节点为根的子树中所有节点对应的薪水c
之和
我在实现的时候维护了一个并查集,并查集的作用是方便查找当前节点所在的堆的根节点(通过如下所述的 find
函数):
void init() {
for (int i = 1; i <= n; i++) f[i] = i;
}
int find(int x) {
return x == f[x] ? x : f[x] = find(f[x]);
}
然后就是涉及两个左偏树的合并操作了,在操作前我封装了一个 push_up
函数,用于更新当前节点的信息:
void push_up(int x) {
int l = tr[x].s[0], r = tr[x].s[1];
tr[x].dis = tr[r].dis + 1;
tr[x].sz = tr[l].sz + tr[r].sz + 1;
tr[x].sum = tr[l].sum + tr[r].sum + tr[x].c;
}
然后是 merge 操作,合并根节点为 x 和 y:
int merge(int x, int y) {
if (!x || !y) return x + y;
if (tr[x].c < tr[y].c) swap(x, y);
int &l = tr[x].s[0], &r = tr[x].s[1];
f[r = merge(r, y)] = x;
if (tr[l].dis < tr[r].dis) swap(l, r);
push_up(x);
return x;
}
然后是 pop 操作,弹出根节点为 x 的那棵树的根节点(即堆顶元素):
void pop(int x) {
int l = tr[x].s[0], r = tr[x].s[1];
f[l] = l;
f[r] = r;
f[x] = merge(l, r);
}
这里的操作我的想法是:因为虽然 x 被 pop 掉了,但是之后可能还会需要找 x 所在集合的根节点,所以将 f[x] 更新为它的左右子树合并后的新的根节点。
然后处理的时候从 n 到 1 处理每一个节点,对于当前节点,只要其 sum > M,就 pop 堆顶元素(因为堆顶元素的 c 是最大的,所有优先 pop 掉 c 大的)
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
struct Node {
int s[2], c, sz, dis;
long long sum;
Node() {};
Node(int _c) { s[0] = s[1] = 0; dis = sz = 1; sum = c = _c; }
} tr[maxn];
int n, f[maxn], b[maxn], L[maxn];
long long m, ans;
void init() {
for (int i = 1; i <= n; i++) f[i] = i;
}
int find(int x) {
return x == f[x] ? x : f[x] = find(f[x]);
}
void push_up(int x) {
int l = tr[x].s[0], r = tr[x].s[1];
tr[x].dis = tr[r].dis + 1;
tr[x].sz = tr[l].sz + tr[r].sz + 1;
tr[x].sum = tr[l].sum + tr[r].sum + tr[x].c;
}
int merge(int x, int y) {
if (!x || !y) return x + y;
if (tr[x].c < tr[y].c) swap(x, y);
int &l = tr[x].s[0], &r = tr[x].s[1];
f[r = merge(r, y)] = x;
if (tr[l].dis < tr[r].dis) swap(l, r);
push_up(x);
return x;
}
void pop(int x) {
int l = tr[x].s[0], r = tr[x].s[1];
f[l] = l;
f[r] = r;
f[x] = merge(l, r);
}
int main() {
scanf("%d%lld", &n, &m);
init();
for (int i = 1; i <= n; i++) {
int c;
scanf("%d%d%d", b+i, &c, L+i);
tr[i] = Node(c);
}
for (int i = n; i >= 1; i--) {
int x = find(i);
while (tr[x].sum > m) pop(x), x = find(x);
long long tmp = (long long) tr[x].sz * L[i];
ans = max(ans, tmp);
if (b[i] && find(b[i]) != x) merge(find(b[i]), x);
}
printf("%lld\n", ans);
return 0;
}