题目描述
给你一个由 n
个整数组成的数组 nums
,以及两个整数 k
和 x
。
数组的 x-sum 计算按照以下步骤进行:
- 统计数组中所有元素的出现次数。
- 仅保留出现次数最多的前
x
个元素的每次出现。如果两个元素的出现次数相同,则数值 较大 的元素被认为出现次数更多。 - 计算结果数组的和。
注意,如果数组中的不同元素少于 x
个,则其 x-sum 是数组的元素总和。
返回一个长度为 n - k + 1
的整数数组 answer
,其中 answer[i]
是 子数组 nums[i..i + k - 1]
的 x-sum。
子数组 是数组内的一个连续 非空 的元素序列。
样例
输入:nums = [1,1,2,2,3,4,2,3], k = 6, x = 2
输出:[6,10,12]
解释:
对于子数组 [1, 1, 2, 2, 3, 4],只保留元素 1 和 2。
因此,answer[0] = 1 + 1 + 2 + 2。
对于子数组 [1, 2, 2, 3, 4, 2],只保留元素 2 和 4。
因此,answer[1] = 2 + 2 + 2 + 4。
注意 4 被保留是因为其数值大于出现其他出现次数相同的元素(3 和 1)。
对于子数组 [2, 2, 3, 4, 2, 3],只保留元素 2 和 3。
因此,answer[2] = 2 + 2 + 2 + 3 + 3。
输入:nums = [3,8,7,8,7,5], k = 2, x = 2
输出:[11,15,15,15,12]
解释:
由于 k == x,answer[i] 等于子数组 nums[i..i + k - 1] 的总和。
限制
nums.length == n
1 <= n <= 10^5
1 <= nums[i] <= 10^9
1 <= x <= k <= nums.length
算法
(双有序集) $O(n \log n)$
- 使用双有序集存储维护前 $x$ 个出现次数最大的数字。
- 第一个有序集存储 $x$ 个出现次数最大的数字,剩余的数字由第二个有序集维护。
- 首先将前 $k-1$ 个数字维护出现次数,并维护两个有序集。
- 然后开始遍历所有子数组。
- 每次先维护新出现数字的出现次数,即先删除原来的出现次数,再插入新的出现次数。
- 更新答案后,仍然按照前面的顺序,删除子数组左边界数字的出现次数。
时间复杂度
- 每次维护有序集的时间复杂度为 $O(\log n)$,故总时间复杂度为 $O(n \log n)$。
空间复杂度
- 需要 $O(n)$ 的额外空间存储有序集、哈希表和答案。
C++ 代码
#define LL long long
class Solution {
public:
vector<LL> findXSum(vector<int>& nums, int k, int x) {
set<pair<int, int>> s1, s2;
LL tot = 0;
auto ins = [&](const pair<int, int> &p) {
if (s1.size() < x) {
s1.insert(p);
tot += (LL)(p.first) * p.second;
} else if (p > *s1.begin()) {
s2.insert(*s1.begin());
tot -= (LL)(s1.begin()->first) * s1.begin()->second;
s1.erase(s1.begin());
s1.insert(p);
tot += (LL)(p.first) * p.second;
} else {
s2.insert(p);
}
};
auto del = [&](const pair<int, int> &p) {
auto it = s1.find(p);
if (it == s1.end()) {
s2.erase(s2.find(p));
return;
}
tot -= (LL)(p.first) * p.second;
s1.erase(it);
if (!s2.empty()) {
s1.insert(*s2.rbegin());
tot += (LL)(s2.rbegin()->first) * s2.rbegin()->second;
s2.erase(*s2.rbegin());
}
};
const int n = nums.size();
unordered_map<int, int> h;
for (int i = 0; i < k - 1; i++)
++h[nums[i]];
for (const auto &[t, v] : h)
ins(make_pair(v, t));
vector<LL> ans;
for (int i = k - 1; i < n; i++) {
int t = nums[i];
if (h[t] > 0)
del(make_pair(h[t], t));
ins(make_pair(++h[t], t));
ans.push_back(tot);
t = nums[i - k + 1];
del(make_pair(h[t], t));
if (--h[t] > 0)
ins(make_pair(h[t], t));
}
return ans;
}
};