0834-树中距离之和

Raphael Liu Lv10

给定一个无向、连通的树。树中有 n 个标记为 0...n-1 的节点以及 n-1 条边 。

给定整数 n 和数组 edgesedges[i] = [ai, bi]表示树中的节点 aibi 之间有一条边。

返回长度为 n 的数组 answer ,其中 answer[i] 是树中第 i 个节点与所有其他节点之间的距离之和。

示例 1:

**输入:** n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
**输出:** [8,12,6,10,10,10]
**解释:** 树如图所示。
我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5) 
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。

示例 2:

**输入:** n = 1, edges = []
**输出:** [0]

示例 3:

**输入:** n = 2, edges = [[1,0]]
**输出:** [1,1]

提示:

  • 1 <= n <= 3 * 104
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= ai, bi < n
  • ai != bi
  • 给定的输入保证为有效的树

方法一:树形动态规划

思路与算法

首先我们来考虑一个节点的情况,即每次题目指定一棵树,以 root 为根,询问节点 root 与其他所有节点的距离之和。

很容易想到一个树形动态规划:定义 dp}[u] 表示以 u 为根的子树,它的所有子节点到它的距离之和,同时定义 sz}[u] 表示以 u 为根的子树的节点数量,不难得出如下的转移方程:

\textit{dp}[u]=\sum_{v\in \textit{son}[u]}\textit{dp}[v] + \textit{sz}[v]

其中 son}[u] 表示 u 的所有后代节点集合。转移方程表示的含义就是考虑每个后代节点 v,已知 v 的所有子节点到它的距离之和为 dp}[v],那么这些节点到 u 的距离之和还要考虑 u\rightarrow v 这条边的贡献。考虑这条边长度为 1,一共有 sz[v] 个节点到节点 u 的距离会包含这条边,因此贡献即为 1\times \textit{sz}[v]=\textit{sz}[v]。我们遍历整棵树,从叶子节点开始自底向上递推到根节点 root 即能得出最后的答案为 dp}[\textit{root}]。

回到本题中,题目要求的其实是上题的扩展,即要求我们求出每个节点为根节点的时候,它与其他所有节点的距离之和。暴力的角度我们可以考虑对每个节点都做一次如上的树形动态规划,这样时间复杂度即为 O(n^2),那么有没有更优雅的方法呢?

经过一次树形动态规划后其实我们获得了在 u 为根的树中,每个节点为根的子树的答案 dp,我们可以利用这些已有信息来优化时间复杂度。

假设 u 的某个后代节点为 v,如果要算 v 的答案,本来我们要以 v 为根来进行一次树形动态规划。但是利用已有的信息,我们可以考虑树的形态做一次改变,让 v 换到根的位置,u 变为其孩子节点,同时维护原有的 dp 信息。在这一次的转变中,我们观察到除了 u 和 v 的 dp 值,其他节点的 dp 值都不会改变,因此只要更新 dp}[u] 和 dp}[v] 的值即可。

那么我们来看 v 换到根的位置的时候怎么利用已有信息求出 dp}[u] 和 dp}[v] 的值。重新回顾第一次树形动态规划的转移方程,我们可以知道当 u 变为 v 的孩子的时候 v 不在 u 的后代集合 son}[u] 中了,因此此时 dp}[u] 需要减去 v 的贡献,即

\textit{dp}[u]=\textit{dp}[u]-(\textit{dp}[v]+\textit{sz}[v])

同时 sz}[u] 也要相应减去 sz}[v]。

而 v 的后代节点集合中多出了 u,因此 dp}[v] 的值要由 u 更新上来,即

\textit{dp}[v]=\textit{dp}[v]+(\textit{dp}[u]+\textit{sz}[u])

同时 sz}[v] 也要相应加上 sz}[u]。

至此我们完成了一次「换根」操作,在 O(1) 的时间内维护了 dp 的信息,且此时的树结构以 v 为根。那么接下来我们不断地进行换根的操作,即能在 O(n) 的时间内求出每个节点为根的答案,实现了时间复杂度的优化。

<ppt1,ppt2,ppt3,ppt4,ppt5,ppt6,ppt7>

代码

