2019 CCPC Harbin E. Exchanging Gifts
题意
已知序列 \(g\),将序列 \(g\) 以某种方式乱序后的结果为序列 \(h\),定义序列 \(g\) 的开心值为:在相同下标情况下,序列 \(g\) 对应下标的值和序列 \(h\) 对应下标的值不相同的下标的个数的最大值。
输入给出 \(n(1\leq n \leq 10^6)\) 个序列,记为 \(s_1,s_2,\dots,s_n\),对于序列 \(s_i\) 可能以两种方式给出:
-
1 k q[1...k]
为 \(s_i = [q_1,q_2,\dots,q_k]\),满足 \(1\leq k\leq 10^6,1\leq q_j\leq 10^9\),满足所有 \(k\) 之和小于 \(10^6\) -
2 x y
为 \(s_i = s_x + s_y\),满足 \(1\leq x,y \leq i-1\)
需要求 \(s_n\) 的开心值。
思路
容易发现,对于长度为 \(n\) 的某个序列:
若存在某个元素 \(a\) 其长度 \(m > \lfloor\frac{n}{2}\rfloor\),则开心值为 \(2(n-m)\)
若对于任意元素其长度 \(m \leq \lfloor\frac{n}{2}\rfloor\),则开心值为 \(n\)
所以我们对于一个长度为 \(n\) 的序列,我们只需关心出现超过 \(\lfloor\frac{n}{2}\rfloor\) 的数,及其个数。
假设我们知道序列 \(s_n\) 中哪个数可能超过半数,这个数的个数是很容易在 \(O(n + \sum k)\) 的时间内求出的。
问题转化为如何快速知道可能超过半数的数。
考虑以下操作:
假设有一个 \(n\) 个数的序列,每次我们都在这 \(n\) 个数中 随意 找出 \(2\) 个不相同的数,将其移出序列,反复操作,直到无法操作为止。
显然对于该操作,要么没留下数,要么留下的数是都是同一个数,并且如果存在某个数的个数超过半数,则留下来的数一定都是这个数。
模拟该操作的代价是 \(O(n)\) 的。
回到本题
对于以方式1
给出的序列,经过以上的操作,我们可以得到这个序列的“可能”超过半数的数及其剩余的数量(如果有,则一定是超过半数的数,如果没有,是哪个数都无所谓,我们不关心它的数量),实现这个操作的复杂度是 \(O(k)\) 的。
对于以方式2
给出的序列,相当于将两个序列合并,由于上述操作并没有操作顺序的要求,所以基于两个序列已经计算好的“可能”超过半数的数及其剩余的数量,继续上述操作,得到合并后“可能”超过半数的数及其剩余的数量,实现这个操作的复杂度是 \(O(1)\) 的。
最终我们可以在复杂度 \(O(n + \sum k)\) 下,得到对 \(s_n\) 进行上述操作的“可能”超过半数的数及其剩余的数量。
接下来我们只要关心该数的个数,上面分析过,这件事可以在 \(O(n + \sum k)\) 实现。值得注意的是,若该数不超过半数,那么不可能会有其他数超过半数(原因在上文中是显然的)。然后可以很容易地算出结果。
代码
#include <algorithm>
#include <iostream>
#include <queue>
#include <vector>
using namespace std;
const int maxn = 1e6 + 10;
struct Node {
int num;
int64_t cnt;
};
void solve() {
int n;
cin >> n;
vector<Node> nodes(n + 1, {0, 0});
vector<vector<int>> que(n + 1);
for (int i = 1; i <= n; i++) {
int op;
cin >> op;
que[i].push_back(op);
if (op == 1) {
int k;
cin >> k;
que[i].push_back(k);
for (int j = 0; j < k; j++) {
int a;
cin >> a;
que[i].push_back(a);
if (nodes[i].cnt == 0 || nodes[i].num == a) {
nodes[i].num = a;
nodes[i].cnt++;
} else {
nodes[i].cnt--;
}
}
} else {
int x, y;
cin >> x >> y;
que[i].push_back(x);
que[i].push_back(y);
if (nodes[x].num == nodes[y].num) {
nodes[i].num = nodes[x].num;
nodes[i].cnt = nodes[x].cnt + nodes[y].cnt;
} else {
if (nodes[x].cnt < nodes[y].cnt) {
nodes[i].num = nodes[y].num;
nodes[i].cnt = nodes[y].cnt - nodes[x].cnt;
} else {
nodes[i].num = nodes[x].num;
nodes[i].cnt = nodes[x].cnt - nodes[y].cnt;
}
}
}
}
int num = nodes[n].num;
vector<int64_t> sum(n + 1), mx(n + 1);
for (int i = 1; i <= n; i++) {
auto p = que[i].begin();
if (*(p++) == 1) {
int k;
k = *(p++);
for (int j = 0; j < k; j++) {
if (*(p++) == num) {
mx[i]++;
}
}
sum[i] = k;
} else {
int x, y;
x = *(p++);
y = *(p++);
sum[i] = sum[x] + sum[y];
mx[i] = mx[x] + mx[y];
}
}
if (mx[n] <= sum[n] / 2) {
cout << sum[n] << '\n';
} else {
cout << (sum[n] - mx[n]) * 2 << '\n';
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
int T;
cin >> T;
while (T--) solve();
return 0;
}