P5336 [THUSC2016] 成绩单

发布时间 2023-10-27 12:58:16作者: purplevine

这得是区间 dp。还需要限制一下值域。因此 dp 状态时 \(f_{l, r, x, y}\) 表示使 \([l, r]\) 区间所有值都处于 \([x, y]\) 的最小花费。设 \(g_{l, r} = \min\{f_{l, r, x, y} + a + b (x - y) ^ 2\}\)

思考一个序列可以被怎么消掉。维护一个类似括号序列的东西,左右匹配的括号代表消这一段区间。维护一个栈,遇到左括号入栈,遇到空字符也入栈,遇到右括号弹栈到匹配的左括号为止。因此一个合法的消除方案可以抽象成入栈、一段后缀出栈。

因此转移时拓展右端点并枚举入栈或出栈。入栈时

\[f_{l, r, x, y} \to f_{l, r+1, \min\{x, w_{r+1}\}, \max\{w_{r+1}\}} \]

出栈时

\[f_{l, k, x, y} + g_{k+1, r+1} \to f_{l, r+1, x, y} \]

这是好的做法。

思考序列被消掉的方法,是为了寻找一个好的 dp 转移方向。dp 的限制够严了,足够把求出答案的信息囊括入内。

写记搜感觉也不错,但是要改成统计式而非贡献式。转移顺序是按 \(r\) 从小到大,\(l\) 从大到小。需要离散化,这不是重点。复杂度是 \(O(n^5)\)

谢谢 这篇题解,感觉是最自然的一篇。但是似乎没有题解说清楚为什么要按那样子分类着转移啊。想到那个栈后,对于转移方向,会更清晰一些吧。

心路回顾:这是决策单调性吧……这个函数怎么没有四边形不等式啊……区间 dp???……\(n=50\)???

#include <bits/stdc++.h>

using LL = long long;

const LL inf = 0x3f3f3f3f3f3f3f3f;

void cmin(LL &x, LL y) { x = std::min(x, y); }

int main() {
  int n, a, b; scanf("%d %d %d", &n, &a, &b);
  std::vector<int> w(n);
  for (auto &u : w) scanf("%d", &u);
  auto s = w;
  std::sort(s.begin(), s.end());
  s.erase(std::unique(s.begin(), s.end()), s.end());
  for (auto &u : w) u = std::lower_bound(s.begin(), s.end(), u) - s.begin();
  int m = s.size();
  std::vector f(n, std::vector (n, std::vector (m, std::vector<LL> (m, inf))));
  for (int i = 0; i < n; i++)
    f[i][i][w[i]][w[i]] = 0;
  std::vector g(n, std::vector<LL> (n, inf));
  for (int r = 0; r < n; r++) 
    for (int l = r; l >= 0; l--) {
      if (r < n - 1) {
        for (int x = 0; x < m; x++) 
          for (int y = x; y < m; y++) {
            cmin(f[l][r + 1][std::min(x, w[r + 1])][std::max(y, w[r + 1])], f[l][r][x][y]);
        }
      }
      for (int x = 0; x < m; x++) 
        for (int y = 0; y < m; y++) {
          // if (f[l][r][x][y] != inf) printf("%d %d %d %d : %lld\n", l, r, x, y, f[l][r][x][y]);
          cmin(g[l][r], f[l][r][x][y] + a + 1ll * b * (s[y] - s[x]) * (s[y] - s[x]));
        }
      // printf("%d %d : %lld\n", l, r, g[l][r]);
      for (int x = 0; x < m; x++) 
        for (int y = x; y < m; y++)
          for (int k = 0; k < l; k++)
            cmin(f[k][r][x][y], f[k][l - 1][x][y] + g[l][r]);
      
    }
  printf("%lld\n", g[0][n - 1]);
}