CSP-J 2023 T4 旅游巴士(CSP-J考纲范围内的解法:BFS+二分)

发布时间 2023-11-23 17:22:46作者: 江城伍月

原题连接:https://www.luogu.com.cn/problem/P9751

题意解读:

给定n个点,m条边的有向带权图(权重为能通过该条边的最小时间),求从起点1到终点n的最短距离,由于出发和达到时间都需为k的倍数,所以这个最短距离也必须是k的倍数。限制条件:每通过一条路径,时长比上一个节点+1,需要根据当前节点的时长判断能否通过下一个节点。

样例模拟:

解题思路:

如果你没有任何思路,请注意此题无解时输出-1,那么直接输出-1将获得宝贵的5分:

5分的代码:

#include <bits/stdc++.h>
using namespace std;
int main()
{
    cout << -1;
    return 0;
}

接下来言归正传,对于此题,网上有很多解析给出了多个思路:反向建图、分层图、Dijikstra等,而对于CSP-J的选手,想在不超出考纲知识点范围内用最直观的方法解决该问题,解法如下:

第一步:分析如果出发时间越大,这样道路的限制就越少,必定能走出最短路径,而出发时间越小,道路受到限制越多,所以可以直接枚举出发时间0、k、2k......

第二步:在确定了出发时间之后,如何求最短路呢?这里通过每一条路径的时间都是1单位,因此可以用BFS来求一个正好可以被k整除的最短路。

观察一下测试数据:测试点6~7,k <= 1, ai = 0, 说明任意时间出发都可以通过每一条路径,直接取1-n的最短路即可,这样做至少可以得10分,运气好更多:

15分的代码:

#include <bits/stdc++.h>
using namespace std;

const int N = 10010;
const int M = 20010;

int n, m, k;
int head[N], e[M], w[M], ne[M], idx; //用数组模拟邻接表来存储图
queue<int> q;
int depth[N]; //depth[i]表示从起点走到i节点后的时间

void add(int u, int v, int a)
{
    e[idx] = v;
    w[idx] = a;
    ne[idx] = head[u];
    head[u] = idx++;
}

int bfs(int start)
{
    memset(depth, -1, sizeof(depth));
    q.push(1); 
    depth[1] = start; //起点的时间设置为出发时间
    while(q.size())
    {
        int u = q.front();q.pop();
        for(int i = head[u]; i != -1; i = ne[i])
        {
            int v = e[i]; // 对于每一个邻接点,取节点号
            int a = w[i]; // 对于每一个邻接点,取权值,即可通行时间
            if(depth[v] == -1 && depth[u] >= a) //如果节点没有访问过,并且当前时间可以通行
            {
                q.push(v);
                depth[v] = depth[u] + 1; //当前节点时间+1即为下一个节点的时间
            }
        }
    }
    if(depth[n] != -1 && depth[n] % k == 0) return depth[n]; //到节点n的时间%k是0,则说明找到一个可行解
    return -1;
}

int main()
{
    memset(head, -1, sizeof(head));
    
    cin >> n >> m >> k;
    int u, v, a;
    int mintime = 0, maxtime = 0; //设置两个变量来保存可以通行的最小、最大时间,用于减少遍历的次数
    while(m--)
    {
        cin >> u >> v >> a;
        add(u, v, a); //建图
        mintime = min(mintime, a);
        maxtime = max(maxtime, a);
    }

    int ans = 0x3f3f3f3f;
   
    for(int i = (maxtime / k + 1) * k; i >= mintime / k * k; i -= k) // 从超过最大时间的第一个被k整除的时间,遍历到最小的能被k整除的时间
    {
        int res = bfs(i); //bfs求解答案
        if(res % k == 0) 
        {
            ans = min(ans, res); //求答案中最小值
        }
    }
    
    if(ans != 0x3f3f3f3f) cout << ans;
    else cout << -1;

    return 0;
}

第三步:继续往下思考,单纯用朴素的BFS求最短路只能得部分分,其原因在于从1-n的路径往往有多条,而最短路不一定符合能被k整除,所以一个直观的想法是能否把所有从1-n的最短路的时间都计算出来,然后找到第一个能被k整除的最短路即可。如何做到呢?在计算从起点走到某个节点之后的时间时,用到了depth[N]数组,这里只能保存一个唯一路径,下次就不会再走该节点,如果给depth数据增加一维,定义depth[N][K],K的取值范围从0~k-1,这样可以把所有从1~n的步数%k为0~k-1最短路径都保存在depth中。

举例:

还是以样例为例,当从3时出发可以从1->2->5,这样经过了2个单位时间,到节点5的时间为3+2=5,可以保存为depth[5][2]=3;

另外一条路径1->3->4->5,经过了3个单位时间,在判断节点5是否访问过时,是用depth[5][3],显然前面只保存了depth[5][2],所以还可以访问节点5,然后更新depth[5][3]=3+3=6。

