题目描述
给定一个序列$a[n]$,同时给定一个序列$b[n]$,其中$bi$表示清楚$ai$中的第$bi$个数,并输出清除后的对应最大连续子段和。
样例
输入:
4
1 3 2 5
3 4 1 2
输出:
5
4
3
0
算法1
(线段树) $O(nlogn)$
单点修改,区间查询最大子段和,很容易联想到线段树维护。
C++ 代码
#include<bits/stdc++.h>
using i64=long long;
const i64 inf=-2e14;
const int N=1e5+5;
struct node{
int l,r;
i64 ls,rs,mx,sum;
}tr[N<<2];
int a[N],n;
void pushup(int u){
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
tr[u].ls=std::max(tr[u<<1].ls,tr[u<<1].sum+tr[u<<1|1].ls);
tr[u].rs=std::max(tr[u<<1|1].rs,tr[u<<1].rs+tr[u<<1|1].sum);
tr[u].mx=std::max({tr[u<<1].mx,tr[u<<1|1].mx,tr[u<<1].rs+tr[u<<1|1].ls});
}
void build(int u,int l,int r){
tr[u]={l,r};
if(l==r){
tr[u]={l,l,a[l],a[l],a[l],a[l]};
return;
}
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
void modify(int u,int x){
if(tr[u].l==x&&tr[u].r==x){
tr[u].ls=inf,tr[u].rs=inf;
tr[u].mx=inf,tr[u].sum=inf;
return;
}
int mid=tr[u].l+tr[u].r>>1;
if(x<=mid){
modify(u<<1,x);
}
else{
modify(u<<1|1,x);
}
pushup(u);
}
signed main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin>>n;
for(int i=1;i<=n;i++){
std::cin>>a[i];
}
build(1,1,n);
for(int i=1;i<=n;i++){
int x;
std::cin>>x;
modify(1,x);
std::cout<<std::max(tr[1].mx,(i64)0)<<"\n";
}
return 0;
}
算法2
(并查集维护) $O(n)$
由于并查集只能合并,不能删除,我们可以考虑将删除操作倒过来转变为插入操作,对于每一个位置,我们都可以将与他相邻的位置合并,放在同一个集合内,再对每个集合中的元素和求最大值即可。
C++ 代码
#include<bits/stdc++.h>
using i64=long long;
const int N=1e5+5;
int a[N],b[N];
i64 s[N],ans[N];
struct DSU{
std::vector<int>f,siz;
DSU(){};
DSU(int n){
init(n);
}
void init(int n){
f.resize(n);
std::iota(f.begin(),f.end(),0);
siz.assign(n,1);
for(int i=0;i<n;i++){
s[i]=a[i];
}
}
int find(int x){
return f[x]==x?x:find(f[x]);
}
bool same(int x,int y){
return find(x)==find(y);
}
bool merge(int x,int y){
x=find(x),y=find(y);
if(x==y){
return false;
}
if(siz[x]<siz[y]){
std::swap(x,y);
}
siz[x]+=siz[y];
s[x]+=s[y];
f[y]=x;
return true;
}
int size(int x){
return siz[find(x)];
}
};
int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n;
std::cin>>n;
for(int i=1;i<=n;i++){
std::cin>>a[i];
}
for(int i=1;i<=n;i++){
std::cin>>b[i];
}
DSU dsu(n+1);
i64 sum=0;
std::vector<int>vis(n+1,0);
for(int i=n;i;i--){
int cur=b[i];
vis[cur]=1;
if(cur>1&&vis[cur-1]){
dsu.merge(cur,cur-1);
}
if(cur<n&&vis[cur+1]){
dsu.merge(cur,cur+1);
}
sum=std::max(sum,s[dsu.find(cur)]);
ans[i-1]=sum;
}
ans[n]=0;
for(int i=1;i<=n;i++){
std::cout<<ans[i]<<"\n";
}
return 0;
}