0982-按位与为零的三元组

Raphael Liu Lv10

给你一个整数数组 nums ,返回其中 按位与三元组 的数目。

按位与三元组 是由下标 (i, j, k) 组成的三元组,并满足下述全部条件:

  • 0 <= i < nums.length
  • 0 <= j < nums.length
  • 0 <= k < nums.length
  • nums[i] & nums[j] & nums[k] == 0 ,其中 & 表示按位与运算符。

示例 1:

**输入:** nums = [2,1,3]
**输出:** 12
**解释:** 可以选出如下 i, j, k 三元组:
(i=0, j=0, k=1) : 2 & 2 & 1
(i=0, j=1, k=0) : 2 & 1 & 2
(i=0, j=1, k=1) : 2 & 1 & 1
(i=0, j=1, k=2) : 2 & 1 & 3
(i=0, j=2, k=1) : 2 & 3 & 1
(i=1, j=0, k=0) : 1 & 2 & 2
(i=1, j=0, k=1) : 1 & 2 & 1
(i=1, j=0, k=2) : 1 & 2 & 3
(i=1, j=1, k=0) : 1 & 1 & 2
(i=1, j=2, k=0) : 1 & 3 & 2
(i=2, j=0, k=1) : 3 & 2 & 1
(i=2, j=1, k=0) : 3 & 1 & 2

示例 2:

**输入:** nums = [0,0,0]
**输出:** 27

提示:

  • 1 <= nums.length <= 1000
  • 0 <= nums[i] < 216

方法一:枚举

思路与算法

最容易想到的做法是使用三重循环枚举三元组 (i,j,k),再判断 nums}[i] & \textit{nums}[j] & \textit{nums}[k] 的值是否为 0。但这样做的时间复杂度是 O(n^3),其中 n 是数组 nums 的长度,会超出时间限制。

注意到题目中给定了一个限制:数组 nums 的元素不会超过 2^{16。这说明,nums}[i] & \textit{nums}[j] 的值也不会超过 2^{16。因此,我们可以首先使用二重循环枚举 i 和 j,并使用一个长度为 2^{16 的数组(或哈希表)存储每一种 nums}[i] & \textit{nums}[j] 以及它出现的次数。随后,我们再使用二重循环,其中的一重枚举记录频数的数组,另一重枚举 k,这样就可以将时间复杂度从 O(n^3) 降低至 O(n^2 + 2^{16} \cdot n)。

代码

[sol1-C++]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public:
int countTriplets(vector<int>& nums) {
vector<int> cnt(1 << 16);
for (int x: nums) {
for (int y: nums) {
++cnt[x & y];
}
}
int ans = 0;
for (int x: nums) {
for (int mask = 0; mask < (1 << 16); ++mask) {
if ((x & mask) == 0) {
ans += cnt[mask];
}
}
}
return ans;
}
};
[sol1-Java]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Solution {
public int countTriplets(int[] nums) {
int[] cnt = new int[1 << 16];
for (int x : nums) {
for (int y : nums) {
++cnt[x & y];
}
}
int ans = 0;
for (int x : nums) {
for (int mask = 0; mask < (1 << 16); ++mask) {
if ((x & mask) == 0) {
ans += cnt[mask];
}
}
}
return ans;
}
}
[sol1-C#]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class Solution {
public int CountTriplets(int[] nums) {
int[] cnt = new int[1 << 16];
foreach (int x in nums) {
foreach (int y in nums) {
++cnt[x & y];
}
}
int ans = 0;
foreach (int x in nums) {
for (int mask = 0; mask < (1 << 16); ++mask) {
if ((x & mask) == 0) {
ans += cnt[mask];
}
}
}
return ans;
}
}
[sol1-Python3]
1
2
3
4
5
6
7
8
9
10
class Solution:
def countTriplets(self, nums: List[int]) -> int:
cnt = Counter((x & y) for x in nums for y in nums)

ans = 0
for x in nums:
for mask, freq in cnt.items():
if (x & mask) == 0:
ans += freq
return ans
[sol1-C]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int countTriplets(int* nums, int numsSize) {
int *cnt = (int *)calloc(sizeof(int), 1 << 16);
for (int i = 0; i < numsSize; i++) {
int x = nums[i];
for (int j = 0; j < numsSize; j++) {
int y = nums[j];
++cnt[x & y];
}
}
int ans = 0;
for (int i = 0; i < numsSize; i++) {
int x = nums[i];
for (int mask = 0; mask < (1 << 16); ++mask) {
if ((x & mask) == 0) {
ans += cnt[mask];
}
}
}
free(cnt);
return ans;
}
[sol1-JavaScript]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
var countTriplets = function(nums) {
const cnt = new Array(1 << 16).fill(0);
for (const x of nums) {
for (const y of nums) {
++cnt[x & y];
}
}
let ans = 0;
for (const x of nums) {
for (let mask = 0; mask < (1 << 16); ++mask) {
if ((x & mask) === 0) {
ans += cnt[mask];
}
}
}
return ans;
};
[sol1-Golang]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
func countTriplets(nums []int) int {
cnt := make(map[int]int)
for i := range nums {
for j := range nums {
cnt[nums[i]&nums[j]]++
}
}
res := 0
for i := range nums {
for k, v := range cnt {
if k&nums[i] == 0 {
res += v
}
}
}
return res
}

