是道挺明显的换根dp,但是我还是不会orz
首先暴力的思路是容易想的,比如以0为根节点,然后做一遍dfs,统计出所有节点的深度,即为到根节点的距离,累加起来就是0节点对应的答案。对于所有的点都进行这样的操作就可以得到全部答案,但是这样的时间复杂度是\(O(n^2)\),过不了。
所以肯定要想办法做换根dp,从而只需要两次dfs就可以得到所有节点的答案。但是换根dp需要推导出变换根节点后信息的变换公式,这个需要找规律。
参考灵神这张图,上方为0作为根节点得到的距离,下方为2作为根节点得到的距离。可以观察到规律,属于节点2的子树中的所有节点的距离值都比起原先以0作为根节点的距离值减少了1,而不属于的那些节点增加了1。
假设所有的节点数量为n,属于节点2的子树的节点数量为\(size[2]\),那么距离值就增加了\(n - size[2]\),同时距离值也减少了\(size[2]\),两者合并,距离值就是加上\(n - 2\times size[2]\)。
只要得到了这个公式,那么换根dp就很简单了,先取0号节点做一次dp,然后第二次dp,对每个结点都用上面的公式直接算出答案即可。(记得先预处理出每个节点的子树大小)
class Solution {
public:
vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges)
{
vector<vector<int>> g(n);
for (auto &edge : edges)
{
int u = edge[0], v = edge[1];
g[u].push_back(v);
g[v].push_back(u);
}
vector<int> sz(n, 1);
function<int(int, int)> get_size = [&] (int u, int fa) {
for (int i = 0; i < g[u].size(); ++i)
{
int v = g[u][i];
if (v != fa)
sz[u] += get_size(v, u);
}
return sz[u];
};
get_size(0, -1);
vector<int> res(n);
function<void(int, int, int)> dp = [&] (int u, int fa, int depth)
{
res[0] += depth;
for (int i = 0; i < g[u].size(); ++i)
{
int v = g[u][i];
if (v != fa) dp(v, u, depth + 1);
}
};
function<void(int, int)> reroot = [&] (int u, int fa)
{
for (int i = 0; i < g[u].size(); ++i)
{
int v = g[u][i];
if (v != fa)
{
res[v] = res[u] + n - 2 * sz[v];
reroot(v, u);
}
}
};
dp(0, -1, 0);
reroot(0, -1);
return res;
}
};