树中距离之和

发布时间 2023-07-16 03:25:19作者: 失控D大白兔

给定一个无向、连通的树。树中有 n 个标记为 0...n-1 的节点以及 n-1 条边 。
给定整数 n 和数组 edges , edges[i] = [ai, bi]表示树中的节点 ai 和 bi 之间有一条边。
返回长度为 n 的数组 answer ,其中 answer[i] 是树中第 i 个节点与所有其他节点之间的距离之和

1. 树状动规

假设以0号节点为源点
首先考虑计算源点到其他子节点的距离之和
同时计算以0节点为源点,其他各节点子树包含子节点数目
然后再次深度优先进行换源,根据数学推导计算其他节点为源点的距离和

数学推导
dp[i]定义为以0为源点,i的子树中所有子节点到i的距离和
dis[i]定义为以i为源点,其余所有节点到i节点的距离和
假设0节点的相邻节点有1、2
dis[0] = dp[0] = dp[1] + dp[2] + points[1] + points[2]

同时
dis[2] = dp[2] + dp[1] + points[1]*2 + 1

其实也就是dp值保持不变,上一个源点扩张一步,其他分支节点扩张两步,由此推导出换源后的距离变化
dis[2] = dp[1] + dp[2] + (n-points[2]-1)*2 + 1 = dis[0] + (n-points[2]-1) - (points[2]-1)

即dis[2] = dis[0] + (n-points[2]) - poinst[2]
本质上就是其余点距离增加,所属点距离减少

class Solution {
public:
    vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
        vector<vector<int>> graph(n); 
        vector<int> res(n);
        vector<int> points(n,1);
        for(int i=0;i<n-1;i++){ //构建邻接表
            int from = edges[i][0];
            int to = edges[i][1];
            graph[from].push_back(to);
            graph[to].push_back(from);
        }
        int dis0 = 0;//计算第一个点到其他点距离之和
        //第一个深度优先,计算以0节点为源点,各个节点包含的子节点个数
        function<int(int, int, int)> dfs1 = [&](int start, int pre, int depth) -> int{//这里记录上一个节点,避免循环访问
            dis0 += depth;//计算距离和
            for(int next:graph[start]){
                if(next==pre) continue;
                points[start] +=  dfs1(next,start,depth+1);
            }
            return points[start];
        };
        dfs1(0,-1,0);

        //第二个深度优先,进行换源的同时,计算各节点到源点距离和
        function<void(int,int,int)> dfs2 = [&](int start,int pre,int dist)->void{
            res[start] = dist;
            for(int next:graph[start]){
                if(next==pre) continue;
                //把源点换为next,
                dfs2(next,start, dist + (n-points[next]) - points[next]);
            }
        };
        dfs2(0,-1,dis0);
        return res;
    }
};