题意
给定由 \(N\) 个节点组成的无向完全图 \(G\),并删去 \(M\) 条边,求该图的最短路数量。
(\(2 \le N \le 2 \times 10^5, 0 \le M \le \min\left\{2 \times 10^5, \dfrac{N(N - 1)}{2}\right\}\))。
题解
首先考虑若有一张边数较少的无向图该如何对最短路计数,由于边不带权,所以一个节点到 \(1\) 号节点的最短路就是该节点在以 \(1\) 号节点为根的 \(\tt{BFS}\) 生成树深度。故可进行转移,设 \(\operatorname{count}_u\) 表示从 \(1\) 号节点到节点 \(u\) 的最短路个数,\(\operatorname{dist}_u\) 表示从 \(1\) 号节点到节点 \(u\) 的最短路长度,那么有转移
但是由于此图边数过多,暴力转移是不可行的。考虑也按完全图删边的思想进行转移,即首先设当前图为完全图进行转移,然后将通过被删除的边转移而来的非法贡献去除,设被删除的边组成的图为 \(G^{\prime}\),那么有转移
在求出每个节点的 \(\operatorname{count}\) 后维护一个数组 \(sum_x = \sum\limits_{\operatorname{dist}_u = x} \operatorname{count}_u\) 即可快速转移,这部分复杂度为 \(\mathcal{O}(N + M)\)。
现在还需要解决的问题是如何快速的从当前节点扩展到其他节点,可以发现每个节点只会进入队列一次,故可以维护一个列表,代表未进队的节点,每次扩展节点时遍历列表,并检查边 \((u, v)\) 是否被删除即可。具体的,可以维护一个 \(\tt{bool}\) 数组,每次扩展节点前遍历所有从当前节点出发被删除的边,并打上标记,然后遍历列表后清空即可。可以发现每条被删除的边最多引起一次对列表单个元素的重复遍历,故这部分复杂度为 \(\mathcal{O}(N + M)\)。
总复杂度为 \(\mathcal{O}(N + M)\),可以通过本题。
Code
//G
#include <bits/stdc++.h>
typedef long long valueType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;
typedef std::vector<bool> bitset;
typedef std::queue<valueType> queue;
typedef std::list<valueType> list;
constexpr valueType MOD = 998244353;
template<typename T1, typename T2, typename T3 = valueType>
void Inc(T1 &a, T2 b, const T3 &mod = MOD) {
a = a + b;
if (a >= mod)
a -= mod;
}
template<typename T1, typename T2, typename T3 = valueType>
void Dec(T1 &a, T2 b, const T3 &mod = MOD) {
a = a - b;
if (a < 0)
a += mod;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
valueType N, M;
std::cin >> N >> M;
ValueMatrix G(N + 1);
for (valueType i = 0; i < M; ++i) {
valueType u, v;
std::cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
ValueVector dist(N + 1, N + 100), count(N + 1, -1), sum(N + 1, 0);
bitset removed(N + 1, false);
list Q(N - 1);
queue que;
dist[1] = 1;
count[1] = 1;
sum[0] = 1;
que.push(1);
std::iota(Q.begin(), Q.end(), 2);
while (!que.empty()) {
valueType const u = que.front();
que.pop();
count[u] = sum[dist[u] - 1];
for (auto const &iter: G[u]) {
if (dist[iter] == dist[u] - 1) {
Dec(count[u], count[iter]);
}
removed[iter] = true;
}
Inc(sum[dist[u]], count[u]);
for (auto iter = Q.begin(); iter != Q.end();) {
if (!removed[*iter]) {
dist[*iter] = dist[u] + 1;
que.push(*iter);
iter = Q.erase(iter);
} else {
++iter;
}
}
for (auto const &iter: G[u])
removed[iter] = false;
}
std::cout << count[N] << std::endl;
return 0;
}