classSolution: defnumSubarrayProductLessThanK(self, nums: List[int], k: int) -> int: if k == 0: return0 ans, n = 0, len(nums) logPrefix = [0] * (n + 1) for i, num inenumerate(nums): logPrefix[i + 1] = logPrefix[i] + log(num) logK = log(k) for j inrange(1, n + 1): l = bisect_right(logPrefix, logPrefix[j] - logK + 1e-10, 0, j) ans += j - l return ans
var numSubarrayProductLessThanK = function(nums, k) { if (k === 0) { return0; } const n = nums.length; const logPrefix = newArray(n + 1).fill(0); for (let i = 0; i < n; i++) { logPrefix[i + 1] = logPrefix[i] + Math.log(nums[i]); } const logk = Math.log(k); let ret = 0; for (let j = 0; j < n; j++) { let l = 0; let r = j + 1; let idx = j + 1; const val = logPrefix[j + 1] - logk + 1e-10; while (l <= r) { const mid = Math.floor((l + r) / 2); if (logPrefix[mid] > val) { idx = mid; r = mid - 1; } else { l = mid + 1; } } ret += j + 1 - idx; } return ret; };
[sol1-Golang]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
funcnumSubarrayProductLessThanK(nums []int, k int) (ans int) { if k == 0 { return } n := len(nums) logPrefix := make([]float64, n+1) for i, num := range nums { logPrefix[i+1] = logPrefix[i] + math.Log(float64(num)) } logK := math.Log(float64(k)) for j := 1; j <= n; j++ { l := sort.SearchFloat64s(logPrefix[:j], logPrefix[j]-logK+1e-10) ans += j - l } return }
我们固定子数组 [i, j] 的右端点 j 时,显然左端点 i 越大,子数组元素乘积越小。对于子数组 [i, j],当左端点 i \ge l_1 时,所有子数组的元素乘积都小于 k,当左端点 i \lt l_1 时,所有子数组的元素乘积都大于等于 k。那么对于右端点为 j + 1 的所有子数组,它的左端点 i 就不需要从 0 开始枚举,因为对于所有 i \lt l_1 的子数组,它们的元素乘积都大于等于 k。我们只要从 i = l_1 处开始枚举,直到子数组 i = l_2 时子数组 [l_2, j + 1] 的元素乘积小于 k,那么左端点 i \ge l_2 所有子数组的元素乘积都小于 k。
根据上面的分析,我们枚举子数组的右端点 j,并且左端点从 i = 0 开始,用 prod 记录子数组 [i, j] 的元素乘积。每枚举一个右端点 j,如果当前子数组元素乘积 prod 大于等于 k,那么我们右移左端点 i 直到满足当前子数组元素乘积小于 k 或者 i > j,那么元素乘积小于 k 的子数组数目为 j - i + 1。返回所有数目之和。
prod 的值始终不超过 k \times \max_l {\textit{nums}[l] \,因此无需担心整型溢出的问题。
代码
[sol2-Python3]
1 2 3 4 5 6 7 8 9 10
classSolution: defnumSubarrayProductLessThanK(self, nums: List[int], k: int) -> int: ans, prod, i = 0, 1, 0 for j, num inenumerate(nums): prod *= num while i <= j and prod >= k: prod //= nums[i] i += 1 ans += j - i + 1 return ans
[sol2-C++]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
classSolution { public: intnumSubarrayProductLessThanK(vector<int>& nums, int k){ int n = nums.size(), ret = 0; int prod = 1, i = 0; for (int j = 0; j < n; j++) { prod *= nums[j]; while (i <= j && prod >= k) { prod /= nums[i]; i++; } ret += j - i + 1; } return ret; } };
[sol2-Java]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
classSolution { publicintnumSubarrayProductLessThanK(int[] nums, int k) { intn= nums.length, ret = 0; intprod=1, i = 0; for (intj=0; j < n; j++) { prod *= nums[j]; while (i <= j && prod >= k) { prod /= nums[i]; i++; } ret += j - i + 1; } return ret; } }
[sol2-C#]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
publicclassSolution { publicintNumSubarrayProductLessThanK(int[] nums, int k) { int n = nums.Length, ret = 0; int prod = 1, i = 0; for (int j = 0; j < n; j++) { prod *= nums[j]; while (i <= j && prod >= k) { prod /= nums[i]; i++; } ret += j - i + 1; } return ret; } }
[sol2-C]
1 2 3 4 5 6 7 8 9 10 11 12 13
intnumSubarrayProductLessThanK(int* nums, int numsSize, int k){ int ret = 0; int prod = 1, i = 0; for (int j = 0; j < numsSize; j++) { prod *= nums[j]; while (i <= j && prod >= k) { prod /= nums[i]; i++; } ret += j - i + 1; } return ret; }
[sol2-JavaScript]
1 2 3 4 5 6 7 8 9 10 11 12 13
var numSubarrayProductLessThanK = function(nums, k) { let n = nums.length, ret = 0; let prod = 1, i = 0; for (let j = 0; j < n; j++) { prod *= nums[j]; while (i <= j && prod >= k) { prod /= nums[i]; i++; } ret += j - i + 1; } return ret; };
[sol2-Golang]
1 2 3 4 5 6 7 8 9 10 11
funcnumSubarrayProductLessThanK(nums []int, k int) (ans int) { prod, i := 1, 0 for j, num := range nums { prod *= num for ; i <= j && prod >= k; i++ { prod /= nums[i] } ans += j - i + 1 } return }
复杂度分析
时间复杂度:O(n),其中 n 是数组 nums 的长度。两个端点 i 和 j 的增加次数都不超过 n。