顺着这个思路,实现代码,就能保存程序的正确性了:

35分的代码:

#include <bits/stdc++.h>
using namespace std;

const int N = 10010;
const int M = 20010;
const int K = 110;

int n, m, k;
int head[N], e[M], w[M], ne[M], idx;
queue<pair<int, int> > q; //队列中需要保存节点号以及到达该节点是从起点经过了几步
int depth[N][K]; //depth[i][j]表示从起点到达i节点,经过j步的时间,j从上一个节点的对应值加1即可得到

void add(int u, int v, int a)
{
    e[idx] = v;
    w[idx] = a;
    ne[idx] = head[u];
    head[u] = idx++;
}

int bfs(int start)
{
    memset(depth, -1, sizeof(depth));
    q.push(make_pair(1, 0)); // 把起点1放入队列,对应的步数为0
    depth[1][0] = start; // 设置初始时间
    while(q.size())
    {
        pair<int, int> p = q.front();q.pop();
        int u = p.first; //当前节点号
        int x = p.second; //到达当前节点步数
        for(int i = head[u]; i != -1; i = ne[i])
        {
            int v = e[i]; // 对于每一个邻接点,取节点号
            int a = w[i]; // 对于每一个邻接点,取权值,即可通行时间
            if(depth[v][(x + 1) % k] == -1 && depth[u][x] >= a) //邻接点的步数是上一个节点+1,%k是便于求最短路,不会再绕k步;同时判断当前时间能否通行
            {
                q.push(make_pair(v, (x + 1) % k));
                depth[v][(x + 1) % k] = depth[u][x] + 1; //当前节点步数+1即为下一个节点的步数
            }
        }
    }
    if(depth[n][0] != -1) return depth[n][0]; //到节点n,经过步数%k是0有值,则说明找到一个可行解
    return -1;
}

int main()
{
    memset(head, -1, sizeof(head));
    
    cin >> n >> m >> k;
    int u, v, a;
    int mintime = 0, maxtime = 0;
    while(m--)
    {
        cin >> u >> v >> a;
        add(u, v, a);
        mintime = min(mintime, a);
        maxtime = max(maxtime, a);
    }

    int ans = 0x3f3f3f3f;
   
    for(int i = (maxtime / k + 1) * k; i >= mintime / k * k; i -= k)
    {
        int res = bfs(i);
        if(res % k == 0) 
        {
            ans = min(ans, res);
        }
    }
    
    if(ans != 0x3f3f3f3f) cout << ans;
    else cout << -1;

    return 0;
}

第四步:目前为止,程序逻辑都对,唯一的问题是性能,其实一开始就应该可以想到,此问题具有单调性:假设出发时间够大,所有节点都能通行的情况,一定能找到一个最短路,所以可以用二分答案来不断缩小出发时间,逼近可以找到可行解的那个最小时间,下面给出代码:

100分的代码:

#include <bits/stdc++.h>
using namespace std;

const int N = 10010;
const int M = 20010;
const int K = 110;

int n, m, k;
int head[N], e[M], w[M], ne[M], idx;
queue<pair<int, int> > q;
int depth[N][K]; 

void add(int u, int v, int a)
{
    e[idx] = v;
    w[idx] = a;
    ne[idx] = head[u];
    head[u] = idx++;
}

int bfs(int start)
{
    memset(depth, -1, sizeof(depth));
    q.push(make_pair(1, 0));
    depth[1][0] = start; 
    while(q.size())
    {
        pair<int, int> p = q.front();q.pop();
        int u = p.first; 
        int x = p.second; 
        for(int i = head[u]; i != -1; i = ne[i])
        {
            int v = e[i]; 
            int a = w[i]; 
            if(depth[v][(x + 1) % k] == -1 && depth[u][x] >= a)
            {
                q.push(make_pair(v, (x + 1) % k));
                depth[v][(x + 1) % k] = depth[u][x] + 1; 
            }
        }
    }
    if(depth[n][0] != -1) return depth[n][0]; 
    return -1;
}

int main()
{
    memset(head, -1, sizeof(head));
    
    cin >> n >> m >> k;
    int u, v, a;
    int mintime = 0, maxtime = 0;
    while(m--)
    {
        cin >> u >> v >> a;
        add(u, v, a);
        mintime = min(mintime, a);
        maxtime = max(maxtime, a);
    }

    int ans = 0x3f3f3f3f;
   
    //二分答案优化
    int l = mintime / k, r = 2 * maxtime / k  + 1;
    while(l < r)
    {
        int mid = (l + r) >> 1;
        int res;
        if((res = bfs(mid * k)) != -1){
            r = mid;
            ans = min(ans, res);
        } 
        else l = mid + 1;
    }
 
    if(ans != 0x3f3f3f3f) cout << ans;
    else cout << -1;

    return 0;
}