题目描述
给定一个整数数组 nums
(下标 从 0
开始 计数)以及两个整数:low
和 high
,请返回 漂亮数对 的数目。
漂亮数对 是一个形如 (i, j)
的数对,其中 0 <= i < j < nums.length
且 low <= (nums[i] XOR nums[j]) <= high
。
样例
输入:nums = [1,4,2,7], low = 2, high = 6
输出:6
解释:所有漂亮数对 (i, j) 列出如下:
- (0, 1): nums[0] XOR nums[1] = 5
- (0, 2): nums[0] XOR nums[2] = 3
- (0, 3): nums[0] XOR nums[3] = 6
- (1, 2): nums[1] XOR nums[2] = 6
- (1, 3): nums[1] XOR nums[3] = 3
- (2, 3): nums[2] XOR nums[3] = 5
输入:nums = [9,8,4,2,1], low = 5, high = 14
输出:8
解释:所有漂亮数对 (i, j) 列出如下:
- (0, 2): nums[0] XOR nums[2] = 13
- (0, 3): nums[0] XOR nums[3] = 11
- (0, 4): nums[0] XOR nums[4] = 8
- (1, 2): nums[1] XOR nums[2] = 12
- (1, 3): nums[1] XOR nums[3] = 10
- (1, 4): nums[1] XOR nums[4] = 9
- (2, 3): nums[2] XOR nums[3] = 6
- (2, 4): nums[2] XOR nums[4] = 5
限制
1 <= nums.length <= 2 * 10^4
1 <= nums[i] <= 2 * 10^4
1 <= low <= high <= 2 * 10^4
算法
(Trie 树) $O(n \log S)$
- 将给定的数组建立 01 Trie 树,树中节点记录以当前节点为根的叶子节点的个数。
- 将问题分解为两部分:求出小于等于
high
的数对个数,然后再求出小于等于low - 1
的数对个数,两个值相减得到答案。 - 对于每个查询,从树的根节点开始向下,尽量走与当前值相反的路径(这样才可能获得更多的值),具体参考代码注释。
时间复杂度
- 每个节点都有一条长度为 $O(\log S)$ 的路径,查询时同样需要遍历长度为 $O(\log S)$ 的路径。其中 $S$ 为数字的最大值。
- 故总时间复杂度为 $O(n \log S)$。
空间复杂度
- 需要 $O(n \log S)$ 的空间存储所有节点。
C++ 代码
struct Node {
int nums;
Node *nxt[2];
Node() {
nums = 0;
nxt[0] = nxt[1] = NULL;
}
};
class Solution {
private:
Node *rt;
void insert(int x) {
Node *p = rt;
for (int i = 15; i >= 0; i--) {
int v = (x >> i) & 1;
if (p->nxt[v] == NULL) p->nxt[v] = new Node();
p->nums++;
p = p->nxt[v];
}
p->nums++;
}
int query(int x, int limit) {
int res = 0, cur = 0;
// res 是符合要求的叶子节点的数量,cur 是当前匹配的值(不能超过 limit)
Node *p = rt;
for (int i = 15; i >= 0; i--) {
int v = (x >> i) & 1;
if (p->nxt[v] == NULL) { // v 的分支为空(只有 v^1 分支)
if ((cur | (1 << i)) <= limit) { // 判断是否可以往 v^1 走
cur |= 1 << i;
p = p->nxt[v^1];
} else { // 不可以,则直接返回了,当前叶子节点的值也不能取
return res;
}
} else if (p->nxt[v^1] == NULL) { // v^1 的分支为空(只有 v 分支)
p = p->nxt[v]; // 直接往 v 分支走即可,cur 不会超过 limit
} else { // 两个分支都存在,优先尝试从 v^1 走
if ((cur | (1 << i)) <= limit) { // 可以从 v^1 走
res += p->nxt[v]->nums; // 需要累计上 v 分支下的节点个数
cur |= 1 << i;
p = p->nxt[v^1];
} else {
p = p->nxt[v];
}
}
}
return res + p->nums; // 最后不要忘了叶子节点自身的值
}
int calc(const vector<int> &nums, int limit) {
const int n = nums.size();
int tot = 0;
for (int i = 0; i < n; i++)
tot += query(nums[i], limit) - 1; // 这里减 1 是为了去掉 (i, i) 这个非法数对
return tot / 2; // 每个合法数对被统计了两次,所以除以 2
}
public:
int countPairs(vector<int>& nums, int low, int high) {
const int n = nums.size();
rt = new Node();
for (int i = 0; i < n; i++)
insert(nums[i]);
return calc(nums, high) - calc(nums, low - 1);
}
};
一遍插入,一遍计算,这样不需要减1,也不需要除2,这么麻烦
强
O(n \log S) 这个复杂度不对吧,还是符号的问题
符号的问题,已修复