2812-找出最安全路径

Raphael Liu Lv10

给你一个下标从 0 开始、大小为 n x n 的二维矩阵 grid ,其中 (r, c) 表示:

  • 如果 grid[r][c] = 1 ,则表示一个存在小偷的单元格
  • 如果 grid[r][c] = 0 ,则表示一个空单元格

你最开始位于单元格 (0, 0) 。在一步移动中,你可以移动到矩阵中的任一相邻单元格,包括存在小偷的单元格。

矩阵中路径的 安全系数 定义为:从路径中任一单元格到矩阵中任一小偷所在单元格的 最小 曼哈顿距离。

返回所有通向单元格 __(n - 1, n - 1) 的路径中的 最大安全系数

单元格 (r, c) 的某个 相邻 单元格,是指在矩阵中存在的 (r, c + 1)(r, c - 1)(r + 1, c)
(r - 1, c) 之一。

两个单元格 (a, b)(x, y) 之间的 曼哈顿距离 等于 | a - x | + | b - y | ,其中 |val|
表示 val 的绝对值。

示例 1:

**输入:** grid = [[1,0,0],[0,0,0],[0,0,1]]
**输出:** 0
**解释:** 从 (0, 0) 到 (n - 1, n - 1) 的每条路径都经过存在小偷的单元格 (0, 0) 和 (n - 1, n - 1) 。

示例 2:

**输入:** grid = [[0,0,1],[0,0,0],[0,0,0]]
**输出:** 2
**解释:**
上图所示路径的安全系数为 2:
- 该路径上距离小偷所在单元格(0,2)最近的单元格是(0,0)。它们之间的曼哈顿距离为 | 0 - 0 | + | 0 - 2 | = 2 。
可以证明,不存在安全系数更高的其他路径。

示例 3:

**输入:** grid = [[0,0,0,1],[0,0,0,0],[0,0,0,0],[1,0,0,0]]
**输出:** 2
**解释:**
上图所示路径的安全系数为 2:
- 该路径上距离小偷所在单元格(0,3)最近的单元格是(1,2)。它们之间的曼哈顿距离为 | 0 - 1 | + | 3 - 2 | = 2 。
- 该路径上距离小偷所在单元格(3,0)最近的单元格是(3,2)。它们之间的曼哈顿距离为 | 3 - 3 | + | 0 - 2 | = 2 。
可以证明,不存在安全系数更高的其他路径。

提示:

  • 1 <= grid.length == n <= 400
  • grid[i].length == n
  • grid[i][j]01
  • grid 至少存在一个小偷

视频讲解 第三题。

建议结合视频中画的图来理解。

  1. 从所有 1 出发,写一个多源 BFS,计算出每个格子 (i,j) 到最近的 1 的曼哈顿距离 dis}[i][j]。注意题目保证至少有一个 1。
  2. 答案不会超过 dis}[i][j] 的最大值,我们可以倒序枚举答案。
  3. 如何判断我们能从左上角 (0,0) 走到右下角 (n-1,n-1) 呢?并查集!
  4. 假设答案为 d,我们可以把所有 dis}[i][j]=d 的格子与其四周 \ge d 的格子用并查集连起来,在答案为 d 的情况下,这些格子之间是可以互相到达的。
  5. 用并查集判断 (0,0) 和 (n-1,n-1) 是否连通,只要连通就立刻返回 d 作为答案。
[sol-Python3]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution:
def maximumSafenessFactor(self, grid: List[List[int]]) -> int:
n = len(grid)
q = []
dis = [[-1] * n for _ in range(n)]
for i, row in enumerate(grid):
for j, x in enumerate(row):
if x:
q.append((i, j))
dis[i][j] = 0

groups = [q]
while q: # 多源 BFS
tmp = q
q = []
for i, j in tmp:
for x, y in (i + 1, j), (i - 1, j), (i, j + 1), (i, j - 1):
if 0 <= x < n and 0 <= y < n and dis[x][y] < 0:
q.append((x, y))
dis[x][y] = len(groups)
groups.append(q) # 相同 dis 分组记录

# 并查集模板
fa = list(range(n * n))
def find(x: int) -> int:
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]

for d in range(len(groups) - 2, 0, -1):
for i, j in groups[d]:
for x, y in (i + 1, j), (i - 1, j), (i, j + 1), (i, j - 1):
if 0 <= x < n and 0 <= y < n and dis[x][y] >= dis[i][j]:
fa[find(x * n + y)] = find(i * n + j)
if find(0) == find(n * n - 1): # 写这里判断更快些
return d
return 0
[sol-Java]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Solution {
private final static int[][] DIRS = { {-1, 0}, {1, 0}, {0, -1}, {0, 1} };

public int maximumSafenessFactor(List<List<Integer>> grid) {
int n = grid.size();
var q = new ArrayList<int[]>();
var dis = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (grid.get(i).get(j) > 0) {
q.add(new int[]{i, j});
} else {
dis[i][j] = -1;
}
}
}

