设 LCA 为 a 和 b 的最近公共祖先,那么环长等于 LCA 到 a 的距离加 LCA 到 b 的距离加一。
如何找 LCA?
注意到在完全二叉树中,深度越深的点,其编号必定大于上一层的节点编号,根据这个性质,我们可以不断循环,每次循环比较 a 和 b 的大小:
如果 a>b,则 a 的深度大于等于 b 的深度,那么把 a 移动到其父节点,即 a=a/2;
如果 a<b,则 a 的深度小于等于 b 的深度,那么把 b 移动到其父节点,即 b=b/2;
如果 a=b,则找到了 LCA,退出循环。
循环次数加一即为环长。
[sol1-Python3]
1 2 3 4 5 6 7 8 9 10
classSolution: defcycleLengthQueries(self, n: int, queries: List[List[int]]) -> List[int]: for i, (a, b) inenumerate(queries): res = 1 while a != b: if a > b: a //= 2 else: b //= 2 res += 1 queries[i] = res return queries
[sol1-Java]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
classSolution { publicint[] cycleLengthQueries(int n, int[][] queries) { varm= queries.length; varans=newint[m]; for (vari=0; i < m; ++i) { intres=1, a = queries[i][0], b = queries[i][1]; while (a != b) { if (a > b) a /= 2; else b /= 2; ++res; } ans[i] = res; } return ans; } }
[sol1-C++]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
classSolution { public: vector<int> cycleLengthQueries(int n, vector<vector<int>> &queries){ int m = queries.size(); vector<int> ans(m); for (int i = 0; i < m; ++i) { int res = 1, a = queries[i][0], b = queries[i][1]; while (a != b) { a > b ? a /= 2 : b /= 2; ++res; } ans[i] = res; } return ans; } };
[sol1-Go]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
funccycleLengthQueries(_ int, queries [][]int) []int { ans := make([]int, len(queries)) for i, q := range queries { res := 1 for a, b := q[0], q[1]; a != b; res++ { if a > b { a /= 2 } else { b /= 2 } } ans[i] = res } return ans }
复杂度分析
时间复杂度:O(nm),其中 m 为 queries 的长度。回答一个询问的时间复杂度为 O(n)。
空间复杂度:O(1),仅用到若干额外变量。
方法二:位运算优化
进一步挖掘完全二叉树的性质:节点编号的二进制的长度恰好等于节点深度。
以二进制下的 a=110,\ b=11101 为例:
算出两个节点的深度,分别为 3 和 5,深度之差 d=5-3=2,那么把 b 右移 d 位(相当于上跳 d 步),得到 111,这样 b 就和 a 在同一层了。
如果此时 a=b,说明 a 就是 LCA,答案为 d+1。
如果此时 a\ne b,计算 a 异或 b 的结果 1,它的二进制长度为 L=1,那么 a 和 b 需要各上跳 L 步才能到达 LCA,答案为 d+2L+1。
[sol2-Python3]
1 2 3 4 5 6 7
classSolution: defcycleLengthQueries(self, n: int, queries: List[List[int]]) -> List[int]: for i, (a, b) inenumerate(queries): if a > b: a, b = b, a # 保证 a <= b d = b.bit_length() - a.bit_length() queries[i] = d + (a ^ (b >> d)).bit_length() * 2 + 1 return queries
[sol2-Java]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
classSolution { publicint[] cycleLengthQueries(int n, int[][] queries) { varm= queries.length; varans=newint[m]; for (vari=0; i < m; ++i) { inta= queries[i][0], b = queries[i][1]; if (a > b) { vartmp= a; a = b; b = tmp; // 交换,保证 a <= b } vard= Integer.numberOfLeadingZeros(a) - Integer.numberOfLeadingZeros(b); ans[i] = d + (32 - Integer.numberOfLeadingZeros(a ^ (b >> d))) * 2 + 1; } return ans; } }
[sol2-C++]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
classSolution { public: vector<int> cycleLengthQueries(int n, vector<vector<int>> &queries){ int m = queries.size(); vector<int> ans(m); for (int i = 0; i < m; ++i) { int a = queries[i][0], b = queries[i][1]; if (a > b) swap(a, b); // 保证 a <= b int d = __builtin_clz(a) - __builtin_clz(b); b >>= d; // 上跳,和 a 在同一层 ans[i] = a == b ? d + 1 : d + (32 - __builtin_clz(a ^ b)) * 2 + 1; } return ans; } };
[sol2-Go]
1 2 3 4 5 6 7 8 9 10 11 12
funccycleLengthQueries(_ int, queries [][]int) []int { ans := make([]int, len(queries)) for i, q := range queries { a, b := uint(q[0]), uint(q[1]) if a > b { a, b = b, a // 保证 a <= b } d := bits.Len(b) - bits.Len(a) ans[i] = d + bits.Len(b>>d^a)*2 + 1 } return ans }