题目描述
给你一个长度为 n
的数组 nums
和一个整数 k
。
对于 nums
中的每一个子数组,你可以对它进行 至多 k
次操作。每次操作中,你可以将子数组中的任意一个元素增加 1。
注意,每个子数组都是独立的,也就是说你对一个子数组的修改不会保留到另一个子数组中。
请你返回最多 k
次操作以内,有多少个子数组可以变成 非递减 的。
如果一个数组中的每一个元素都大于等于前一个元素(如果前一个元素存在),那么我们称这个数组是 非递减 的。
样例
输入:nums = [6,3,1,2,4,4], k = 7
输出:17
解释:
nums 的所有 21 个子数组中,只有子数组 [6, 3, 1] ,[6, 3, 1, 2] ,[6, 3, 1, 2, 4]
和 [6, 3, 1, 2, 4, 4] 无法在 k = 7 次操作以内变为非递减的。
所以非递减子数组的数目为 21 - 4 = 17 。
输入:nums = [6,3,1,3,6], k = 4
输出:12
解释:
子数组 [3, 1, 3, 6] 和 nums 中所有小于等于三个元素的子数组中,除了 [6, 3, 1] 以外,
都可以在 k 次操作以内变为非递减子数组。
总共有 5 个包含单个元素的子数组,4 个包含两个元素的子数组,
除 [6, 3, 1] 以外有 2 个包含三个元素的子数组,
所以总共有 1 + 5 + 4 + 2 = 12 个子数组可以变为非递减的。
限制
1 <= nums.length <= 10^5
1 <= nums[i] <= 10^9
1 <= k <= 10^9
算法
(双指针,单调栈,单调队列) $O(n)$
- 对于每个结束位置 $i$,都找到尽可能小的位置 $j$,满足 $[j, i], [j + 1, i], \dots, [i, i]$ 都是满足条件的子数组。
- 注意到 $j$ 是随着 $i$ 的增加而不减的,故可以使用双指针。
- 当移动 $i$ 时,可以使用单调队列维护当前区间内的最大值,将 $i$ 加入区间的代价为 $nums(i)$ 减去区间的最大值。
- 如果代价超过了 $k$,则需要移动 $j$。对于 $nums(j)$,其能影响到的范围为 $[j + 1, r(j) - 1]$,其中 $r(j)$ 为 $j$ 右侧第一个大于等于 $j$ 的位置。
- 先去掉 $nums(j)$ 的影响,则相当于将 $[j + 1, r(j) - 1]$ 的代价重置为 $0$。注意这里 $r(j) - 1$ 可能大于 $i$,需要取最小值。
- 然后再让 $[j + 1, i]$ 区间内所有 以 $j$ 作为左侧第一个最大值的位置 $x$($x \le i$),累加其 $[x, r(x) - 1]$ 的代价。
- 计算代价可以通过区间和来辅助计算。
时间复杂度
- 前缀和、单调栈预处理的时间复杂度为 $O(n)$。
- 遍历过程中,每个元素进队一次,出队一次,且每个 $x$ 也恰好出现一次。
- 故总时间复杂度为 $O(n)$。
空间复杂度
- 需要 $O(n)$ 的额外空间存储前缀和数组,栈和队列。
C++ 代码
#define LL long long
class Solution {
public:
LL countNonDecreasingSubarrays(vector<int>& nums, int k) {
const int n = nums.size();
vector<LL> sum(n);
sum[0] = 0;
for (int i = 1; i < n; i++)
sum[i] = sum[i - 1] + nums[i];
vector<int> r(n);
vector<vector<int>> left(n);
stack<int> st;
for (int i = 0; i <= n; i++) {
while (!st.empty() && (i == n || nums[st.top()] <= nums[i])) {
int top = st.top();
st.pop();
r[top] = i;
if (!st.empty())
left[st.top()].push_back(top);
}
st.push(i);
}
deque<int> mx;
LL ans = 0;
for (int i = 0, j = 0; i < n; i++) {
while (!mx.empty() && nums[mx.back()] <= nums[i])
mx.pop_back();
mx.push_back(i);
k -= nums[mx.front()] - nums[i];
while (k < 0) {
int t = min(i, r[j] - 1);
k += (LL)(t - j) * nums[j];
k -= sum[t] - sum[j];
for (int x : left[j]) {
if (x > i)
break;
t = min(i, r[x] - 1);
k -= (LL)(t - x) * nums[x];
k += sum[t] - sum[x];
}
if (j == mx.front())
mx.pop_front();
++j;
}
ans += i - j + 1;
}
return ans;
}
};