【luogu题解】P9749 [CSP-J 2023] 公路

发布时间 2023-11-22 16:20:41作者: daiyulong

\(Meaning\)

\(Solution\)

这道题我来讲一个不一样的解法:\(dp\)

在写 \(dp\) 之前,我们需要明确以下几个东西:状态的表示,状态转移方程,边界条件和答案的表示。

状态的表示

\(dp[i]\) 表示到达第 \(i\) 个站点所需要的最少钱数, \(w[i]\) 表示在使用最少钱数到达第 \(i\) 个站点时多余的路程。

状态转移方程

\[ dp[i]=dp[i-1]+\bigg\lceil\frac{v[i-1]-w[i-1]}{d}\bigg\rceil\times pre\_min(i-1) \]

\[ w[i]=\bigg\lceil\frac{v[i-1]-w[i-1]}{d}\bigg\rceil-v[i-1]+w[i-1] \]

其中 \(pre\_min(i)\) 表示前 \(i\) 个站点中最小的油价。

边界条件

\[ dp[i]=0,w[i]=0 \]

答案的表示

\[ dp[n] \]

问题

在状态转移方程中,怎样在 \(O(1)\) 的时间复杂度下完成 \(pre\_min\) 函数呢?

这就涉及到了一个算法:

\(ST\)

在算法和数据结构中,ST表(Sparse Table)是一种用于解决区间查询问题的数据结构。它可以有效地回答各种形式的查询,例如最小值、最大值、区间和等。

简介

ST表的主要思想是通过预处理来加速区间查询。它使用倍增 DP 的思想将一个数组分割成多个子区间,并在每个子区间上计算出某种操作的结果。然后,根据这些预先计算好的结果,我们可以根据需要合并区间来回答各种查询。

具体的实现过程如下:

  1. 初始化ST表,ST表是一个二维数组。
  2. 将输入的原始数组填充到ST表的第一行。
  3. 使用递推关系填充ST表的其他行,直到得到完整的ST表。
  4. 根据查询的起始位置和区间长度,在ST表中找到对应区间的值,结合适当的操作得出最终结果。

查询操作

对于任何查询操作,我们可以使用以下步骤来回答:

  1. 计算出查询区间的长度len。

  2. 找到大于等于len的最大值j,使得2^j <= len。

  3. 使用预处理的结果和递推关系,在ST表中找到对应的值,并结合适当的操作得到查询结果。

这种方法的时间复杂度为O(1),因为我们只需进行几次常数级别的操作即可回答查询。

应用场景

ST表在解决各种区间查询问题时非常有用。以下是一些常见的应用场景:

  • 查询最小值/最大值:通过选择适当的查询操作,在O(1)的时间复杂度内回答每个查询。
  • 区间和查询:可以通过使用累积和来实现区间和查询。
  • 区间gcd查询:可以通过预处理和递推关系计算区间内的最大公约数。

总结

ST表是一种高效解决区间查询问题的数据结构。通过预先计算和递推关系,我们可以在O(1)的时间复杂度内回答各种形式的查询。它的实现相对简单且灵活,适用于多种应用场景。

模板

初始化(时间复杂度 \(O(\log_2n)\)

for(int i=1;i<=n;i++) {
   	st[i][0]=a[i];
}
for(int j=1;(1<<j)<=n;j++) {
	for(int i=1;i+(1<<j)-1<=n;i++) {
		st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
	}
}

查询(时间复杂度 \(O(1)\)

l=1,r=i-1,len=log2(r-l+1);
pm=min(st[l][len],st[r-(1<<len)+1][len]);

解决问题

有了ST表,我们就可以在O(1)的时间复杂度中查询最值了,那我们程序的最终问题:TLE也解决了。程序整体时间复杂度为O(n),可以通过此题。

AC代码如下。

\(Accept\ Code\)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=1e5+5;
ll v[N],a[N],w[N],dp[N],st[N][20];
ll n,d,l,r,len,pm;
int main() {
    cin>>n>>d;
    for(int i=1;i<n;i++) {
        cin>>v[i];
    }
    for(int i=1;i<=n;i++) {
        cin>>a[i];
    }
    for(int i=1;i<=n;i++) {
    	st[i][0]=a[i];
	}
	for(int j=1;(1<<j)<=n;j++) {
		for(int i=1;i+(1<<j)-1<=n;i++) {
			st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
		}
	}
    for(int i=2;i<=n;i++) {
		l=1,r=i-1,len=log2(r-l+1);
		pm=min(st[l][len],st[r-(1<<len)+1][len]);
		dp[i]=dp[i-1]+ceil(1.0*(v[i-1]-w[i-1])/d)*pm;
		w[i]=ceil(1.0*(v[i-1]-w[i-1])/d)*d-(v[i-1]-w[i-1]);
	}
	cout<<dp[n];
    return 0;
}