2023.7.16 树中距离之和

发布时间 2023-07-16 15:42:20作者: 烤肉kr

image

是道挺明显的换根dp,但是我还是不会orz

首先暴力的思路是容易想的,比如以0为根节点,然后做一遍dfs,统计出所有节点的深度,即为到根节点的距离,累加起来就是0节点对应的答案。对于所有的点都进行这样的操作就可以得到全部答案,但是这样的时间复杂度是\(O(n^2)\),过不了。

所以肯定要想办法做换根dp,从而只需要两次dfs就可以得到所有节点的答案。但是换根dp需要推导出变换根节点后信息的变换公式,这个需要找规律。
image

参考灵神这张图,上方为0作为根节点得到的距离,下方为2作为根节点得到的距离。可以观察到规律,属于节点2的子树中的所有节点的距离值都比起原先以0作为根节点的距离值减少了1,而不属于的那些节点增加了1。
假设所有的节点数量为n,属于节点2的子树的节点数量为\(size[2]\),那么距离值就增加了\(n - size[2]\),同时距离值也减少了\(size[2]\),两者合并,距离值就是加上\(n - 2\times size[2]\)

只要得到了这个公式,那么换根dp就很简单了,先取0号节点做一次dp,然后第二次dp,对每个结点都用上面的公式直接算出答案即可。(记得先预处理出每个节点的子树大小)

class Solution {
public:
    vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) 
    {
        vector<vector<int>> g(n);
        for (auto &edge : edges)
        {
            int u = edge[0], v = edge[1];
            g[u].push_back(v);
            g[v].push_back(u);
        }

        vector<int> sz(n, 1);
        function<int(int, int)> get_size = [&] (int u, int fa) {
            for (int i = 0; i < g[u].size(); ++i)
            {
                int v = g[u][i];
                if (v != fa)
                    sz[u] += get_size(v, u);
            }
            return sz[u];
        };
        get_size(0, -1);

        vector<int> res(n);
        function<void(int, int, int)> dp = [&] (int u, int fa, int depth)
        {
            res[0] += depth;
            for (int i = 0; i < g[u].size(); ++i)
            {
                int v = g[u][i];
                if (v != fa) dp(v, u, depth + 1);
            }
        };
        function<void(int, int)> reroot = [&] (int u, int fa)
        {
            for (int i = 0; i < g[u].size(); ++i)
            {
                int v = g[u][i];
                if (v != fa)
                {
                    res[v] = res[u] + n - 2 * sz[v];
                    reroot(v, u);
                }
            }
        };
        dp(0, -1, 0);
        reroot(0, -1);

        return res;
    }
};