834. 树中距离之和 (Hard)

发布时间 2023-07-18 10:40:42作者: zwyyy456

问题描述

834. 树中距离之和 (Hard)

给定一个无向、连通的树。树中有 n 个标记为 0...n-1 的节点以及 n-1 条边 。

给定整数 n 和数组 edgesedges[i] = [aᵢ, bᵢ] 表示树中的节点 aᵢbᵢ 之间有一条边。

返回长度为 n 的数组 answer ,其中 answer[i] 是树中第 i 个节点与所有其他节点之间的距离之和。

示例 1:

输入: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释: 树如图所示。
我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。

示例 2:

输入: n = 1, edges = []
输出: [0]

示例 3:

输入: n = 2, edges = [[1,0]]
输出: [1,1]

提示:

  • 1 <= n <= 3 * 10⁴
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= aᵢ, bᵢ < n
  • aᵢ != bᵢ
  • 给定的输入保证为有效的树

解题思路

求单个结点(例如 $0$)的距离之和 $dp[0]$,我们可以很容易利用 dfs 以 $O(n)$ 的时间复杂度求出,求 $n$ 个点就是 $O(n^2)$,明显会超时。

但是,不难看出,父结点 $j$ 的 $dp[j]$ 和子结点 $i$ 的 $dp[i]$ 之间存在某种递推关系,即 $dp[i] = dp[j] - cnt[i] + n - cnt[i]$(因为 $i, j$ 直接相连)。

那么,问题就只剩下如何求 cnt[i] 了,即在无向图表示的树中,如何求以当前结点为根结点的子树的结点数,参见 无向图形式组织的树

代码

class Solution {
  public:
    int count(vector<vector<int>> &tree, vector<int> &dis, vector<int> &cnt, int pa, int grandpa) {
        int res = 1;
        for (int child : tree[pa]) {
            if (child == grandpa) { // 防止重复遍历,保证 dfs 遍历时的单向性
                continue;
            }
            dis[child] = dis[pa] + 1;
            res += count(tree, dis, cnt, child, pa);
        }
        cnt[pa] = res;
        return res;
    }
    vector<int> sumOfDistancesInTree(int n, vector<vector<int>> &edges) {
        vector<vector<int>> tree(n);
        for (auto &vec : edges) {
            tree[vec[0]].push_back(vec[1]);
            tree[vec[1]].push_back(vec[0]); // 注意,建立无向图要 push_back 两次!
        }
        vector<int> cnt(n);
        vector<int> dp(n);
        vector<int> dis(n); // 表示结点 0 到其他结点的最短距离
        count(tree, dis, cnt, 0, -1);
        for (int i = 0; i < n; ++i) {
            dp[0] += dis[i];
        }
        queue<pair<int, int>> q;
        q.push({0, -1}); // pa, grandpa
        while (!q.empty()) {
            auto [pa, grandpa] = q.front();
            q.pop();
            for (int child : tree[pa]) {
                if (child == grandpa) { // 保证 bfs 遍历时的单向性
                    continue;
                }
                dp[child] = dp[pa] + n - 2 * cnt[child];
                q.push({child, pa});
            }
        }
        return dp;
    }
};