题目链接:6356. 收集树中金币
方法:拓扑排序
解题思路
参考:拓扑排序 + 记录入队时间(Python/Java/C++/Go)
代码
class Solution {
public:
int collectTheCoins(vector<int>& coins, vector<vector<int>>& edges) {
int n = coins.size(), vertex = n;
vector<vector<int>> g(n);
int deg[n]; memset(deg, 0, sizeof(deg));
for (auto &e : edges) {
g[e[0]].push_back(e[1]);
g[e[1]].push_back(e[0]);
deg[e[0]] ++ ;
deg[e[1]] ++ ;
}
queue<int> q;
for (int i = 0; i < n; i ++ ) {
if (deg[i] == 1 && !coins[i]) q.push(i);
}
while (!q.empty()) { // 拓扑排序去除没有硬币的叶子节点
int front = q.front();
q.pop();
deg[front] -- ; // deg = 0
vertex -- ;
for (int &adj : g[front]) {
deg[adj] -- ;
if (deg[adj] == 1 && !coins[adj]) q.push(adj);
}
}
// 然后减去倒数后两层的节点
for (int i = 0; i < n; i ++ ) { // 减去倒数第一层的节点
if (deg[i] == 1) {
q.push(i);
vertex -- ;
}
}
while (!q.empty()) {
int front = q.front();
q.pop();
deg[front] -- ;
for (int &adj : g[front]) {
deg[adj] -- ;
}
}
for (int i = 0; i < n; i ++ ) { // 减去倒数第二层的节点
if (deg[i] == 1) vertex -- ;
}
return vertex == 0 ? 0 : (vertex - 1) * 2; // 返回需要走的边的数量
}
};
复杂度分析
时间复杂度:\(O(n)\);
空间复杂度:\(O(n)\)。
题目拓展
若再传进来一个\(queries\)数组,对于每次查询\(queries[i] = x\),表示每个节点收集距离当前节点距离为 \(x\) 以内的所有金币,在该情况下计算当前最少经过的边数。返回一个数组\(vector<int> ans\),存储对应\(queries\)的答案。本题即是返回\(queries[i] = 2\)时,最少经过的边数。
代码
class Solution {
public:
vector<int> collectTheCoins(vector<int>& coins, vector<vector<int>>& edges, vector<int>& queries) {
int n = coins.size(), vertex = n;
vector<vector<int>> g(n);
int deg[n]; memset(deg, 0, sizeof(deg));
for (auto &e : edges) {
g[e[0]].push_back(e[1]);
g[e[1]].push_back(e[0]);
deg[e[0]] ++ ;
deg[e[1]] ++ ;
}
queue<int> q;
for (int i = 0; i < n; i ++ ) {
if (deg[i] == 1 && !coins[i]) q.push(i);
}
while (!q.empty()) {
int front = q.front();
q.pop();
deg[front] -- ;
vertex -- ;
for (int &adj : g[front]) {
deg[adj] -- ;
if (deg[adj] == 1 && !coins[adj]) q.push(adj);
}
}
int idx = 0, vtx[n];
memset(vtx, 0, sizeof(vtx));
vtx[idx ++ ] = vertex;
while (vertex > 1) {
for (int i = 0; i < n; i ++ ) {
if (deg[i] == 1) {
q.push(i);
vertex -- ;
}
}
while (!q.empty()) {
int front = q.front();
q.pop();
deg[front] -- ;
for (int &adj : g[front]) {
deg[adj] -- ;
}
}
vtx[idx ++ ] = vertex;
}
int m = queries.size();
vector<int> ans(m);
for (int i = 0; i < m; i ++ ) {
int x = queries[i];
if (x >= n - 1) ans[i] = 0;
else ans[i] = vtx[x] == 0 ? 0 : (vtx[x] - 1) * 2;
}
return ans;
}
};