C++ 代码
#define FOR(i, a, b) for(int i = a; i <= b; i ++)
#define ROF(i, a, b) for(int i = a; i >= b; i --)
#define mem(a, b) memset(a, b, sizeof a)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n, m, num = 0;
short a[1000010];
bool inq[1000010];
struct {
ll val;
int l, r;
}node[1000010];
int nn;
void del(int x) {
int l = node[x].l, r = node[x].r;
node[l].r = r, node[r].l = l;
}
void push(int pos, int x) {
int r = node[pos].r, l = pos;
node[++nn].l = l, node[nn].r = r, node[l].r = nn, node[r].l = nn;
node[nn].val = x;
}
struct {
ll val;
int cnt;
}heap[1000010];
int tot;
void up(int x) {
while (x > 1) {
if (heap[x].val > heap[x / 2].val) swap(heap[x], heap[x / 2]), x /= 2;
else break;
}
}
void insert(ll val, int cnt) {
heap[++tot] = { val, cnt };
up(tot);
}
void down() {
int p = 1, s = 2 * p;
while (s <= tot) {
if (s < tot && heap[s].val < heap[s + 1].val) s++;
if (heap[s].val > heap[p].val) {
swap(heap[s], heap[p]);
p = s, s = 2 * p;
}
else break;
}
}
void pop() {
heap[1] = heap[tot--];
down();
}
int main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
int s = 0, t = 1000001;
FOR(i, 1, 1000000) inq[i] = 1;
while (cin >> m >> n) {
nn = tot = num = 0;
node[s].l = node[s].r = t, node[t].l = node[t].r = s;
FOR(i, 1, n) {
cin >> a[i];
if (a[i] > 0) num++;
}
if (num <= m) {
sort(a + 1, a + n + 1, greater<int>());
ll ans = 0;
FOR(i, 1, m) ans += a[i];
cout << ans << '\n';
}
else {
int k = 0;
ll now = 0, ans = 0;
int l = 1, r = n;
while (a[l] <= 0) l++;
while (a[r] <= 0) r--;
FOR(i, l, r) {
if (a[i] == 0) continue;
if (a[i] > 0) ans += a[i];
if (now == 0) now += a[i];
else {
if (((a[i] >> 31) & 1) ^ ((now >> 63) & 1)) {
insert(-abs(now), tot + 1);
push(nn, now);
if (now > 0) k++;
now = a[i];
}
else
now += a[i];
}
}
k++;
insert(-abs(now), tot + 1);
push(nn, now);
while (k > m) {
int cnt = heap[1].cnt;
while (!inq[cnt]) {pop(); cnt = heap[1].cnt;}
if ((node[cnt].l != s && node[cnt].r != t) || node[cnt].val > 0) {
ans += heap[1].val;
inq[node[cnt].l] = inq[node[cnt].r] = 0;
node[cnt].val += node[node[cnt].l].val + node[node[cnt].r].val;
if (node[cnt].l != s) del(node[cnt].l);
if (node[cnt].r != t) del(node[cnt].r);
pop();
insert(-abs(node[cnt].val), cnt);
k--;
}
else pop();
}
cout << ans << '\n';
FOR(i, 1, n) inq[i] = 1;
}
}
return 0;
}