我们可以从字典树的根结点开始遍历,遍历的参照对象是 a_i 和 x。假设我们当前遍历到了第 k 个二进制位:
如果 x 的第 k 个二进制位为 0,那么此时不存在使得 a_i \oplus a_j \lt x 条件成立的 a_j,设 r 是 a_i 的第 k 个二进制位,我们需要往表示 r 的子节点走,这保证了路径上的数值与 a_i 做异或后前缀与 x 相同。
如果 x 的第 k 个二进制位为 1,设 r 是 a_i 的第 k 个二进制位,那么此时表示 r 的子节点中记录的数字,就是使得 a_i \oplus a_j \lt x 条件成立的 a_j 的个数,将它累加到答案中。然后我们需要往表示 r\oplus 1 的子节点走,这保证了路径上的数值与 a_i 做异或后前缀与 x 相同。
如果在过程中,出现某个子节点不存在使得过程无法继续,我们需要立刻返回答案。否则在最后,我们遍历完所有的 15 个二进制位后,到达的最后一个节点中记录的数字是使得 a_i \oplus a_j = x 条件成立的 a_j 的个数,也将其累加到答案中。至此,我们求出来所有使得 a_i \oplus a_j \le x 条件成立的 a_j 的个数。
defadd(self, num: int) -> None: cur = self.root for k inrange(HIGH_BIT, -1, -1): bit = (num >> k) & 1 ifnot cur.children[bit]: cur.children[bit] = TrieNode() cur = cur.children[bit] cur.sum += 1
defget(self, num: int, x: int) -> int: res = 0 cur = self.root for k inrange(HIGH_BIT, -1, -1): bit = (num >> k) & 1 if (x >> k) & 1: if cur.children[bit]: res += cur.children[bit].sum ifnot cur.children[bit ^ 1]: return res cur = cur.children[bit ^ 1] else: ifnot cur.children[bit]: return res cur = cur.children[bit] res += cur.sum return res
classSolution: defcountPairs(self, nums: List[int], low: int, high: int) -> int: deff(nums: List[int], x: int) -> int: res = 0 trie = Trie() for i inrange(1, len(nums)): trie.add(nums[i - 1]) res += trie.get(nums[i], x) return res return f(nums, high) - f(nums, low - 1)
public: voidadd(int num){ Trie* cur = root; for (int k = HIGH_BIT; k >= 0; k--) { int bit = (num >> k) & 1; if (cur->son[bit] == nullptr) { cur->son[bit] = newTrie(); } cur = cur->son[bit]; cur->sum++; } }
intget(int num, int x){ Trie* cur = root; int sum = 0; for (int k = HIGH_BIT; k >= 0; k--) { int r = (num >> k) & 1; if ((x >> k) & 1) { if (cur->son[r] != nullptr) { sum += cur->son[r]->sum; } if (cur->son[r ^ 1] == nullptr) { return sum; } cur = cur->son[r ^ 1]; } else { if (cur->son[r] == nullptr) { return sum; } cur = cur->son[r]; } } sum += cur->sum; return sum; }
intf(vector<int>& nums, int x){ root = newTrie(); int res = 0; for (int i = 1; i < nums.size(); i++) { add(nums[i - 1]); res += get(nums[i], x); } return res; }
intcountPairs(vector<int>& nums, int low, int high){ returnf(nums, high) - f(nums, low - 1); } };
publicintCountPairs(int[] nums, int low, int high) { return F(nums, high) - F(nums, low - 1); }
publicintF(int[] nums, int x) { root = new Trie(); int res = 0; for (int i = 1; i < nums.Length; i++) { Add(nums[i - 1]); res += Get(nums[i], x); } return res; }
publicvoidAdd(int num) { Trie cur = root; for (int k = HIGH_BIT; k >= 0; k--) { int bit = (num >> k) & 1; if (cur.son[bit] == null) { cur.son[bit] = new Trie(); } cur = cur.son[bit]; cur.sum++; } }
publicintGet(int num, int x) { Trie cur = root; int sum = 0; for (int k = HIGH_BIT; k >= 0; k--) { int r = (num >> k) & 1; if (((x >> k) & 1) != 0) { if (cur.son[r] != null) { sum += cur.son[r].sum; } if (cur.son[r ^ 1] == null) { return sum; } cur = cur.son[r ^ 1]; } else { if (cur.son[r] == null) { return sum; } cur = cur.son[r]; } } sum += cur.sum; return sum; } }
classTrie { // son[0] 表示左子树,son[1] 表示右子树 public Trie[] son = new Trie[2]; publicint sum;
voidadd(int num, Trie *root) { Trie* cur = root; for (int k = HIGH_BIT; k >= 0; k--) { int bit = (num >> k) & 1; if (cur->son[bit] == NULL) { cur->son[bit] = creatTrieNode(); } cur = cur->son[bit]; cur->sum++; } }
intget(int num, int x, const Trie *root) { const Trie* cur = root; int sum = 0; for (int k = HIGH_BIT; k >= 0; k--) { int r = (num >> k) & 1; if ((x >> k) & 1) { if (cur->son[r] != NULL) { sum += cur->son[r]->sum; } if (cur->son[r ^ 1] == NULL) { return sum; } cur = cur->son[r ^ 1]; } else { if (cur->son[r] == NULL) { return sum; } cur = cur->son[r]; } } sum += cur->sum; return sum; }
intf(constint *nums, int numsSize, int x) { Trie *root = creatTrieNode(); int res = 0; for (int i = 1; i < numsSize; i++) { add(nums[i - 1], root); res += get(nums[i], x, root); } freeTrie(root); return res; }
intcountPairs(int* nums, int numsSize, int low, int high) { return f(nums, numsSize, high) - f(nums, numsSize, low - 1); }
constf = (nums, x) => { root = newTrie(); let res = 0;
constadd = (num) => { let cur = root; for (let k = HIGH_BIT; k >= 0; k--) { let bit = (num >> k) & 1; if (!cur.son[bit]) { cur.son[bit] = newTrie(); } cur = cur.son[bit]; cur.sum++; } }
constget = (num, x) => { let cur = root; let sum = 0; for (let k = HIGH_BIT; k >= 0; k--) { let r = (num >> k) & 1; if (((x >> k) & 1) !== 0) { if (cur.son[r]) { sum += cur.son[r].sum; } if (!cur.son[r ^ 1]) { return sum; } cur = cur.son[r ^ 1]; } else { if (!cur.son[r]) { return sum; } cur = cur.son[r]; } } sum += cur.sum; return sum; }
for (let i = 1; i < nums.length; i++) { add(nums[i - 1]); res += get(nums[i], x); } return res; }
func(t *trie) put(v int) *trieNode { o := t.root for i := trieBitLen; i >= 0; i-- { b := v >> i & 1 if o.son[b] == nil { o.son[b] = &trieNode{} } o = o.son[b] o.cnt++ } return o }
func(t *trie) countLimitXOR(v, limit int) (cnt int) { o := t.root for i := trieBitLen; i >= 0; i-- { b := v >> i & 1 if limit>>i&1 > 0 { if o.son[b] != nil { cnt += o.son[b].cnt } b ^= 1 } if o.son[b] == nil { return } o = o.son[b] } return }
funccountPairs(nums []int, low, high int) (ans int) { t := &trie{&trieNode{} } t.put(nums[0]) for _, v := range nums[1:] { ans += t.countLimitXOR(v, high+1) - t.countLimitXOR(v, low) t.put(v) } return }
复杂度分析
时间复杂度:O(n\log C)。其中 n 是 nums 的长度,C 是数组中的元素范围,在本题中 C \lt 2^{15。我们需要将 a_0,a_1,\cdots,a_{n-2 加入到字典树中,并且需要以 a_1,a_2,\cdots,a_{n-1 以及 x 作为「参照对象」在字典树上进行遍历,每一项操作的单次时间复杂度为 O(\log C),因此总时间复杂度为 O(n\log C)。
空间复杂度:O(n\log C)。每一个元素在字典树中需要使用 O(\log C) 的空间,因此总空间复杂度为 O(n\log C)。