浅谈斜率优化

发布时间 2023-09-16 12:11:57作者: 球君

  众所周知,动态规划推出状态转移方程是很困难的,推出状态转移方程后发现复杂度爆炸是很炸裂的,所以这就是斜率优化存在的意义----降低转移方程的复杂度

  在看具体例子之前,我们先大致的介绍一下斜率优化的原理

  考虑一个这样的状态转移方程,f[i]=min{f[j]-k[i]*j+s[i]}   j<i,f 用于储存/转移状态,k、s 提前给定

  对于该阶段的决策,我们不妨把所有只跟i有关项的看作是常量,跟j有关的看作是变量,可以得到 f[i]=min{f[j]-k[i]*j}+s[i];

  即,对于该阶段的决策,我们要找到最小的 f[j]-k[i]*j

  可以看作是两个函数的差

  不妨建立坐标,对于函数 y=f[x] x∈[1,i) ,点的坐标均为 (x,f[x]) ,举个例子,如下图所示

 

  然后我们再绘制 y=k[i]*x 的图象,大致就长这样  

  挑选其中一个点 (t,f[t]) 作 y=k[i]*x 的平行线

  显然J点的纵坐标就是 f[t]-k[i]*t, 我们要让这个J点的纵坐标尽可能地小

  为寻找这个最小值,不妨设有一条斜率为 k[i] 的直线从下往上平移,J点的纵坐标最小的时候就是这条直线上移到第一次穿过一个点(即第一次该直线上有一点可表示为 (j,f[j]))时的截距(该直线与y轴的交点的纵坐标)

   很显然,只有图象下凸的一段是可能成为最优解的,因为一条直线从下往上平移到上凸的点(举个例子,B点)之前一定会先经过它两侧的点

  

 

  如上图,AM 和 CN 在经过 B 前都会要么先经过 A 要么先经过 C

  既然上凸的点是不会再被选中进行更新 f[i] 的,那么不妨把这些点剔除出去,如图

  这样这个图象就是斜率单调递增的了,即,这个图象都是下凸的

  观察直线取到最值的情况,很显然除了直线经过的那个点,剩下的点都高于直线

  也就是说那个交点左侧的线段的斜率小于直线斜率,而焦点右侧的线段的斜率大于直线斜率,因为斜率单调递增,所以满足要求的点最多只有一个

  最粗暴的方法是从左往右扫一边,如果右侧线段斜率小于直线斜率就接着往下扫

  更优一点的做法就是二分查找,如果左侧线段斜率大于直线斜率就继续往左区间查找,如果右侧线段斜率小于直线斜率就继续往右区间查找

  下面我们来结合一道题看一下具体实现


 

   P5785 [SDOI2012] 任务安排

  这题还有个弱化版P2365 任务安排

  应该不难想到一个朴素的动态规划,f[i][j]  表示前i个任务,分成j批完成所需的最小费用

  预处理出 t 和 c 的前缀和 sumc、sumt

  状态转移方程为 f[i][j]=min{f[k][j-1]+(s*j+sumt[i])*(sumc[i]-sumc[k])}  k∈[0,i)

  枚举 i,j,k 进行状态转移,时间复杂度O(n^3)

  考虑优化,可以看出枚举j的主要作用是计算要启动机器多少次,把上面min的括号拆开,即 f[k][j-1]+s*j*(sumc[i]-sumc[k])+sumt[i]*(sumc[i]-sumc[k])

  其中 sumt[i]*(sumc[i]-sumc[k]) 跟 j 无关,主要是看 s*j*(sumc[i]-sumc[k])

  观察式子,发现j的作用仅是为了计算此前个过程中的启动时间和,而每次启动机器都会给后面每一步带来 s*c[j] 的额外费用,那我们不妨在此前过程中就把启动机器对后面产生的费用计入当前费用

  通过这种思想我们可以去掉 f 的一维

  令 f[i] 表示将前i个任务分成若干批执行的最小费用,状态转移方程为 f[i]=min{f[j]+sumt[i]*(sumc[i]-sumc[j])+s*(sumc[n]-sumc[j])}  j∈[0,i)

  枚举 i 和 j 进行状态转移

  然后这就是个O(n^2)的复杂度的状态转移了

  做到这一步你就可以通过P2365 任务安排

  看一下完整代码

