浅谈斜率优化DP

发布时间 2023-11-13 15:16:45作者: 北烛青澜

前言

考试 T2 出题人放了个树上斜率优化 DP,直接被同校 OIER 吊起来锤。

离 NOIP 还有不到一周,赶紧学一点。

引入

斜率

斜率,数学、几何学名词,是表示一条直线(或曲线的切线)关于(横)坐标轴倾斜程度的量。它通常用直线(或曲线的切线)与(横)坐标轴夹角的正切,或两点的纵坐标之差与横坐标之差的比来表示。

斜率可以用来描述一个坡的倾斜程度,公式 \(k = \frac{\Delta y}{\Delta x}\)

初中学过一元一次函数 \(y = kx + b\),这里的 \(k\) 就是这个函数表示的直线的斜率。

解决什么

一般对于形如 \(f[i] = \min(a[i] \times b[j] + c[i] + d[j])\) 这种类型的 DP 转移式子都可以用上斜率优化。

其中 \(b\) 要满足单调递增。

看到中间有一部分与 \(i,j\) 都有关,所以这个时候要用到斜率优化。

理解

下面来以一道题目为例进行讲解。

P3195 [HNOI2008] 玩具装箱

看完题目应该都可以想出来一个 \(O(n^2)\) 的 DP,那就是:

\(f[i]\) 表示考虑到第 \(i\) 个玩具所用的最小花费,\(sum[i]\) 为从 \(1\sim i\) 的玩具长度总和。

\[f[i] = \min\{f[j] + (sum[i] - sum[j] + i - j - L - 1)^2\} \]

我们尝试把这一堆东西分分类,把只有 \(i\) 的挪到一起,只有 \(j\) 的挪到一起,剩下的挪到中间。

得到:

\[f[i] = \min\{f[j] + (sum[i] + i - sum[j] - j - L - 1)^2\} \]

\(A=sum[i] + i, B = sum[j] - j - L - 1\)

那么就是 :

\[f[i] = f[j] + A^2 -2AB + B^2 \]

显然的,\(A^2\) 我们可以预处理,是已知的,由于前缀和,而且玩具长度至少为 \(1\),所以 \(2A\) 是严格单调递增的,\(B\) 数组我们也可以直接预处理。

\[f[j] + B^2 = 2AB + f[j] + A^2 \]

这个式子是把只与 \(j\) 有关的移到左边了,可以发现形式上是和 \(y = kx + b\) 一样的。

那么我们就可以把一个之前转移完成的状态看成是一个 \((B, f[j] + B^2)\) 的点,而 \(2A\) 就是经过他们的直线的斜率。

那么我们要求 \(f[i]\) 的话,就是求这个点和这个斜率为 \(2A\) 的直线的最大可能截距是多少。

于图像中

假设下面的三个点是我们待选的状态:

image

假设我们当前要求的斜率画出来是下面这样:

image

我们就从下往上,一点一点向上挪,直到碰到的第一个点,此时的截距一定最大。我们也能看出的确 \(C\) 点最优。

那么此时的 \(A\) 点好像没有什么用了,可以扔掉吗?

答案是可以,因为斜率是单调递增的,既然这次第一个碰不到 \(A\),那么后面肯定也不是第一个碰到。

但是我们如何做到最快找出呢?

队列维护

image

观察这张图片,假设里面的点都是之前转移完的状态。

比较 \(AE,AB\) 的斜率。

不难发现 \(AB\) 的斜率比 \(AE\) 小,想一下之前说的,如果拿一条直线去碰这个图形,从各个角度去碰,最外层的点会形成一个凸包,而这个凸包内的点,是无论如何都碰不到的。

这个我们可以用一个队列来维护一个下凸壳,也就是凸包的一部分。

然后根据上面说的,要是队列头的两个元素形成的直线斜率比当前的小,也可以直接弹出。

这样队列的队头元素就是我们要转移的值了。

code:


/*
 * @Author: Aisaka_Taiga
 * @Date: 2023-11-13 14:11:27
 * @LastEditTime: 2023-11-13 15:09:40
 * @LastEditors: Aisaka_Taiga
 * @FilePath: \Desktop\P3195.cpp
 * The heart is higher than the sky, and life is thinner than paper.
 */
#include <bits/stdc++.h>

#define pf(x) ((x) * (x))
#define int long long
#define DB double
#define N 1000100

using namespace std;

inline int read()
{
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') f = -1; c = getchar();}
    while(c <= '9' && c >= '0') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
    return x * f;
}

int n, L, q[N], c[N], f[N], sum[N], A[N], B[N];

inline int X(int x){return B[x];}

inline int Y(int x){return f[x] + pf(B[x]);}

inline DB xl(int i, int j){return (Y(i) - Y(j)) * 1.0 / (X(i) - X(j));}

signed main()
{
    n = read(), L = read();
    for(int i = 1; i <= n; i ++) c[i] = read();
    for(int i = 1; i <= n; i ++)
    {
        sum[i] = sum[i - 1] + c[i];
        B[i] = sum[i] + i + L + 1;
        A[i] = sum[i] + i;
    }
    B[0] = L + 1;
    int h = 1, t = 1;
    for(int i = 1; i <= n; i ++)
    {
        while(h < t && xl(q[h], q[h + 1]) < 2 * A[i]) h ++;
        int j = q[h];
        f[i] = f[j] + pf(A[i] - B[j]);
        while(h < t && xl(q[t - 1], i) < xl(q[t - 1], q[t])) t --;
        q[++ t] = i;
    }
    cout << f[n] << endl;
    return 0;
}

参考:https://www.cnblogs.com/terribleterrible/p/9669614.html