classSolution: defminimumTotalPrice(self, n: int, edges: List[List[int]], price: List[int], trips: List[List[int]]) -> int: g = [[] for _ inrange(n)] for x, y in edges: g[x].append(y) g[y].append(x) # 建树
cnt = [0] * n for start, end in trips: defdfs(x: int, fa: int) -> bool: if x == end: # 到达终点(注意树只有唯一的一条简单路径) cnt[x] += 1# 统计从 start 到 end 的路径上的点经过了多少次 returnTrue# 找到终点 for y in g[x]: if y != fa and dfs(y, x): cnt[x] += 1# 统计从 start 到 end 的路径上的点经过了多少次 returnTrue# 找到终点 returnFalse# 未找到终点 dfs(start, -1)
# 类似 337. 打家劫舍 III https://leetcode.cn/problems/house-robber-iii/ defdfs(x: int, fa: int) -> (int, int): not_halve = price[x] * cnt[x] # x 不变 halve = not_halve // 2# x 减半 for y in g[x]: if y != fa: nh, h = dfs(y, x) # 计算 y 不变/减半的最小价值总和 not_halve += min(nh, h) # x 不变,那么 y 可以不变,可以减半,取这两种情况的最小值 halve += nh # x 减半,那么 y 只能不变 return not_halve, halve returnmin(dfs(0, -1))
funcminimumTotalPrice(n int, edges [][]int, price []int, trips [][]int)int { g := make([][]int, n) for _, e := range edges { x, y := e[0], e[1] g[x] = append(g[x], y) g[y] = append(g[y], x) // 建树 }
cnt := make([]int, n) for _, t := range trips { end := t[1] var dfs func(int, int)bool dfs = func(x, fa int)bool { if x == end { // 到达终点(注意树只有唯一的一条简单路径) cnt[x]++ // 统计从 start 到 end 的路径上的点经过了多少次 returntrue// 找到终点 } for _, y := range g[x] { if y != fa && dfs(y, x) { cnt[x]++ // 统计从 start 到 end 的路径上的点经过了多少次 returntrue } } returnfalse// 未找到终点 } dfs(t[0], -1) }
// 类似 337. 打家劫舍 III https://leetcode.cn/problems/house-robber-iii/ var dfs func(int, int) (int, int) dfs = func(x, fa int) (int, int) { notHalve := price[x] * cnt[x] // x 不变 halve := notHalve / 2// x 减半 for _, y := range g[x] { if y != fa { nh, h := dfs(y, x) // 计算 y 不变/减半的最小价值总和 notHalve += min(nh, h) // x 不变,那么 y 可以不变,可以减半,取这两种情况的最小值 halve += nh // x 减半,那么 y 只能不变 } } return notHalve, halve } nh, h := dfs(0, -1) return min(nh, h) }
funcmin(a, b int)int { if a > b { return b }; return a }
复杂度分析
时间复杂度:O(nm),其中 m 为 trips 的长度。
空间复杂度:O(n)。
方法二:Tarjan 离线 LCA + 树上差分
核心思路:利用树上差分打标记,再通过一次 DFS 算出 cnt 值。
从 x=\textit{start 到 y=\textit{end 的路径可以视作从 x 向上到某个点「拐弯」,再向下到达 y。(拐弯的点也可能就是 x 或 y)
classSolution: defminimumTotalPrice(self, n: int, edges: List[List[int]], price: List[int], trips: List[List[int]]) -> int: g = [[] for _ inrange(n)] for x, y in edges: g[x].append(y) g[y].append(x) # 建树
qs = [[] for _ inrange(n)] for s, e in trips: qs[s].append(e) # 路径端点分组 if s != e: qs[e].append(s)
# 并查集模板 pa = list(range(n)) deffind(x: int) -> int: if x != pa[x]: pa[x] = find(pa[x]) return pa[x]
diff = [0] * n father = [0] * n color = [0] * n deftarjan(x: int, fa: int) -> None: father[x] = fa color[x] = 1# 递归中 for y in g[x]: if color[y] == 0: # 未递归 tarjan(y, x) pa[y] = x # 相当于把 y 的子树节点全部 merge 到 x for y in qs[x]: # color[y] == 2 意味着 y 所在子树已经遍历完 # 也就意味着 y 已经 merge 到它和 x 的 lca 上了 if y == x or color[y] == 2: # 从 y 向上到达 lca 然后拐弯向下到达 x diff[x] += 1 diff[y] += 1 lca = find(y) diff[lca] -= 1 if father[lca] >= 0: diff[father[lca]] -= 1 color[x] = 2# 递归结束 tarjan(0, -1)
defdfs(x: int, fa: int) -> (int, int, int): not_halve, halve, cnt = 0, 0, diff[x] for y in g[x]: if y != fa: nh, h, c = dfs(y, x) # 计算 y 不变/减半的最小价值总和 not_halve += min(nh, h) # x 不变,那么 y 可以不变,可以减半,取这两种情况的最小值 halve += nh # x 减半,那么 y 只能不变 cnt += c # 自底向上累加差分值 not_halve += price[x] * cnt # x 不变 halve += price[x] * cnt // 2# x 减半 return not_halve, halve, cnt returnmin(dfs(0, -1)[:2])
funcminimumTotalPrice(n int, edges [][]int, price []int, trips [][]int)int { g := make([][]int, n) for _, e := range edges { x, y := e[0], e[1] g[x] = append(g[x], y) g[y] = append(g[y], x) // 建树 }
qs := make([][]int, n) for _, t := range trips { x, y := t[0], t[1] qs[x] = append(qs[x], y) // 路径端点分组 if x != y { qs[y] = append(qs[y], x) } }
// 并查集模板 pa := make([]int, n) for i := range pa { pa[i] = i } var find func(int)int find = func(x int)int { if pa[x] != x { pa[x] = find(pa[x]) } return pa[x] }
diff := make([]int, n) father := make([]int, n) color := make([]int8, n) var tarjan func(int, int) tarjan = func(x, fa int) { father[x] = fa color[x] = 1// 递归中 for _, y := range g[x] { if color[y] == 0 { // 未递归 tarjan(y, x) pa[y] = x // 相当于把 y 的子树节点全部 merge 到 x } } for _, y := range qs[x] { // color[y] == 2 意味着 y 所在子树已经遍历完 // 也就意味着 y 已经 merge 到它和 x 的 lca 上了 if y == x || color[y] == 2 { // 从 y 向上到达 lca 然后拐弯向下到达 x diff[x]++ diff[y]++ lca := find(y) diff[lca]-- if f := father[lca]; f >= 0 { diff[f]-- } } } color[x] = 2// 递归结束 } tarjan(0, -1)
var dfs func(int, int) (int, int, int) dfs = func(x, fa int) (notHalve, halve, cnt int) { cnt = diff[x] for _, y := range g[x] { if y != fa { nh, h, c := dfs(y, x) // 计算 y 不变/减半的最小价值总和 notHalve += min(nh, h) // x 不变,那么 y 可以不变,可以减半,取这两种情况的最小值 halve += nh // x 减半,那么 y 只能不变 cnt += c // 自底向上累加差分值 } } notHalve += price[x] * cnt // x 不变 halve += price[x] * cnt / 2// x 减半 return } nh, h, _ := dfs(0, -1) return min(nh, h) }
funcmin(a, b int)int { if a > b { return b }; return a }
复杂度分析
时间复杂度:O(n+m\alpha),其中 m 为 trips 的长度,\alpha 为并查集的常数,可视作 O(1)。