834. Sum of Distances in Tree (Hard)

发布时间 2023-07-18 10:40:42作者: zwyyy456

Description

834. Sum of Distances in Tree (Hard)

There is an undirected connected tree with n nodes labeled from 0 to n - 1 and n - 1 edges.

You are given the integer n and the array edges where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the tree.

Return an array answer of length n where answer[i] is the sum of the distances between the ith node in the tree and all other nodes.

 

Example 1:

Input: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation: The tree is shown above.
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.
Hence, answer[0] = 8, and so on.

Example 2:

Input: n = 1, edges = []
Output: [0]

Example 3:

Input: n = 2, edges = [[1,0]]
Output: [1,1]

 

Constraints:

  • 1 <= n <= 3 * 104
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= ai, bi < n
  • ai != bi
  • The given input represents a valid tree.

Solution

To find the sum of distances to a single node (e.g., $0$) denoted as $dp[0]$, we can easily use DFS to compute it with a time complexity of $O(n)$. However, finding the sum of distances for all $n$ nodes would take $O(n^2)$ time, which would obviously result in a timeout for large graphs.

However, we notice that there exists a recurrence relationship between the parent node $j$'s $dp[j]$ and the child node $i$'s $dp[i]$. Specifically, $dp[i] = dp[j] - cnt[i] + n - cnt[i]$ (since nodes $i$ and $j$ are directly connected).

So, the remaining problem is how to calculate cnt[i], which represents the number of nodes in the subtree rooted at the current node in the tree represented as an undirected graph. Please refer to the Tree Organized as an Undirected Graph for more details.

Code

class Solution {
  public:
    int count(vector<vector<int>> &tree, vector<int> &dis, vector<int> &cnt, int pa, int grandpa) {
        int res = 1;
        for (int child : tree[pa]) {
            if (child == grandpa) { // prevent repeated traversal
                continue;
            }
            dis[child] = dis[pa] + 1;
            res += count(tree, dis, cnt, child, pa);
        }
        cnt[pa] = res;
        return res;
    }
    vector<int> sumOfDistancesInTree(int n, vector<vector<int>> &edges) {
        vector<vector<int>> tree(n);
        for (auto &vec : edges) {
            tree[vec[0]].push_back(vec[1]);
            tree[vec[1]].push_back(vec[0]); // push_back twice to build undirected graph
        }
        vector<int> cnt(n);
        vector<int> dp(n);
        vector<int> dis(n); 
        count(tree, dis, cnt, 0, -1);
        for (int i = 0; i < n; ++i) {
            dp[0] += dis[i];
        }
        queue<pair<int, int>> q;
        q.push({0, -1}); // pa, grandpa
        while (!q.empty()) {
            auto [pa, grandpa] = q.front();
            q.pop();
            for (int child : tree[pa]) {
                if (child == grandpa) { // prevent repeated bfs
                    continue;
                }
                dp[child] = dp[pa] + n - 2 * cnt[child];
                q.push({child, pa});
            }
        }
        return dp;
    }
};