对于本题而言,我们可以首先求出第 0 行和第 1 行的前 k 个最小数组和,将该结果与第 2 行再求出前 k 个最小数组和,再将该结果与第 3 行再求出前 k 个最小数组和,以此类推。当使用完最后一行后,就可以得到整个矩阵的前 k 个最小数组和,也就得到了第 k 个最小数组和。
在求解的过程中,如果数组和的数量不够 k 个,就求出所有可能的数组和。
这样做的时间复杂度为 O(m \times F(k, n)),其中 F(k, n) 表示当两个数组的长度分别是 k 和 n 时,求出前 k 个最小数组和的时间复杂度。下面的题解部分只会讲解 m = 2 时的做法,读者也可以直接参考「373. 查找和最小的 K 对数字」的官方题解 。为了叙述方便,记这两行分别为 f 和 g,长度分别为 l_f 和 l_g。
方法一:小根堆
思路与算法
我们可以将两个数组 f 和 g 求解前 k 个最小数组和的问题转换成类似「归并排序」的问题:
我们构造 l_g 个序列,第 i 个序列包含了 f[0] + g[i], f[1] + g[i], \cdots, f[l_f - 1] + g[i]。由于 f 是非递减的,因此这个这个序列也是非递减的;
所有序列的并集恰好就是所有的 l_f \times l_g 个数组和。要想求出前 k 个最小数组和,我们就可以使用小根堆。初始时,我们将所有的 l_g 个序列的首项放入堆中,随后进行 k 次操作,每次操作我们从堆顶取出当前的最小值,再将它后面的那一项(如果有)放回堆中。这样一来,第 j~(j \geq 1) 次操作时我们得到的就是第 j 个最小数组和。
细节
上述做法的时间复杂度为 O(l_g + k \log l_g),与 l_f 无关。在实际的代码编写中,我们可以交换 f 和 g 使得 l_g 一定小于等于 l_f。
var kthSmallest = function(mat, k) { const m = mat.length; let prev = mat[0]; for (let i = 1; i < m; ++i) { prev = merge(prev, mat[i], k); } return prev[k - 1]; }
如果 l_f \times l_g < k,我们需要将 k 减少至 l_f \times l_g,因为二元组的数量并没有 k 个。
当二分查找完成并得到 thres 后,我们可以使用二重循环遍历数组 f 和 g,找出所有和小于等于 thres 的二元组。需要注意的是:
时间复杂度为 O(l_f \times l_g),较高;
和小于等于 thres 的二元组数量可能会大于 k,因为有若干个和恰好等于 thres 的二元组。
为了解决上面的这些问题,我们可以对二重循环遍历进行优化:当内层遍历 g 的循环已经不满足要求时,可以直接退出,因为后续 g 中的元素只会更大。并且在遍历的过程中,我们的判断条件改为「和小于 thres」而不是「和小于等于 thres」,这样二重循环最多只会添加 k 个二元组,时间复杂度减少至 O(k)。在这之后,如果答案的长度没有到 k,我们再补上对应数量的 thres 即可。
publicclassSolution { publicintKthSmallest(int[][] mat, int k) { int m = mat.Length; int[] prev = mat[0]; for (int i = 1; i < m; ++i) { prev = Merge(prev, mat[i], k); } return prev[k - 1]; }
publicint[] Merge(int[] f, int[] g, int k) { int left = f[0] + g[0], right = f[f.Length - 1] + g[g.Length - 1], thres = 0; k = Math.Min(k, f.Length * g.Length); while (left <= right) { int mid = (left + right) / 2; int rptr = g.Length - 1, cnt = 0; for (int lptr = 0; lptr < f.Length; ++lptr) { while (rptr >= 0 && f[lptr] + g[rptr] > mid) { --rptr; } cnt += rptr + 1; } if (cnt >= k) { thres = mid; right = mid - 1; } else { left = mid + 1; } }
IList<int> list = new List<int>(); int index = 0; for (int i = 0; i < f.Length; ++i) { for (int j = 0; j < g.Length; ++j) { int sum = f[i] + g[j]; if (sum < thres) { list.Add(sum); } else { break; } } } while (list.Count < k) { list.Add(thres); } int[] ans = list.ToArray(); Array.Sort(ans); return ans; } }
classSolution: defkthSmallest(self, mat: List[List[int]], k: int) -> int: defmerge(f: List[int], g: List[int], k: int) -> List[int]: left, right, thres = f[0] + g[0], f[-1] + g[-1], 0 k = min(k, len(f) * len(g)) while left <= right: mid = (left + right) // 2 rptr, cnt = len(g) - 1, 0 for lptr, x inenumerate(f): while rptr >= 0and x + g[rptr] > mid: rptr -= 1 cnt += rptr + 1 if cnt >= k: thres = mid; right = mid - 1; else: left = mid + 1;
ans = list() for i, fi inenumerate(f): for j, gj inenumerate(g): if (total := fi + gj) < thres: ans.append(total) else: break ans += [thres] * (k - len(ans)) ans.sort() return ans prev = mat[0] for i inrange(1, len(mat)): prev = merge(prev, mat[i], k) return prev[k - 1]
int *merge(constint *f, int fSize, constint *g, int gSize, int k, int *returnSize) { int left = f[0] + g[0], right = f[fSize - 1] + g[gSize - 1], thres = 0; k = MIN(k, fSize * gSize); while (left <= right) { int mid = (left + right) / 2; int rptr = gSize - 1, cnt = 0; for (int lptr = 0; lptr < fSize; ++lptr) { while (rptr >= 0 && f[lptr] + g[rptr] > mid) { --rptr; } cnt += rptr + 1; } if (cnt >= k) { thres = mid; right = mid - 1; } else { left = mid + 1; } }
int *ans = (int *)calloc(k, sizeof(int)); int pos = 0; for (int i = 0; i < fSize; ++i) { for (int j = 0; j < gSize; ++j) { int sum = f[i] + g[j]; if (sum < thres) { ans[pos++] = sum; } else { break; } } } while (pos < k) { ans[pos++] = thres; } qsort(ans, k, sizeof(int), cmp); *returnSize = k; return ans; }
intkthSmallest(int** mat, int matSize, int* matColSize, int k) { int m = matSize; int n = matColSize[0]; int *prev = mat[0]; int prevSize = n; for (int i = 1; i < m; i++) { int arrSize = 0; int *arr = merge(prev, prevSize, mat[i], n, k, &arrSize); prevSize = arrSize; prev = (int *)malloc(sizeof(int) * prevSize); memcpy(prev, arr, sizeof(int) * prevSize); free(arr); } return prev[k - 1]; }
var kthSmallest = function(mat, k) { const m = mat.length; let prev = mat[0]; for (let i = 1; i < m; ++i) { prev = merge(prev, mat[i], k); } return prev[k - 1]; }
constmerge = (f, g, k) => { let left = f[0] + g[0], right = f[f.length - 1] + g[g.length - 1], thres = 0; k = Math.min(k, f.length * g.length); while (left <= right) { const mid = Math.floor((left + right) / 2); let rptr = g.length - 1, cnt = 0; for (let lptr = 0; lptr < f.length; ++lptr) { while (rptr >= 0 && f[lptr] + g[rptr] > mid) { --rptr; } cnt += rptr + 1; } if (cnt >= k) { thres = mid; right = mid - 1; } else { left = mid + 1; } }
const list = []; let index = 0; for (let i = 0; i < f.length; ++i) { for (let j = 0; j < g.length; ++j) { let sum = f[i] + g[j]; if (sum < thres) { list.push(sum); } else { break; } } } while (list.length < k) { list.push(thres); } const ans = newArray(list.length).fill(0); for (let i = 0; i < list.length; ++i) { ans[i] = list[i]; } ans.sort((a, b) => a - b); return ans; };
复杂度分析
时间复杂度:O(m \times (k \log k + n) \times \log C)。在一次二分查找的过程中:
双指针部分需要的时间为 O(k + n);
二重循环遍历需要的时间为 O(k);
排序需要的时间为 O(k \log k)。
它们的和为 O(k \log k + n)。二分查找需要 O(\log C) 次,其中 C 是和的上界与下界之差,它的范围不会超过 5000 \cdot m。