[sol1-C++]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution {
public:
vector<int> ans, sz, dp;
vector<vector<int>> graph;

void dfs(int u, int f) {
sz[u] = 1;
dp[u] = 0;
for (auto& v: graph[u]) {
if (v == f) {
continue;
}
dfs(v, u);
dp[u] += dp[v] + sz[v];
sz[u] += sz[v];
}
}

void dfs2(int u, int f) {
ans[u] = dp[u];
for (auto& v: graph[u]) {
if (v == f) {
continue;
}
int pu = dp[u], pv = dp[v];
int su = sz[u], sv = sz[v];

dp[u] -= dp[v] + sz[v];
sz[u] -= sz[v];
dp[v] += dp[u] + sz[u];
sz[v] += sz[u];

dfs2(v, u);

dp[u] = pu, dp[v] = pv;
sz[u] = su, sz[v] = sv;
}
}

vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
ans.resize(n, 0);
sz.resize(n, 0);
dp.resize(n, 0);
graph.resize(n, {});
for (auto& edge: edges) {
int u = edge[0], v = edge[1];
graph[u].emplace_back(v);
graph[v].emplace_back(u);
}
dfs(0, -1);
dfs2(0, -1);
return ans;
}
};
[sol1-Java]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class Solution {
int[] ans;
int[] sz;
int[] dp;
List<List<Integer>> graph;

public int[] sumOfDistancesInTree(int n, int[][] edges) {
ans = new int[n];
sz = new int[n];
dp = new int[n];
graph = new ArrayList<List<Integer>>();
for (int i = 0; i < n; ++i) {
graph.add(new ArrayList<Integer>());
}
for (int[] edge: edges) {
int u = edge[0], v = edge[1];
graph.get(u).add(v);
graph.get(v).add(u);
}
dfs(0, -1);
dfs2(0, -1);
return ans;
}

public void dfs(int u, int f) {
sz[u] = 1;
dp[u] = 0;
for (int v: graph.get(u)) {
if (v == f) {
continue;
}
dfs(v, u);
dp[u] += dp[v] + sz[v];
sz[u] += sz[v];
}
}

public void dfs2(int u, int f) {
ans[u] = dp[u];
for (int v: graph.get(u)) {
if (v == f) {
continue;
}
int pu = dp[u], pv = dp[v];
int su = sz[u], sv = sz[v];

dp[u] -= dp[v] + sz[v];
sz[u] -= sz[v];
dp[v] += dp[u] + sz[u];
sz[v] += sz[u];

dfs2(v, u);

dp[u] = pu;
dp[v] = pv;
sz[u] = su;
sz[v] = sv;
}
}
}
[sol1-JavaScript]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
let ans, sz, dp, graph;
const dfs = (u, f) => {
sz[u] = 1;
dp[u] = 0;
for (const v of graph[u]) {
if (v === f) {
continue;
}
dfs(v, u);
dp[u] += dp[v] + sz[v];
sz[u] += sz[v];
}
}
const dfs2 = (u, f) => {
ans[u] = dp[u];
for (const v of graph[u]) {
if (v === f) {
continue;
}
const pu = dp[u], pv = dp[v];
const su = sz[u], sv = sz[v];

dp[u] -= dp[v] + sz[v];
sz[u] -= sz[v];
dp[v] += dp[u] + sz[u];
sz[v] += sz[u];

dfs2(v, u);

dp[u] = pu, dp[v] = pv;
sz[u] = su, sz[v] = sv;
}
}
var sumOfDistancesInTree = function(n, edges) {
ans = new Array(n).fill(0);
sz = new Array(n).fill(0);
dp = new Array(n).fill(0);
graph = new Array(n).fill(0).map(v => []);
for (const [u, v] of edges) {
graph[u].push(v);
graph[v].push(u);
}
dfs(0, -1);
dfs2(0, -1);
return ans;
};
[sol1-Golang]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
func sumOfDistancesInTree(n int, edges [][]int) []int {
graph := make([][]int, n)
for _, e := range edges {
u, v := e[0], e[1]
graph[u] = append(graph[u], v)
graph[v] = append(graph[v], u)
}

sz := make([]int, n)
dp := make([]int, n)
var dfs func(u, f int)
dfs = func(u, f int) {
sz[u] = 1
for _, v := range graph[u] {
if v == f {
continue
}
dfs(v, u)
dp[u] += dp[v] + sz[v]
sz[u] += sz[v]
}
}
dfs(0, -1)

ans := make([]int, n)
var dfs2 func(u, f int)
dfs2 = func(u, f int) {
ans[u] = dp[u]
for _, v := range graph[u] {
if v == f {
continue
}
pu, pv := dp[u], dp[v]
su, sv := sz[u], sz[v]

dp[u] -= dp[v] + sz[v]
sz[u] -= sz[v]
dp[v] += dp[u] + sz[u]
sz[v] += sz[u]

dfs2(v, u)

dp[u], dp[v] = pu, pv
sz[u], sz[v] = su, sv
}
}
dfs2(0, -1)
return ans
}
[sol1-C]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
int *ans, *sz, *dp;
int *hd, *nx, *e;

void dfs(int u, int f) {
sz[u] = 1;
dp[u] = 0;
for (int i = hd[u]; i; i = nx[i]) {
int v = e[i];
if (v == f) {
continue;
}
dfs(v, u);
dp[u] += dp[v] + sz[v];
sz[u] += sz[v];
}
}

void dfs2(int u, int f) {
ans[u] = dp[u];
for (int i = hd[u]; i; i = nx[i]) {
int v = e[i];
if (v == f) {
continue;
}
int pu = dp[u], pv = dp[v];
int su = sz[u], sv = sz[v];

dp[u] -= dp[v] + sz[v];
sz[u] -= sz[v];
dp[v] += dp[u] + sz[u];
sz[v] += sz[u];

dfs2(v, u);

dp[u] = pu, dp[v] = pv;
sz[u] = su, sz[v] = sv;
}
}

int* sumOfDistancesInTree(int n, int** edges, int edgesSize, int* edgesColSize, int* returnSize) {
ans = malloc(sizeof(int) * n);
sz = malloc(sizeof(int) * n);
dp = malloc(sizeof(int) * n);
hd = malloc(sizeof(int) * n);
nx = malloc(sizeof(int) * (edgesSize * 2 + 1));
e = malloc(sizeof(int) * (edgesSize * 2 + 1));
for (int i = 0; i < n; i++) {
ans[i] = sz[i] = dp[i] = hd[i] = 0;
}
for (int i = 0, num = 0; i < edgesSize; i++) {
int u = edges[i][0], v = edges[i][1];
nx[++num] = hd[u], hd[u] = num, e[num] = v;
nx[++num] = hd[v], hd[v] = num, e[num] = u;
}
dfs(0, -1);
dfs2(0, -1);
*returnSize = n;
return ans;
}

复杂度分析

  • 时间复杂度:O(n),其中 n 是树中的节点个数。我们只需要遍历整棵树两次即可得到答案,其中每个节点被访问两次,因此时间复杂度为 O(2n)=O(n)。

  • 空间复杂度:O(n)。我们需要线性的空间存图,n 个节点的树包含 n-1 条边,数组 dp 和 sz 的长度均为 n。

 Comments
On this page
0834-树中距离之和