给定一个数组 points
,其中 points[i] = [xi, yi]
表示 X-Y 平面上的一个点,并且是一个整数 k
,返回离原点 (0,0)
最近的 k
个点。
这里,平面上两点之间的距离是 欧几里德距离 ( √(x1 - x2)2 + (y1 - y2)2
)。
你可以按 任何顺序 返回答案。除了点坐标的顺序之外,答案 确保 是 唯一 的。
示例 1:
**输入:** points = [[1,3],[-2,2]], k = 1
**输出:** [[-2,2]]
**解释:**
(1, 3) 和原点之间的距离为 sqrt(10),
(-2, 2) 和原点之间的距离为 sqrt(8),
由于 sqrt(8) < sqrt(10),(-2, 2) 离原点更近。
我们只需要距离原点最近的 K = 1 个点,所以答案就是 [[-2,2]]。
示例 2:
**输入:** points = [[3,3],[5,-1],[-2,4]], k = 2
**输出:** [[3,3],[-2,4]]
(答案 [[-2,4],[3,3]] 也会被接受。)
提示:
1 <= k <= points.length <= 104
-104 < xi, yi < 104
前言 当我们计算出每个点到原点的欧几里得距离的平方后,本题和「剑指 Offer 40. 最小的k个数 」是完全一样的题。
为什么是欧几里得距离的「平方」?这是因为欧几里得距离并不一定是个整数,在进行计算和比较时可能会引进误差;但它的平方一定是个整数,这样我们就无需考虑误差了。
方法一:排序 思路和算法
将每个点到原点的欧几里得距离的平方从小到大排序后,取出前 k 个即可。
代码
[sol1-C++] 1 2 3 4 5 6 7 8 9 class Solution {public : vector<vector<int >> kClosest (vector<vector<int >>& points, int k) { sort (points.begin (), points.end (), [](const vector<int >& u, const vector<int >& v) { return u[0 ] * u[0 ] + u[1 ] * u[1 ] < v[0 ] * v[0 ] + v[1 ] * v[1 ]; }); return {points.begin (), points.begin () + k}; } };
[sol1-Java] 1 2 3 4 5 6 7 8 9 10 class Solution { public int [][] kClosest(int [][] points, int k) { Arrays.sort(points, new Comparator <int []>() { public int compare (int [] point1, int [] point2) { return (point1[0 ] * point1[0 ] + point1[1 ] * point1[1 ]) - (point2[0 ] * point2[0 ] + point2[1 ] * point2[1 ]); } }); return Arrays.copyOfRange(points, 0 , k); } }
[sol1-Python3] 1 2 3 4 class Solution : def kClosest (self, points: List [List [int ]], k: int ) -> List [List [int ]]: points.sort(key=lambda x: (x[0 ] ** 2 + x[1 ] ** 2 )) return points[:k]
[sol1-Golang] 1 2 3 4 5 6 7 func kClosest (points [][]int , k int ) [][]int { sort.Slice(points, func (i, j int ) bool { p, q := points[i], points[j] return p[0 ]*p[0 ]+p[1 ]*p[1 ] < q[0 ]*q[0 ]+q[1 ]*q[1 ] }) return points[:k] }
[sol1-C] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 int cmp (const void * _a, const void * _b) { int *a = *(int **)_a, *b = *(int **)_b; return a[0 ] * a[0 ] + a[1 ] * a[1 ] - b[0 ] * b[0 ] - b[1 ] * b[1 ]; } int ** kClosest (int ** points, int pointsSize, int * pointsColSize, int k, int * returnSize, int ** returnColumnSizes) { qsort(points, pointsSize, sizeof (int *), cmp); *returnSize = k; *returnColumnSizes = malloc (sizeof (int ) * k); int ** ret = malloc (sizeof (int *) * k); for (int i = 0 ; i < k; i++) { (*returnColumnSizes)[i] = 2 ; ret[i] = malloc (sizeof (int ) * 2 ); ret[i][0 ] = points[i][0 ], ret[i][1 ] = points[i][1 ]; } return ret; }
复杂度分析
方法二:堆 思路和算法
我们可以使用一个大根堆实时维护前 k 个最小的距离平方。
首先我们将前 k 个点的编号(为了方便最后直接得到答案)以及对应的距离平方放入大根堆中,随后从第 k+1 个点开始遍历:如果当前点的距离平方比堆顶的点的距离平方要小,就把堆顶的点弹出,再插入当前的点。当遍历完成后,所有在大根堆中的点就是前 k 个距离最小的点。
不同的语言提供的堆的默认情况不一定相同。在 C++ 语言中,堆(即优先队列)为大根堆,但在 Python 语言中,堆为小根堆,因此我们需要在小根堆中存储(以及比较)距离平方的相反数。
代码
[sol2-C++] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class Solution {public : vector<vector<int >> kClosest (vector<vector<int >>& points, int k) { priority_queue<pair<int , int >> q; for (int i = 0 ; i < k; ++i) { q.emplace (points[i][0 ] * points[i][0 ] + points[i][1 ] * points[i][1 ], i); } int n = points.size (); for (int i = k; i < n; ++i) { int dist = points[i][0 ] * points[i][0 ] + points[i][1 ] * points[i][1 ]; if (dist < q.top ().first) { q.pop (); q.emplace (dist, i); } } vector<vector<int >> ans; while (!q.empty ()) { ans.push_back (points[q.top ().second]); q.pop (); } return ans; } };
[sol2-Java] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 class Solution { public int [][] kClosest(int [][] points, int k) { PriorityQueue<int []> pq = new PriorityQueue <int []>(new Comparator <int []>() { public int compare (int [] array1, int [] array2) { return array2[0 ] - array1[0 ]; } }); for (int i = 0 ; i < k; ++i) { pq.offer(new int []{points[i][0 ] * points[i][0 ] + points[i][1 ] * points[i][1 ], i}); } int n = points.length; for (int i = k; i < n; ++i) { int dist = points[i][0 ] * points[i][0 ] + points[i][1 ] * points[i][1 ]; if (dist < pq.peek()[0 ]) { pq.poll(); pq.offer(new int []{dist, i}); } } int [][] ans = new int [k][2 ]; for (int i = 0 ; i < k; ++i) { ans[i] = points[pq.poll()[1 ]]; } return ans; } }
[sol2-Python3] 1 2 3 4 5 6 7 8 9 10 11 12 13 class Solution : def kClosest (self, points: List [List [int ]], k: int ) -> List [List [int ]]: q = [(-x ** 2 - y ** 2 , i) for i, (x, y) in enumerate (points[:k])] heapq.heapify(q) n = len (points) for i in range (k, n): x, y = points[i] dist = -x ** 2 - y ** 2 heapq.heappushpop(q, (dist, i)) ans = [points[identity] for (_, identity) in q] return ans
[sol2-Golang] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 type pair struct { dist int point []int } type hp []pairfunc (h hp) Len() int { return len (h) }func (h hp) Less(i, j int ) bool { return h[i].dist > h[j].dist }func (h hp) Swap(i, j int ) { h[i], h[j] = h[j], h[i] }func (h *hp) Push(v interface {}) { *h = append (*h, v.(pair)) }func (h *hp) Pop() interface {} { a := *h; v := a[len (a)-1 ]; *h = a[:len (a)-1 ]; return v }func kClosest (points [][]int , k int ) (ans [][]int ) { h := make (hp, k) for i, p := range points[:k] { h[i] = pair{p[0 ]*p[0 ] + p[1 ]*p[1 ], p} } heap.Init(&h) for _, p := range points[k:] { if dist := p[0 ]*p[0 ] + p[1 ]*p[1 ]; dist < h[0 ].dist { h[0 ] = pair{dist, p} heap.Fix(&h, 0 ) } } for _, p := range h { ans = append (ans, p.point) } return }
复杂度分析
方法三:快速选择(快速排序的思想) 思路和算法
我们也可以借鉴快速排序的思想。
快速排序中的划分操作每次执行完后,都能将数组分成两个部分,其中小于等于分界值 pivot 的元素都会被放到左侧部分,而大于 pivot 的元素都都会被放到右侧部分。与快速排序不同的是,在本题中我们可以根据 k 与 pivot 下标的位置关系,只处理划分结果的某一部分(而不是像快速排序一样需要处理两个部分)。
我们定义函数 random_select(left, right, k)
表示划分数组 points 的 [\textit{left},\textit{right}] 区间,并且需要找到其中第 k 个距离最小的点。在一次划分操作完成后,设 pivot 的下标为 i,即区间 [\textit{left}, i-1] 中的点的距离都小于等于 pivot,而区间 [i+1,\textit{right}] 的点的距离都大于 pivot。此时会有三种情况:
如果 k = i-\textit{left}+1,那么说明 pivot 就是第 k 个距离最小的点,我们可以结束整个过程;
如果 k < i-\textit{left}+1,那么说明第 k 个距离最小的点在 pivot 左侧,因此递归调用 random_select(left, i - 1, k)
;
如果 k > i-\textit{left}+1,那么说明第 k 个距离最小的点在 pivot 右侧,因此递归调用 random_select(i + 1, right, k - (i - left + 1))
。
在整个过程结束之后,第 k 个距离最小的点恰好就在数组 points 中的第 k 个位置,并且其左侧的所有点的距离都小于它。此时,我们就找到了前 k 个距离最小的点。
代码
[sol3-C++] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 class Solution {private : mt19937 gen{random_device{}()}; public : void random_select (vector<vector<int >>& points, int left, int right, int k) { int pivot_id = uniform_int_distribution<int >{left, right}(gen); int pivot = points[pivot_id][0 ] * points[pivot_id][0 ] + points[pivot_id][1 ] * points[pivot_id][1 ]; swap (points[right], points[pivot_id]); int i = left - 1 ; for (int j = left; j < right; ++j) { int dist = points[j][0 ] * points[j][0 ] + points[j][1 ] * points[j][1 ]; if (dist <= pivot) { ++i; swap (points[i], points[j]); } } ++i; swap (points[i], points[right]); if (k < i - left + 1 ) { random_select (points, left, i - 1 , k); } else if (k > i - left + 1 ) { random_select (points, i + 1 , right, k - (i - left + 1 )); } } vector<vector<int >> kClosest (vector<vector<int >>& points, int k) { int n = points.size (); random_select (points, 0 , n - 1 , k); return {points.begin (), points.begin () + k}; } };
[sol3-C++api] 1 2 3 4 5 6 7 8 9 class Solution {public : vector<vector<int >> kClosest (vector<vector<int >>& points, int k) { nth_element (points.begin (), points.begin () + k - 1 , points.end (), [](const vector<int >& u, const vector<int >& v) { return u[0 ] * u[0 ] + u[1 ] * u[1 ] < v[0 ] * v[0 ] + v[1 ] * v[1 ]; }); return {points.begin (), points.begin () + k}; } };
[sol3-Java] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 class Solution { Random rand = new Random (); public int [][] kClosest(int [][] points, int k) { int n = points.length; random_select(points, 0 , n - 1 , k); return Arrays.copyOfRange(points, 0 , k); } public void random_select (int [][] points, int left, int right, int k) { int pivotId = left + rand.nextInt(right - left + 1 ); int pivot = points[pivotId][0 ] * points[pivotId][0 ] + points[pivotId][1 ] * points[pivotId][1 ]; swap(points, right, pivotId); int i = left - 1 ; for (int j = left; j < right; ++j) { int dist = points[j][0 ] * points[j][0 ] + points[j][1 ] * points[j][1 ]; if (dist <= pivot) { ++i; swap(points, i, j); } } ++i; swap(points, i, right); if (k < i - left + 1 ) { random_select(points, left, i - 1 , k); } else if (k > i - left + 1 ) { random_select(points, i + 1 , right, k - (i - left + 1 )); } } public void swap (int [][] points, int index1, int index2) { int [] temp = points[index1]; points[index1] = points[index2]; points[index2] = temp; } }
[sol3-Python] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 class Solution : def kClosest (self, points: List [List [int ]], k: int ) -> List [List [int ]]: def random_select (left: int , right: int , k: int ): pivot_id = random.randint(left, right) pivot = points[pivot_id][0 ] ** 2 + points[pivot_id][1 ] ** 2 points[right], points[pivot_id] = points[pivot_id], points[right] i = left - 1 for j in range (left, right): if points[j][0 ] ** 2 + points[j][1 ] ** 2 <= pivot: i += 1 points[i], points[j] = points[j], points[i] i += 1 points[i], points[right] = points[right], points[i] if k < i - left + 1 : random_select(left, i - 1 , k) elif k > i - left + 1 : random_select(i + 1 , right, k - (i - left + 1 )) n = len (points) random_select(0 , n - 1 , k) return points[:k]
[sol3-Golang] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 func less (p, q []int ) bool { return p[0 ]*p[0 ]+p[1 ]*p[1 ] < q[0 ]*q[0 ]+q[1 ]*q[1 ] } func kClosest (points [][]int , k int ) (ans [][]int ) { rand.Shuffle(len (points), func (i, j int ) { points[i], points[j] = points[j], points[i] }) var quickSelect func (left, right int ) quickSelect = func (left, right int ) { if left == right { return } pivot := points[right] lessCount := left for i := left; i < right; i++ { if less(points[i], pivot) { points[i], points[lessCount] = points[lessCount], points[i] lessCount++ } } points[right], points[lessCount] = points[lessCount], points[right] if lessCount+1 == k { return } else if lessCount+1 < k { quickSelect(lessCount+1 , right) } else { quickSelect(left, lessCount-1 ) } } quickSelect(0 , len (points)-1 ) return points[:k] }
[sol3-C] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 void swap (int ** a, int ** b) { int * t = *a; *a = *b, *b = t; } void random_select (int ** points, int left, int right, int k) { int pivot_id = rand() % (right - left + 1 ) + left; int pivot = points[pivot_id][0 ] * points[pivot_id][0 ] + points[pivot_id][1 ] * points[pivot_id][1 ]; swap(points[right], points[pivot_id]); int i = left - 1 ; for (int j = left; j < right; ++j) { int dist = points[j][0 ] * points[j][0 ] + points[j][1 ] * points[j][1 ]; if (dist <= pivot) { ++i; swap(&points[i], &points[j]); } } ++i; swap(&points[i], &points[right]); if (k < i - left + 1 ) { random_select(points, left, i - 1 , k); } else if (k > i - left + 1 ) { random_select(points, i + 1 , right, k - (i - left + 1 )); } } int ** kClosest (int ** points, int pointsSize, int * pointsColSize, int k, int * returnSize, int ** returnColumnSizes) { srand(time(0 )); random_select(points, 0 , pointsSize - 1 , k); *returnSize = k; *returnColumnSizes = malloc (sizeof (int ) * k); int ** ret = malloc (sizeof (int *) * k); for (int i = 0 ; i < k; i++) { (*returnColumnSizes)[i] = 2 ; ret[i] = malloc (sizeof (int ) * 2 ); ret[i][0 ] = points[i][0 ], ret[i][1 ] = points[i][1 ]; } return ret; }
复杂度分析
时间复杂度:期望为 O(n),其中 n 是数组 points 的长度。由于证明过程很繁琐,所以不在这里展开讲。具体证明可以参考《算法导论》第 9 章第 2 小节。
最坏情况下,时间复杂度为 O(n^2)。具体地,每次的划分点都是最大值或最小值,一共需要划分 n-1 次,而一次划分需要线性的时间复杂度,所以最坏情况下时间复杂度为 O(n^2)。
空间复杂度:期望为 O(\log n),即为递归调用的期望深度。
最坏情况下,空间复杂度为 O(n),此时需要划分 n-1 次,对应递归的深度为 n-1 层,所以最坏情况下时间复杂度为 O(n)。
然而注意到代码中的递归都是「尾递归」,因此如果编译器支持尾递归优化,那么空间复杂度总为 O(1)。即使不支持尾递归优化,我们也可以很方便地将上面的代码改成循环迭代的写法。