复杂度分析

  • 时间复杂度:O(n^2 + C \cdot n),其中 n 是数组 nums 的长度,C 是数组 nums 中的元素范围,在本题中 C = 2^{16。

  • 空间复杂度:O(C),即为数组(或哈希表)需要使用的空间。

方法二:枚举 + 子集优化

思路与算法

在方法一的第二个二重循环中,我们需要枚举 [0, 2^{16}) 中的所有整数。即使我们使用哈希表代替数组,在数据随机的情况下,nums}[i] & \textit{nums}[j] 也会覆盖 [0, 2^{16}) 中的大部分整数,使得哈希表不会有明显更好的表现。

这里我们介绍另一个常数级别的优化。当我们在第二个二重循环中枚举 k 时,我们希望统计出所有与 nums}[k] 按位与为 0 的二元组数量。也就是说:

如果 nums}[k] 的第 t 个二进制位是 0,那么二元组的第 t 个二进制位才可以是 1,否则一定不能是 1。

因此,我们可以将 nums}[k] 与 2^{16}-1(即二进制表示下的 16 个 1)进行按位异或运算。这样一来,满足要求的二元组的二进制表示中包含的 1 必须是该数的子集,例如该数是 (100111)_2,那么满足要求的二元组可以是 (100010)_2 或者 (000110)_2,但不能是 (010001)_2。

此时,要想得到所有该数的子集,我们可以使用「二进制枚举子集」的技巧。这里给出对应的步骤:

  • 记该数为 x。我们用 sub 表示当前枚举到的子集。初始时 sub} = x,因为 x 也是本身的子集;

  • 我们不断地令 sub} = (\textit{sub} - 1) & x,其中 & 表示按位与运算。这样我们就可以从大到小枚举 x 的所有子集。当 sub} = 0 时枚举结束。

我们可以粗略估计这样做可以优化的时间复杂度:当数据随机时,x 的二进制表示中期望有 16/2=8 个 1,那么「二进制枚举子集」需要枚举 2^8 次。在优化前,我们需要枚举 2^{16 次,因此常数项就缩减到原来的 \dfrac{1/2^8。但在最坏情况下,x 的二进制表示有 16 个 1,两种方法的表现没有区别。

代码

[sol2-C++]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public:
int countTriplets(vector<int>& nums) {
vector<int> cnt(1 << 16);
for (int x: nums) {
for (int y: nums) {
++cnt[x & y];
}
}
int ans = 0;
for (int x: nums) {
x = x ^ 0xffff;
for (int sub = x; sub; sub = (sub - 1) & x) {
ans += cnt[sub];
}
ans += cnt[0];
}
return ans;
}
};
[sol2-Java]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Solution {
public int countTriplets(int[] nums) {
int[] cnt = new int[1 << 16];
for (int x : nums) {
for (int y : nums) {
++cnt[x & y];
}
}
int ans = 0;
for (int x : nums) {
x = x ^ 0xffff;
for (int sub = x; sub != 0; sub = (sub - 1) & x) {
ans += cnt[sub];
}
ans += cnt[0];
}
return ans;
}
}
[sol2-C#]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class Solution {
public int CountTriplets(int[] nums) {
int[] cnt = new int[1 << 16];
foreach (int x in nums) {
foreach (int y in nums) {
++cnt[x & y];
}
}
int ans = 0;
foreach (int x in nums) {
int y = x ^ 0xffff;
for (int sub = y; sub != 0; sub = (sub - 1) & y) {
ans += cnt[sub];
}
ans += cnt[0];
}
return ans;
}
}
[sol2-Python3]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Solution:
def countTriplets(self, nums: List[int]) -> int:
cnt = Counter((x & y) for x in nums for y in nums)

ans = 0
for x in nums:
sub = x = x ^ 0xffff
while True:
if sub in cnt:
ans += cnt[sub]
if sub == 0:
break
sub = (sub - 1) & x

return ans
[sol2-C]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
int countTriplets(int* nums, int numsSize) {
int *cnt = (int *)calloc(sizeof(int), 1 << 16);
for (int i = 0; i < numsSize; i++) {
int x = nums[i];
for (int j = 0; j < numsSize; j++) {
int y = nums[j];
++cnt[x & y];
}
}
int ans = 0;
for (int i = 0; i < numsSize; i++) {
int x = nums[i] ^ 0xffff;
for (int sub = x; sub; sub = (sub - 1) & x) {
ans += cnt[sub];
}
ans += cnt[0];
}
free(cnt);
return ans;
}
[sol2-JavaScript]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
var countTriplets = function(nums) {
const cnt = new Array(1 << 16).fill(0);
for (const x of nums) {
for (const y of nums) {
++cnt[x & y];
}
}
let ans = 0;
for (let x of nums) {
x = x ^ 0xffff;
for (let sub = x; sub !== 0; sub = (sub - 1) & x) {
ans += cnt[sub];
}
ans += cnt[0];
}
return ans;
};
[sol1-Golang]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
func countTriplets(nums []int) int {
var cnt [1 << 16]int
for i := range nums {
for j := range nums {
cnt[nums[i]&nums[j]]++
}
}
res := 0
for i := range nums {
x := nums[i] ^ 0xffff
for sub := x; sub > 0; sub = (sub - 1) & x {
res += cnt[sub]
}
res += cnt[0]
}
return res
}

复杂度分析

  • 时间复杂度:O(n^2 + C \cdot n),其中 n 是数组 nums 的长度,C 是数组 nums 中的元素范围,在本题中 C = 2^{16。

  • 空间复杂度:O(C),即为数组(或哈希表)需要使用的空间。

 Comments
On this page
0982-按位与为零的三元组