题目描述
给定一个非负整数的数据流输入 $a_1,a_2,…,a_n,…,$ 将到目前为止看到的数字总结为不相交的区间列表。
例如,假设数据流中的整数为 1,3,7,2,6,…,每次的总结为:
[1, 1]
[1, 1], [3, 3]
[1, 1], [3, 3], [7, 7]
[1, 3], [7, 7]
[1, 3], [6, 7]
进阶
如果有很多合并,并且与数据流的大小相比,不相交区间的数量很小,该怎么办?
算法1
(平衡树)
题意是让我们动态的插入一些点,同时能够按照从小到大的顺序输出当前点所构成的区间。那么我们可以考虑当我们插入一个点时,我们需要知道它是不是已经在某些区间内了,同时如果它不在某些区间内,那么它能否将之前存在的区间给合并起来。总结起来就是当我们插入点时需要快速的知道当前点val
是否在某个区间[a, b]
内,并且是否存在以val - 1
为右端点的区间以及以val + 1
为左端点的区间。
如果我们将区间以vector
存起来,并按左端点排好序,那么插入点时我们可以用二分来找到最后一个左端点小于或等于当前点的区间,然后判断当前点是否已在区间中,之后再根据端点的大小关系来动态的维护区间。不过由于vector
是连续存储的,当我们插入或删除一个区间时需要 $O(n)$ 的复杂度,这里 $n$ 是区间的个数。但是当我们需要输出区间时,我们可以直接将当前维护好的区间返回,复杂度为 $O(n)$。
或者我们可以将区间用一棵平衡二叉搜索树来组织起来,每个区间之间的大小关系由左端点的大小来决定。这样我们同样可以用 $O(\log n)$ 的时间找到最后一个左端点小于或等于当前插入点的区间,并且可以用 $O(\log n)$ 的时间来动态的插入和删除区间。但是当我们需要输出区间时我们需要遍历整棵二叉搜索树来计算答案,时间复杂度为 $O(n)$。具体做法如下:
我们将区间 $[begin, end]$ 用一个平衡二叉树来维护,并且定义区间 $a, b$ 之间的大小关系为 $a < b \iff a.begin < b.begin$,注意这里右端点不参与比较,这是为了找到第一个左端点严格大于当前val
的区间,如果右端点也参与比较比如对于[3, 3]
来说第一个严格大于它的区间会是[3, 4]
。
找到第一个左端点严格大于当前[val, val]
的区间后将它减一就是最后一个左端点小于等于当前[val, val]
的区间。
之后我们可以分情况讨论:
- 如果当前值在区间内,我们直接返回,如果插入的话会造成冗余,比如我们在
[1, 3]
内插入2
会造成[1, 3]
和[2, 2]
同时存在。 - 如果存在
val - 1
的右端点和val + 1
的左端点,那么我们可以合并左右两个区间,因为STL的set
不支持修改元素,我们删除再插入就可以了。 - 如果只存在
val - 1
的右端点,合并左区间。 - 如果只存在
val + 1
的左端点,合并右区间。 - 都不存在,插入当前
[val, val]
区间。
这里我们分别用两个迭代器lower
和higher
指向最后一个左端点小于等于当前点和第一个左端点大于当前点的区间。当这两个区间的某一个不存在时,我们就让这两个迭代器指向同一个区间。当这两个区间都不存在时,即当前区间集合为空的时候我们直接将当前点直接插入。并且当我们删除区间时,我们让higher
指向删除后的位置,即新区间要插入的位置,之后在insert
的时候将higher
当做hint
传进去可以加快插入区间的速度。
C++ 代码
class SummaryRanges {
struct Interval
{
int begin, end;
bool operator<(const Interval& i) const
{
return begin < i.begin;
}
};
public:
/** Initialize your data structure here. */
set<Interval> mp;
SummaryRanges() {
}
void addNum(int val) {
if (mp.empty()) {
mp.insert({val, val});
return;
}
auto higher = mp.upper_bound({val, val});
auto lower = higher;
if (lower != mp.begin()) -- lower;
if (higher == mp.end()) -- higher;
if (lower->begin <= val && lower->end >= val) return;
int st = val, ed = val;
if (lower->end == val - 1 && higher->begin == val + 1) {
st = lower->begin, ed = higher->end;
mp.erase(lower);
higher = mp.erase(higher);
} else if (lower->end == val - 1) {
st = lower->begin;
higher = mp.erase(lower);
} else if (higher->begin == val + 1) {
ed = higher->end;
higher = mp.erase(higher);
}
mp.insert(higher, {st, ed});
}
vector<vector<int>> getIntervals() {
vector<vector<int>> res;
for (auto &i : mp) res.push_back({i.begin, i.end});
return res;
}
};
/**
* Your SummaryRanges object will be instantiated and called as such:
* SummaryRanges* obj = new SummaryRanges();
* obj->addNum(val);
* vector<vector<int>> param_2 = obj->getIntervals();
*/
算法2
(哈希表+区间合并)
这道题的 follow-up 是假设有很多区间合并的操作,并且区间的数量远小于输入数据流的数量时应该怎么做。换句话说就是让我们尽可能的降低addNum()
的复杂度而可以适当的提高getIntervals()
的复杂度。我们可以采用与 LeetCode 128. Longest Consecutive Sequence 相同的思路,就是用一个哈希表来维护当前的区间,即让我们维护的区间无序,当我们要输出区间时我们可以遍历哈希表将区间取出,排序去重后再输出。这样addNum()
的复杂度变为 $O(1)$,而getIntervals()
的复杂度变为 $O(n\log n)$。
我们用一个哈希表unordered_map
来维护每个区间两个端点的值,这里的值是区间的长度,这样对于 $n$ 个区间我们会有 $2n$ 个点存于哈希表中。另外由于哈希表里目前只有区间端点的值,没有区间内部点的值,为了判重我们另外开一个哈希表unordered_set
来存储数据流中的所有点。当我们在addNum
合并区间时要将旧的端点值从区间哈希表中删去,这样当我们输出区间时共有 $2n$ 个区间,排序去重后就可以得到答案了,这里去重是指将每个区间右端点所对应的区间给排除掉。
注意这里如果在addNum
时没有删除区间内部的点即旧的端点值会导致区间数量与数据中不同的点的数量相同,当数据范围比较大的时候会超时。
C++ 代码
class SummaryRanges {
public:
/** Initialize your data structure here. */
unordered_map<int, int> hash;
unordered_set<int> dup;
SummaryRanges() {
}
void addNum(int val) {
if (dup.count(val)) return;
dup.insert(val);
if (hash.count(val - 1) && hash.count(val + 1)) {
int left = hash[val - 1], right = hash[val + 1];
hash[val - left] = hash[val + right] = left + right + 1;
if (left > 1) hash.erase(val - 1);
if (right > 1) hash.erase(val + 1);
} else if (hash.count(val - 1)) {
int left = hash[val - 1];
hash[val - left] = hash[val] = left + 1;
if (left > 1) hash.erase(val - 1);
} else if (hash.count(val + 1)) {
int right = hash[val + 1];
hash[val + right] = hash[val] = right + 1;
if (right > 1) hash.erase(val + 1);
} else {
hash[val] = 1;
}
}
vector<vector<int>> getIntervals() {
vector<vector<int>> tmp;
for (auto &p : hash) tmp.push_back({p.first, p.first + p.second - 1});
sort(tmp.begin(), tmp.end());
vector<vector<int>> res;
res.push_back(tmp[0]);
int ed = res[0][1];
for (int i = 1; i < tmp.size(); i ++ ) {
if (tmp[i][0] > ed) {
res.push_back(tmp[i]);
ed = tmp[i][1];
}
}
return res;
}
};
/**
* Your SummaryRanges object will be instantiated and called as such:
* SummaryRanges* obj = new SummaryRanges();
* obj->addNum(val);
* vector<vector<int>> param_2 = obj->getIntervals();
*/