CF1398E Two Types of Spells 题解 set

发布时间 2023-05-30 15:52:40作者: quanjun

题目链接:https://codeforces.com/problemset/problem/1398/E

题目大意

你有一个集合,初始为空。

有两种类型的元素,一种是普通元素,一种是强化元素,两种元素都有一个数值。

\(n\) 次操作,每次操作会往集合中加入一个元素或者删除一个元素。

每次操作后,你都需要确定集合中元素的一个排列,使得排列的价值最大。

排列价值的计算规则是:

每个元素的价值就是它对应的数字,但是强化元素能够使它在排列中后一个位置的元素的价值翻倍(即:对于排列中某一个元素来说,如果它前一个位置的元素是一个强化元素,则这个元素的价值是它本身的数值 \(\times 2\))。

举个例子,假设现在有三个元素:

  • \(1\) 个元素是一个数值为 \(5\) 的普通元素;
  • \(2\) 个元素是一个数值为 \(1\) 的强化元素;
  • \(3\) 个元素是一个数值为 \(8\) 的强化元素。

则:

  • 如果按照第 \(1\) 个元素,第 \(2\) 个元素,第 \(3\) 个元素排列,对应的价值为 \(5 + 1 + 2 \cdot 8 = 22\)
  • 如果按照第 \(1\) 个元素,第 \(3\) 个元素,第 \(2\) 个元素排列,对应的价值为 \(5 + 8 + 2 \cdot 1 = 15\)
  • 如果按照第 \(2\) 个元素,第 \(1\) 个元素,第 \(3\) 个元素排列,对应的价值为 \(1 + 2 \cdot 5 + 8 = 19\)
  • 如果按照第 \(2\) 个元素,第 \(3\) 个元素,第 \(1\) 个元素排列,对应的价值为 \(1 + 2 \cdot 8 + 2 \cdot 5 = 27\)
  • 如果按照第 \(3\) 个元素,第 \(1\) 个元素,第 \(2\) 个元素排列,对应的价值为 \(8 + 2 \cdot 5 + 1 = 19\)
  • 如果按照第 \(3\) 个元素,第 \(2\) 个元素,第 \(1\) 个元素排列,对应的价值为 \(8 + 2 \cdot 1 + 2 \cdot 5 = 20\)

解题思路

假设任意次操作之后,集合中有 \(m\) 个普通元素,和 \(k\) 个强化元素。则:

  • 若数值最大的 \(k\) 个元素 不全是 强化元素,则答案为:所有元素的数值和 + 数值最大的 \(k\) 个元素的数值和
  • 若数值最大的 \(k\) 个元素 全是 强化元素,则答案为:所有元素的数值和 + 数值最大的 \(k-1\) 个元素(当然它们都是强化元素)的数值和 + 数值最大的 \(1\) 个非强化元素的数值

示例程序

实现时,用:

  • \(st[0]\) 保存普通元素(\(m\) 对应普通元素个数);
  • \(st[1]\) 保存强化元素(\(k\) 对应强化元素个数);
  • \(st2[1]\) 保存数值最大的 \(k\) 个元素;
  • \(st2[0]\) 保存除了数值最大的 \(k\) 个元素以外其余的元素;
  • \(sum\) 对应 \(st2[1]\) 中数值最大的 \(k\) 个元素之和;
  • \(sum2\) 对应目前所有元素之和

(命名比较随意,主要是因为写的时候发现少了又加了一个;包括 \(st2[0]\)\(st2[1]\) 一开始我是开了两个堆,结果发现堆不支持随机删除囧)

#include <bits/stdc++.h>
using namespace std;

int n, p, d, m, k;
set<int> st[2], st2[2];
long long sum, sum2;    // sum 记录前k大值之和,sum2 记录所有值之和

int main() {
    scanf("%d", &n);
    while (n--) {
        scanf("%d%d", &p, &d);
        if (d > 0) {
            st[p].insert(d);
            p ? k++ : m++;
            st2[0].insert(d);
            sum2 += d;
        }
        else {  // d < 0
            d = -d;
            p ? k-- : m--;
            st[p].erase(d);
            if (st2[0].count(d)) st2[0].erase(d);
            else {
                st2[1].erase(d);
                sum -= d;
            }
            sum2 -= d;
        }
        while (st2[0].size() && st2[1].size()) {
            int x = *st2[0].rbegin();
            int y = *st2[1].begin();
            if (x > y) {
                st2[0].erase(x);
                st2[1].erase(y);
                st2[0].insert(y);
                st2[1].insert(x);
                sum += x - y;
            }
            else
                break;
        }
        while (st2[1].size() < k) {
            assert(!st2[0].empty());
            int x = *st2[0].rbegin();
            st2[0].erase(x);
            st2[1].insert(x);
            sum += x;
        }
        while (st2[1].size() > k) {
            assert(!st2[1].empty());
            int x = *st2[1].begin();
            st2[1].erase(x);
            sum -= x;
            st2[0].insert(x);
        }
        long long ans;
        if (st[0].empty())
            ans = sum2 + sum - (*st[1].begin());
        else if (st[1].empty())
            ans = sum2;
        else if (*st[0].rbegin() < *st[1].begin())
            ans = sum2 + sum - (*st[1].begin()) + (*st[0].rbegin());
        else
            ans = sum2 + sum;
        printf("%lld\n", ans);
    }
    return 0;
}