var groups = new ArrayList<List<int[]>>();
groups.add(q);
while (!q.isEmpty()) { // 多源 BFS
var tmp = q;
q = new ArrayList<>();
for (var p : tmp) {
for (var d : DIRS) {
int x = p[0] + d[0], y = p[1] + d[1];
if (0 <= x && x < n && 0 <= y && y < n && dis[x][y] < 0) {
q.add(new int[]{x, y});
dis[x][y] = groups.size();
}
}
}
groups.add(q); // 相同 dis 分组记录
}

// 并查集
fa = new int[n * n];
for (int i = 0; i < n * n; i++)
fa[i] = i;

for (int ans = groups.size() - 2; ans > 0; ans--) {
var g = groups.get(ans);
for (var p : groups.get(ans)) {
int i = p[0], j = p[1];
for (var d : DIRS) {
int x = i + d[0], y = j + d[1];
if (0 <= x && x < n && 0 <= y && y < n && dis[x][y] >= dis[i][j])
fa[find(x * n + y)] = find(i * n + j);
}
}
if (find(0) == find(n * n - 1)) // 写这里判断更快些
return ans;
}
return 0;
}

// 并查集模板
private int[] fa;

private int find(int x) {
if (fa[x] != x) fa[x] = find(fa[x]);
return fa[x];
}
}
[sol-C++]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solution {
static constexpr int dirs[4][2] = { {-1, 0}, {1, 0}, {0, -1}, {0, 1} };
public:
int maximumSafenessFactor(vector<vector<int>> &grid) {
int n = grid.size();
vector<pair<int, int>> q;
vector<vector<int>> dis(n, vector<int>(n, -1));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (grid[i][j]) {
q.emplace_back(i, j);
dis[i][j] = 0;
}
}
}

vector<vector<pair<int, int>>> groups = {q};
while (!q.empty()) { // 多源 BFS
vector<pair<int, int>> nq;
for (auto &[i, j]: q) {
for (auto &d: dirs) {
int x = i + d[0], y = j + d[1];
if (0 <= x && x < n && 0 <= y && y < n && dis[x][y] < 0) {
nq.emplace_back(x, y);
dis[x][y] = groups.size();
}
}
}
groups.push_back(nq); // 相同 dis 分组记录
q = move(nq);
}

// 并查集模板
vector<int> fa(n * n);
iota(fa.begin(), fa.end(), 0);
function<int(int)> find = [&](int x) -> int { return fa[x] == x ? x : fa[x] = find(fa[x]); };

for (int ans = (int) groups.size() - 2; ans > 0; ans--) {
for (auto &[i, j]: groups[ans]) {
for (auto &d: dirs) {
int x = i + d[0], y = j + d[1];
if (0 <= x && x < n && 0 <= y && y < n && dis[x][y] >= dis[i][j])
fa[find(x * n + y)] = find(i * n + j);
}
}
if (find(0) == find(n * n - 1)) // 写这里判断更快些
return ans;
}
return 0;
}
};
[sol-Go]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
func maximumSafenessFactor(grid [][]int) int {
n := len(grid)
type pair struct{ x, y int }
q := []pair{}
dis := make([][]int, n)
for i, row := range grid {
dis[i] = make([]int, n)
for j, x := range row {
if x > 0 {
q = append(q, pair{i, j})
} else {
dis[i][j] = -1
}
}
}

dir4 := []pair{ {-1, 0}, {1, 0}, {0, -1}, {0, 1} }
groups := [][]pair{q}
for len(q) > 0 { // 多源 BFS
tmp := q
q = nil
for _, p := range tmp {
for _, d := range dir4 {
x, y := p.x+d.x, p.y+d.y
if 0 <= x && x < n && 0 <= y && y < n && dis[x][y] < 0 {
q = append(q, pair{x, y})
dis[x][y] = len(groups)
}
}
}
groups = append(groups, q) // 相同 dis 分组记录
}

// 并查集模板
fa := make([]int, n*n)
for i := range fa {
fa[i] = i
}
var find func(int) int
find = func(x int) int {
if fa[x] != x {
fa[x] = find(fa[x])
}
return fa[x]
}

for ans := len(groups) - 2; ans > 0; ans-- {
for _, p := range groups[ans] {
i, j := p.x, p.y
for _, d := range dir4 {
x, y := p.x+d.x, p.y+d.y
if 0 <= x && x < n && 0 <= y && y < n && dis[x][y] >= dis[i][j] {
fa[find(x*n+y)] = find(i*n + j)
}
}
}
if find(0) == find(n*n-1) { // 写这里判断更快些
return ans
}
}
return 0
}

复杂度分析

  • 时间复杂度:\mathcal{O}(n^2\log n) 或者 \mathcal{O}(n^2),其中 n 为 grid 的长度。时间复杂度取决于并查集的实现,加上按秩合并的话,均摊地说并查集的操作可以视作是 O(1) 的。
  • 空间复杂度:\mathcal{O}(n^2)。
 Comments
On this page
2812-找出最安全路径