这里是本题的加强版
我们可以用相同的方法,使用左偏树来维护中位数
递增递减分开做即可
#include <bits/stdc++.h>
#define MAX 1000005
#define ll long long
using namespace std;
template<typename T>
void read(T &n){
n = 0;
T f = 1;
char c = getchar();
while(!isdigit(c) && c != '-') c = getchar();
if(c == '-') f = -1, c = getchar();
while(isdigit(c)) n = n*10+c-'0', c = getchar();
n *= f;
}
template<typename T>
void write(T n){
if(n < 0) putchar('-'), n = -n;
if(n > 9) write(n/10);
putchar(n%10+'0');
}
struct node{
int rt, l, r, sz, w;
}st[MAX];
int n, top;
int son[MAX][2], dis[MAX], val[MAX];
ll ans;
int merge(int x, int y){
if(!x || !y) return x+y;
if(val[x] < val[y]) swap(x, y);
son[x][1] = merge(son[x][1], y);
if(dis[son[x][0]] < dis[son[x][1]]) swap(son[x][0], son[x][1]);
dis[x] = dis[son[x][1]]+1;
return x;
}
ll solve1(){
dis[0] = -1;
for(int i = 1; i <= n; i++){
st[++top] = (node){i, i, i, 1, val[i]};
while(top>1 && st[top].w < st[top-1].w){
top--;
st[top].rt = merge(st[top].rt, st[top+1].rt);
st[top].sz += st[top+1].sz;
st[top].r = st[top+1].r;
while(st[top].sz*2 > st[top].r-st[top].l+2){
st[top].rt = merge(son[st[top].rt][0], son[st[top].rt][1]);
st[top].sz--;
}
st[top].w = val[st[top].rt];
}
}
ll ans = 0;
for(int i = 1; i <= top; i++){
for(int j = st[i].l; j <= st[i].r; j++){
ans += abs(st[i].w-val[j]);
}
}
return ans;
}
ll solve2(){
memset(dis, 0, sizeof(dis));
memset(son, 0, sizeof(son));
dis[0] = -1;
top = 0;
for(int i = 1; i <= n; i++){
st[++top] = (node){i, i, i, 1, val[i]};
while(top>1 && st[top].w > st[top-1].w){
top--;
st[top].rt = merge(st[top].rt, st[top+1].rt);
st[top].sz += st[top+1].sz;
st[top].r = st[top+1].r;
while(st[top].sz*2 > st[top].r-st[top].l+2){
st[top].rt = merge(son[st[top].rt][0], son[st[top].rt][1]);
st[top].sz--;
}
st[top].w = val[st[top].rt];
}
}
ll ans = 0;
for(int i = 1; i <= top; i++){
for(int j = st[i].l; j <= st[i].r; j++){
ans += abs(st[i].w-val[j]);
}
}
return ans;
}
int main()
{
cin >> n;
for(int i = 1; i <= n; i++){
read(val[i]);
}
ans = solve1();
ans = min(ans, solve2());
cout << ans << endl;
return 0;
}