题意
有一个长度为 n 的序列 a,定义序列的权值为一直进行以下操作直到所有数字变成 0 的操作次数。
取出序列最大数 ai,将 ai−1,ai,ai+1 都减去 1。如果有数字小于 0,将其变成 0。
有 m 次操作,要支持单点修改权值,每次修改之后查询序列的权值。
分析
首先容易发现的就是,如果我们选了 ai,那么之后就一定不会选 ai−1 和 ai+1。因为根据定义我们得知,ai>max,而一次修改之后,这些数字的相对大小不变,所以我们只要知道所有选出的数字的和即可。
考虑线段树。合并的时候我们希望可以快速求出合并的值,但是显然,这个值会受到选出的数不能相邻的干扰。常见办法就是维护 0/1/2/3 分别表示考虑左右端点,不考虑左端点,不考虑右端点,左右端点都不考虑的情况。大力分讨即可。
注意,分讨的时候我们要关心考虑的端点是否有选上。没选上就是小丑了,所以我们要额外维护一个 l(0/1/2/3),r(0/1/2/3) 表示在 (0/1/2/3) 的情况下是否选了左或右端点。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define il inline
#define N 100005
#define int int
il int rd(){
int s = 0, w = 1;
char ch = getchar();
for (;ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
for (;ch >= '0' && ch <= '9'; ch = getchar()) s = ((s << 1) + (s << 3) + ch - '0');
return s * w;
}
int n, a[N], x, p;
struct ST{
ll p[4];
bool ls[4], rs[4];
ll& operator [](const int& x){return p[x];}
#define l(x) ls[x]
#define r(x) rs[x]
}tr[N << 2];
void push_up(int p, int l, int r){
ST& ans = tr[p], x = tr[p << 1], y = tr[p << 1 | 1];
int mid = (l + r) >> 1;
if (x.r(0) && y.l(0)){
if (a[mid] >= a[mid + 1]) ans.l(0) = x.l(0), ans.r(0) = y.r(1), ans[0] = x[0] + y[1];
else ans.l(0) = x.l(2), ans.r(0) = y.r(0), ans[0] = x[2] + y[0];
}
else ans.l(0) = x.l(0), ans.r(0) = y.r(0), ans[0] = x[0] + y[0];
ans.l(1) = 0;
if (x.r(1) && y.l(0)){
if (a[mid] >= a[mid + 1]) ans.r(1) = y.r(1), ans[1] = x[1] + y[1];
else ans.r(1) = y.r(0), ans[1] = x[3]+ y[0];
}
else ans.r(1) = y.r(0), ans[1] = x[1] + y[0];
ans.r(2) = 0;
if (x.r(0) && y.l(2)){
if (a[mid] >= a[mid + 1]) ans.l(2) = x.l(0), ans[2] = x[0] + y[3];
else ans.l(2) = x.l(2), ans[2] = x[2] + y[2];
}
else ans.l(2) = x.l(0), ans[2] = x[0] + y[2];
ans.l(3) = ans.r(3) = 0;
if (x.r(1) && y.l(2)){
if (a[mid] >= a[mid + 1]) ans[3] = x[1] + y[3];
else ans[3] = x[3] + y[2];
}
else ans[3] = x[1] + y[2];
}
void build(int p, int l, int r){
if (l == r) return tr[p][0] = a[l], tr[p].l(0) = tr[p].r(0) = 1, void(0);
int mid = (l + r) >> 1;
build(p << 1, l, mid), build(p << 1 | 1, mid + 1, r);
push_up(p, l, r);
}
void add(int p, int l, int r, int x, int k){
if (l == r) return tr[p][0] = k, tr[p].l(0) = tr[p].r(0) = 1, void(0);
int mid = (l + r) >> 1;
if (x <= mid) add(p << 1, l, mid, x, k);
else add(p << 1 | 1, mid + 1, r, x, k);
push_up(p, l, r);
}
signed main(){
n = rd();
for (int i = 1; i <= n; i++) a[i] = rd();
build(1, 1, n);
for (int T = rd(); T--;){
x = rd(), p = rd();
a[x] = p, add(1, 1, n, x, p);
printf ("%lld\n", tr[1][0]);
}
return 0;
}