CF1295E Permutation Separation 题解 线段树优化dp

发布时间 2023-03-30 12:07:07作者: quanjun

题目链接:https://codeforces.com/problemset/problem/1295/E

题目大意:

将排列 \(p_1, p_2, \ldots, p_n\) 先分成 \(p_1, \ldots, p_k\)\(p_{k+1}, \ldots, p_n\) 两个集合。

然后可以将元素从左边的集合移动到右边的集合,也可以将元素从右边的集合移动到左边的集合,移动 \(p_i\) 的代价是 \(a_i\)

用最小的总代价使左边的集合的元素都小于右边的集合。

解题思路:

"All elements in the left set smaller than all elements in the right set" 在左边的集合中的所有元素都要小于在右边的集合中的所有元素

”在左边集合中的所有元素都要小于在右边的集合中的所有元素” 意味着存在一个数值 \(val\) 使得:

  • 所有在第一个集合中的元素数值都 \(\lt val\),同时
  • 所有在第二个集合中的元素数值都 \(\ge val\)

所以可以考虑从 \(1\)\(n+1\) 开一个扫描线,开扫描线的同时维护每一个前缀 \(pos\) 的信息。

对于每一个下标 \(pos\),我们用 \(t[pos]\) 表示:

初始的时候将 \([p_1, p_2, \ldots, p_{pos}]\) 划分到第一个集合,将 \([p_{pos+1}, \ldots, p_n]\) 划分到第二个集合,然后将第一个集合中所有小于 \(val\) 的元素移动到第二个集合,同时将第二个集合中所有大于等于 \(val\) 的元素移动到第一个集合所需的最小花费。

可以发现,总的花费(\(t[pos]\))等于所有满足:

  • \(i \le pos\)\(p_i \ge val\),或者
  • \(i \gt pos\)\(p_i \lt val\)

\(a_i\) 之和。

那么问题来了:如果我们将 \(val\) 的数值增加 \(1\) 会发生什么事情呢?

我们定义排列 \(p\) 中数值为 \(val\) 的元素下标为 \(k\)(即 \(p_k = val\))。则:

  • 对于每一个下标 \(pos \ge k\) 我们不需要将 \(p_k\) 从第一个集合移动到第二个集合,所以我们需要将 \(t[pos] -= a_k\)
  • 对于每一个下标 \(pos \lt k\) 我们也不需要将 \(p_k\) 从第二个集合移动到第一个集合了,所以我们需要将 \(t[pos] += a_k\)

答案是 \(\min\limits_{1 \le pos < n}(t[pos])\)

这意味着我们需要对 \(t\) 数组执行两种类型的操作:

  1. 区间更新(加上一个值);
  2. 区间查询(查询最小值)。

我们可以使用线段树在扫描 \(val\) 的同时执行这些区间操作,时间复杂度为 \(O(n \log n)\)

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 5;

int n, p[maxn], q[maxn], a[maxn];
long long t[maxn], ans = 1ll<<60;
long long tr[maxn<<2], lazy[maxn<<2];

void push_up(int rt) {
    tr[rt] = min(tr[rt<<1], tr[rt<<1|1]);
}

void push_down(int rt) {
    if (lazy[rt]) {
        tr[rt<<1] += lazy[rt];
        tr[rt<<1|1] += lazy[rt];
        lazy[rt<<1] += lazy[rt];
        lazy[rt<<1|1] += lazy[rt];
        lazy[rt] = 0;
    }
}

#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1

void build(int l, int r, int rt) {
    if (l == r) {
        tr[rt] = t[l];
        return;
    }
    int mid = (l + r) / 2;
    build(lson), build(rson), push_up(rt);
}

void update(int L, int R, int val, int l, int r, int rt) {
    if (L <= l && r <= R) {
        tr[rt] += val;
        lazy[rt] += val;
        return;
    }
    push_down(rt);
    int mid = (l + r) / 2;
    if (L <= mid) update(L, R, val, lson);
    if (R > mid) update(L, R, val, rson);
    push_up(rt);
}

long long query(int L, int R, int l, int r, int rt) {
    if (L <= l && r <= R) return tr[rt];
    push_down(rt);
    int mid = (l + r) / 2;
    long long res = 1ll<<60;
    if (L <= mid) res = min(res, query(L, R, lson));
    if (R > mid) res = min(res, query(L, R, rson));
    return res;
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", p+i), q[p[i]] = i;
    for (int i = 1; i <= n; i++)
        scanf("%d", a+i);

    for (int i = 1; i <= n; i++)
        t[i] = a[i] + t[i-1], ans = min(ans, t[i]);

    build(1, n, 1);

    for (int val = 1; val <= n; val++) {
        int k = q[val];
        update(k, n, -a[k], 1, n, 1);
        if (k > 1) update(1, k-1, a[k], 1, n, 1);
        ans = min(ans, query(1, n-1, 1, n, 1));
    }
    printf("%lld\n", ans);
    return 0;
}