#include<bits/stdc++.h>
using namespace std;
const int mn=5005;
long long f[mn],sumt[mn],sumc[mn];
int n,s;
int main()
{
	cin>>n>>s;
	for(int i=1;i<=n;i++)
	{
		int x,y;
		cin>>x>>y;
		sumt[i]=sumt[i-1]+x;
		sumc[i]=sumc[i-1]+y;
	}
	memset(f,0x7f,sizeof f);
	f[0]=0;
	for(int i=1;i<=n;i++)
	for(int j=0;j<i;j++) f[i]=min(f[i],f[j]+sumt[i]*(sumc[i]-sumc[j])+s*(sumc[n]-sumc[j]));
	cout<<f[n]<<endl; 
	return 0;
}

  不过P5785 [SDOI2012] 任务安排的数据范围增大后即便是 O(N^2) 的复杂度也过不了了

  那么可以考虑斜率优化

  看一下刚刚的状态转移方程 f[i]=min{f[j]+sumt[i]*(sumc[i]-sumc[j])+s*(sumc[n]-sumc[j])}  j∈[0,i)

  暂且把跟 i 有关的的项看作是常数项,状态转移方程可以变形为 f[i]=min{f[j]+sumt[i]*(sumc[i]-sumc[j])+s*(sumc[n]-sumc[j])}  j∈[0,i)

  即  f[i]=min{f[j]-(sumt[i]+s)*sumc[j]}+sumt[i]*sumc[i]+s*sumc[n]  j∈[0,i)

  后面的常数项不用管,我们的目的在于将 f[j]-(sumt[i]+s)*sumc[j] 取到最小值

  按照我们前文介绍过的斜率优化的方法,建立平面坐标系,对于平面上的点,坐标为(sumc[x],f[x]),比如像这个坐标系(原谅我画图比较偷懒)

  关于寻找最优解的方法上文已经说过了,这里就来讲讲代码实现

  首先我们要维护这个函数图象的下凸性,如果加入一个点时,发现这个点和它左侧第一个点的连线的斜率小于左侧第一个点和左侧第二个点的斜率,那么就将它左侧第一个点删除,然后继续检查上一个点的与新点的下凸性……

  具体而言就是加入点 C 的时候检查点 B 在加入点 C 后是否成上凸了,若是就将点B删去

  用一个队列来维护,删去B即弹出队尾

  最后的图象就会是下凸的

  而后面的二分查找在前面已经说的很清楚了就不赘述了

  但是我们知道斜率是Δy/Δx,所以如果精度不够或者除出来是无限小数的话我们在比较斜率的时候就可能出差错

  比如说设 C(sumc[c],f[c]),D(sumc[d],f[d]),E(sumc[e],f[e])

  E 能够直接加入队尾的条件是 (f[e]-f[d])/(sumc[e]-sumc[d])>(f[d]-f[c])/(sumc[d]-sumc[c])

  即(f[e]-f[d])*(sumc[d]-sumc[c])>(sumc[e]-sumc[d])*(f[d]-f[c])

  精度的问题就解决了

  同理在查找过程中比较斜率也是用这种方法dp[q[mid+1]]-dp[q[mid]]<=(sumt[i]+s)*(sumc[q[mid+1]]-sumc[q[mid]])

  接下来看一下完整代码

#include<bits/stdc++.h>
#define int long long
const int mn=300005;
using namespace std;
int n,s,head=1,tail=1,q[mn];
int c[mn],t[mn],sumc[mn],sumt[mn],dp[mn];
int binary_search(int i)
{
	if(head==tail) return q[head];//只有唯一的选择 
	int l=head,r=tail;
	while(l<r) 
	{
		int mid=l+r>>1;
		if(dp[q[mid+1]]-dp[q[mid]]<=(sumt[i]+s)*(sumc[q[mid+1]]-sumc[q[mid]])) l=mid+1;//当左侧线段的斜率大于新线段斜率 
		else r=mid;//当右侧线段斜率小于新线段斜率 
	}
	return q[l];
}
signed main() 
{
	cin>>n;
	cin>>s;
	for(int i=1;i<=n;i++) 
	{
		cin>>t[i]>>c[i];
		sumc[i]=sumc[i-1]+c[i];
		sumt[i]=sumt[i-1]+t[i];
	}
	memset(dp,0x3f,sizeof(dp));
	q[head]=0;
	dp[0]=0;
	for(int i=1;i<=n;i++) 
	{
		int p=binary_search(i);//斜率优化 
		dp[i]=dp[p]-(sumt[i]+s)*sumc[p]+sumt[i]*sumc[i]+s*sumc[n];//状态转移 
		while(head<tail&&(dp[q[tail]]-dp[q[tail-1]])*(sumc[i]-sumc[q[tail]])>=(dp[i]-dp[q[tail]])*(sumc[q[tail]]-sumc[q[tail-1]])) tail--;//将所有上凸的节点删除 
		q[++tail]=i;//加入新节点 
	}
	cout<<dp[n]<<endl;
	return 0;
}

  复杂度O(nlogn)

   然后就可以过了