题目描述
难度分:$1800$
输入$n$、$m(m \lt n \leq 2 \times 10^6, 1 \leq m \leq 5000)$和长为$n$的数组$r(-m \leq r[i] \leq m)$,其中恰好有$m$个$0$。
一开始$x=y=0$。从左到右遍历$r$:
- 如果$r[i]=0$,你可以把$x$加一,或者把$y$加一。
- 如果$r[i] \lt 0$且$x \geq |r[i]|$,那么得$1$分。
- 如果$r[i] \gt 0$且$y \geq |r[i]|$,那么得$1$分。
输出最大总得分。
输入样例$1$
10 5
0 1 0 2 0 -3 0 -4 0 -5
输出样例$1$
3
输入样例$2$
3 1
1 -1 0
输出样例$2$
0
输入样例$3$
9 3
0 0 1 0 2 -3 -2 -2 1
输出样例$3$
4
算法
动态规划
这个题乍一看没什么思路,因为这个$n$太大了,但是$m \leq 5000$是比较小的,感觉就应该从$m$入手来设计算法。而注意到只有$r[i]=0$的时候才能对$x$或$y$自增,所以$x$或$y$最终可以变得的大小也不会很大,隐隐感觉要用DP
来做。
状态定义
$dp[i][x]$表示当前考虑第$i$个$0$(从第$0$个$0$开始),且前面$x$的值大小为$x$的情况下,考虑完后面的所有$0$能够得到的最大得分。在这个定义下,答案就是$dp[0][0]$,从第$0$个$0$开始考虑,初始的$x=0$。
状态转移
而到了第$i$个$0$时,前面已经经过了$i$个$0$,所以从前位置的$x$值也可以推测出$y=i-x$。这时候有两种策略,要么当前对$x$自增,要么当前对$y$自增,状态转移方程分别为$dp[i][x]=cnt_1+dp[i+1][x+1]$,$dp[i][x]=cnt_2+dp[i+1][x]$,两种情况选较大值转移。其中$cnt_1=get1(i,i+1,x+1)+get2(i,i+1,y)$,$cnt_2=get1(i,i+1,x)+get2(i,i+1,y+1)$。$get1(i,i+1,val)$计算的是第$i$个$0$和第$i+1$个$0$之间有多少个小于$0$的$r$值绝对值$\leq val$,$get2(i,i+1,val)$计算的是第$i$个$0$和第$i+1$个$0$之间有多少个大于$0$的$r$值$\leq val$。
我们可以先把正数都保存在二元组数组$ypos$中,存$(r[i],i)$;把负数保存在二元组数组$xpos$中,存$(-r[i],i)$。提前将位于两个相邻$0$之间的段按照第一关键字排序,这样在状态转移的时候就可以通过二分来加速了。
复杂度分析
时间复杂度
状态数量是$O(m^2)$,单次转移的时间复杂度为$O(log_2n)$。对$ypos$和$xpos$分段排序,最终也相当于对$O(n)$规模的数组排序,时间复杂度为$O(nlog_2n)$。因此,整个算法的时间复杂度为$O(nlog_2n+m^2log_2n)$。
空间复杂度
DP
数组的空间复杂度为$O(m^2)$。为了加速转移,预处理出两个$O(m^2)$规模的二元组数组$xmp$和$ymp$,$xmp[i][i+1]$表示第$i$个$0$和第$i+1$个$0$之间的负$r$值在$xpos$数组上是哪个子区间$(low,high)$,$ymp[i][i+1]$表示第$i$个$0$和第$i+1$个$0$之间的正$r$值在$ypos$数组上是哪个子区间$(low,high)$。因此,整个算法的额外空间复杂度为$O(m^2)$。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 2000010, M = 5001;
int n, m, r[N], dp[M][M];
array<int, 2> xmp[M][M], ymp[M][M];
int main() {
scanf("%d%d", &n, &m);
vector<array<int, 2>> xpos, ypos;
vector<int> zpos;
for(int i = 1; i <= n; i++) {
scanf("%d", &r[i]);
if(r[i] < 0) {
xpos.push_back({-r[i], i});
}else if(r[i] > 0) {
ypos.push_back({r[i], i});
}else {
zpos.push_back(i);
}
}
zpos.push_back(n + 1);
for(int i = 1; i < zpos.size(); i++) {
int cur = zpos[i - 1], nxt = zpos[i];
int l = 0, r = xpos.size() - 1, index = -1;
while(l <= r) {
int mid = l + r >> 1;
if(xpos[mid][1] > cur) {
index = mid;
r = mid - 1;
}else {
l = mid + 1;
}
}
int low = index;
l = 0, r = xpos.size() - 1, index = -1;
while(l <= r) {
int mid = l + r >> 1;
if(xpos[mid][1] < nxt) {
index = mid;
l = mid + 1;
}else {
r = mid - 1;
}
}
int high = index;
xmp[i - 1][i] = {low, high};
if(low != -1 && high != -1) {
sort(xpos.begin() + low, xpos.begin() + high + 1);
}
}
for(int i = 1; i < zpos.size(); i++) {
int cur = zpos[i - 1], nxt = zpos[i];
int l = 0, r = ypos.size() - 1, index = -1;
while(l <= r) {
int mid = l + r >> 1;
if(ypos[mid][1] > cur) {
index = mid;
r = mid - 1;
}else {
l = mid + 1;
}
}
int low = index;
l = 0, r = ypos.size() - 1, index = -1;
while(l <= r) {
int mid = l + r >> 1;
if(ypos[mid][1] < nxt) {
index = mid;
l = mid + 1;
}else {
r = mid - 1;
}
}
int high = index;
ymp[i - 1][i] = {low, high};
if(low != -1 && high != -1) {
sort(ypos.begin() + low, ypos.begin() + high + 1);
}
}
function<int(int, int, int, int)> get = [&](int cur, int nxt, int x, int flag) {
auto& pir = flag? xmp[cur][nxt]: ymp[cur][nxt];
int l = pir[0], r = pir[1], index = r + 1;
if(l == -1 || r == -1) return 0;
while(l <= r) {
int mid = l + r >> 1;
if((flag? xpos[mid][0]: ypos[mid][0]) > x) {
index = mid;
r = mid - 1;
}else {
l = mid + 1;
}
}
return index - pir[0];
};
for(int x = 0; x <= m; x++) {
dp[m][x] = 0;
}
for(int i = zpos.size() - 2; i >= 0; i--) {
int cur = zpos[i], nxt = zpos[i + 1];
for(int x = 0; x <= i; x++) {
int y = i - x;
int cnt1 = get(i, i + 1, x + 1, 1) + get(i, i + 1, y, 0);
dp[i][x] = max(dp[i][x], cnt1 + dp[i + 1][x + 1]);
int cnt2 = get(i, i + 1, x, 1) + get(i, i + 1, y + 1, 0);
dp[i][x] = max(dp[i][x], cnt2 + dp[i + 1][x]);
}
}
printf("%d\n", dp[0][0]);
return 0;
}