[ABC319G] Counting Shortest Paths 题解

发布时间 2023-09-11 07:35:09作者: User-Unauthorized

题意

给定由 \(N\) 个节点组成的无向完全图 \(G\),并删去 \(M\) 条边,求该图的最短路数量。

\(2 \le N \le 2 \times 10^5, 0 \le M \le \min\left\{2 \times 10^5, \dfrac{N(N - 1)}{2}\right\}\))。

题解

首先考虑若有一张边数较少的无向图该如何对最短路计数,由于边不带权,所以一个节点到 \(1\) 号节点的最短路就是该节点在以 \(1\) 号节点为根的 \(\tt{BFS}\) 生成树深度。故可进行转移,设 \(\operatorname{count}_u\) 表示从 \(1\) 号节点到节点 \(u\) 的最短路个数,\(\operatorname{dist}_u\) 表示从 \(1\) 号节点到节点 \(u\) 的最短路长度,那么有转移

\[\operatorname{count}_u = \sum\limits_{\left(u, v\right) \in G \land \operatorname{dist}_v = \operatorname{dist}_u - 1} \operatorname{count}_v \]

但是由于此图边数过多,暴力转移是不可行的。考虑也按完全图删边的思想进行转移,即首先设当前图为完全图进行转移,然后将通过被删除的边转移而来的非法贡献去除,设被删除的边组成的图为 \(G^{\prime}\),那么有转移

\[\operatorname{count}_u = \sum\limits_{\operatorname{dist}_v = \operatorname{dist}_u - 1} \operatorname{count}_v - \sum\limits_{\left(u, v\right) \in G^{\prime} \land \operatorname{dist}_v = \operatorname{dist}_u - 1} \operatorname{count}_v \]

在求出每个节点的 \(\operatorname{count}\) 后维护一个数组 \(sum_x = \sum\limits_{\operatorname{dist}_u = x} \operatorname{count}_u\) 即可快速转移,这部分复杂度为 \(\mathcal{O}(N + M)\)

现在还需要解决的问题是如何快速的从当前节点扩展到其他节点,可以发现每个节点只会进入队列一次,故可以维护一个列表,代表未进队的节点,每次扩展节点时遍历列表,并检查边 \((u, v)\) 是否被删除即可。具体的,可以维护一个 \(\tt{bool}\) 数组,每次扩展节点前遍历所有从当前节点出发被删除的边,并打上标记,然后遍历列表后清空即可。可以发现每条被删除的边最多引起一次对列表单个元素的重复遍历,故这部分复杂度为 \(\mathcal{O}(N + M)\)

总复杂度为 \(\mathcal{O}(N + M)\),可以通过本题。

Code

//G
#include <bits/stdc++.h>

typedef long long valueType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;
typedef std::vector<bool> bitset;
typedef std::queue<valueType> queue;
typedef std::list<valueType> list;

constexpr valueType MOD = 998244353;

template<typename T1, typename T2, typename T3 = valueType>
void Inc(T1 &a, T2 b, const T3 &mod = MOD) {
    a = a + b;

    if (a >= mod)
        a -= mod;
}

template<typename T1, typename T2, typename T3 = valueType>
void Dec(T1 &a, T2 b, const T3 &mod = MOD) {
    a = a - b;

    if (a < 0)
        a += mod;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    valueType N, M;

    std::cin >> N >> M;

    ValueMatrix G(N + 1);

    for (valueType i = 0; i < M; ++i) {
        valueType u, v;

        std::cin >> u >> v;

        G[u].push_back(v);
        G[v].push_back(u);
    }

    ValueVector dist(N + 1, N + 100), count(N + 1, -1), sum(N + 1, 0);
    bitset removed(N + 1, false);
    list Q(N - 1);

    queue que;

    dist[1] = 1;
    count[1] = 1;
    sum[0] = 1;
    que.push(1);
    std::iota(Q.begin(), Q.end(), 2);

    while (!que.empty()) {
        valueType const u = que.front();

        que.pop();

        count[u] = sum[dist[u] - 1];

        for (auto const &iter: G[u]) {
            if (dist[iter] == dist[u] - 1) {
                Dec(count[u], count[iter]);
            }

            removed[iter] = true;
        }

        Inc(sum[dist[u]], count[u]);

        for (auto iter = Q.begin(); iter != Q.end();) {
            if (!removed[*iter]) {
                dist[*iter] = dist[u] + 1;

                que.push(*iter);

                iter = Q.erase(iter);
            } else {
                ++iter;
            }
        }

        for (auto const &iter: G[u])
            removed[iter] = false;
    }

    std::cout << count[N] << std::endl;

    return 0;
}