P9338 [JOISC 2023 Day3] Chorus

发布时间 2024-01-02 16:07:26作者: Chy12321

套路地把题意转到走棋盘上,即给定一个 \(n \times n\) 的网格图,\(\texttt A\) 代表向右走一步,\(\texttt B\) 代表向上走一步,那么每次操作就是把右上和左下互换。
\(f_i\) 表示从 \((1, 1)\)\((i, i)\) 的最小开销,\(w(l, r)\) 表示 \((l, l)\)\((r, r)\) 拐一次弯的开销。

\[f_i = \min\limits_{j = 1}^{i - 1}\{f_j + w(j, i)\} \]

\(r_i\) 表示第 \(i\) 次向上走之前向右走的次数。

显然有解的必要条件是 \(r_i \ge i\),这可以在初始时就统计进答案,然后我们默认 \(r_i \ge i\)

推了好久才推出来:

\[w(l, r) = (\sum_{i = 1}^n[r_i \le r] - l) \times r - (\sum_{i = 1}^nr_i[r_i \le r] - \sum_{i = 1}^lr_i) \]

于是 \(\mathcal O(n)\) 预处理后求每个 \(w(l, r)\) 的时间复杂度是 \(\mathcal O(1)\)

其实猜猜就好了 打表每个点 \((i, f_i)\),发现其具有凸性,然后可以 wqs 二分。

二分内的 DP 的时间复杂度是 \(\mathcal O(n^2)\),所以总时间复杂度是 \(\mathcal O(n^2 \log V)\)

考虑优化,然后猛然发现转移可以斜率优化,这样内层 DP 的时间复杂度就是 \(\mathcal O(n)\) 了,总时间复杂度 \(\mathcal O(n \log V)\)

代码:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

constexpr int N = 1e6 + 10;

int n, m, r[N], g[N];
int head, tail, q[N];
ll pre[N], cnt[N], sum[N], f[N];
char s[2 * N];

inline ll x(int i) {return i;}

inline ll y(int i) {return f[i] + pre[i] + i;}

inline ll w(int l, int r) {return (cnt[r] - l) * r - sum[r] + pre[l];}

bool check(ll k) {
    for (int i = 1; i <= n; i++) f[i] = 1e18, g[i] = m + 1;
    q[head = tail = 1] = 0;
    for (int i = 1; i <= n; i++) {
        while (head < tail && (i + 1) * (x(q[head + 1]) - x(q[head])) > (y(q[head + 1]) - y(q[head]))) head++;
        f[i] = f[q[head]] + w(q[head], i) - k, g[i] = g[q[head]] + 1;
        while (head < tail && (y(i) - y(q[tail])) * (x(q[tail] - x(q[tail - 1]))) < (y(q[tail]) - y(q[tail - 1])) * (x(i) - x(q[tail]))) tail--;
        q[++tail] = i;
    }
    return g[n] <= m;
}

int main() {
    ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
    cin >> n >> m >> (s + 1);
    for (int i = 1, up = 0, right = 0; i <= 2 * n; i++) s[i] == 'A' ? right++ : r[++up] = right;
    ll init = 0;
    for (int i = 1; i <= n; i++) {
        if (r[i] < i) init += i - r[i], r[i] = i;
        pre[i] += pre[i - 1] + r[i];
        cnt[r[i]]++, sum[r[i]] += r[i];
    }
    for (int i = 1; i <= n; i++) cnt[i] += cnt[i - 1], sum[i] += sum[i - 1];
    ll L = -5e11, R = 0, mid, ans;
    while (L <= R) {
        mid = (L + R) >> 1;
        if (check(mid)) ans = mid, L = mid + 1;
        else R = mid - 1;
    }
    check(ans), cout << init + f[n] + m * ans;
    return 0;
}