题目描述
求树上点集(带修改)的最小覆盖线段集合。
LCA
LYD
老师告诉我们,如果要求最小覆盖线段集合,我们先求出树上点的DFN
,再对每对DFN
相邻的点求他们之间的线段长度,加起来就是答案的2
倍。
当然,动态维护DFN
值相邻的点的方法有许多种,比如Splay
等。我采用了STL
的SET
(因为懒得写Splay
了)
求树上两个点之间线段长度,可得公式:
设$D_i$为i
到根节点的长度,那么,可得:$dis(x,y)=D_x+D_y-2 \times D_{lca(x,y)}$。
减去一个点就直接把ans+=
改成ans-=
即可(反正是全反操作),然后再在SET
中清除它。
C++ 代码
#include <iostream>
#include <cstdio>
#include <set>
using namespace std;
struct edge{
int to,v,next;
}s[200100];
int ls=1,head[100100];
void link(int a,int b,int c){
s[ls].to=b;
s[ls].v=c;
s[ls].next=head[a];
head[a]=ls++;
}
int f[20][100100];//[0,17] lca的倍增数组
long long dis[2][100100];//dis[0]是深度,dis[1]是上面的 D 数组
int dfn[100100];//时间戳
int tmp;
void dfs(int now){
dfn[now]=++tmp;
for(int i=head[now];i;i=s[i].next)
if(s[i].to!=f[0][now]){
f[0][s[i].to]=now;
dis[0][s[i].to]=dis[0][now]+1;
dis[1][s[i].to]=dis[1][now]+s[i].v;
dfs(s[i].to);
}
}
void init(int root,int n){//预处理出f dis dfn 数组
dis[0][root]=dis[1][root]=1;
dfs(root);
for(int i=1;i<=17;i++)
for(int j=1;j<=n;j++)
f[i][j]=f[i-1][f[i-1][j]];
}
int lca(int a,int b){
//printf("**GETTING LCA...(%d,%d)=",a,b);
if(dis[0][b]>dis[0][a])
swap(a,b);
for(int i=17;i>=0;i--)
if(dis[0][f[i][a]]>=dis[0][b])
a=f[i][a];
if(a==b){/*printf("%d\n",a);*/return a;}
//printf("(%d,%d)=",a,b);
if(dis[0][a]!=dis[0][b]) printf("**WARNING: DIS IS NOT EQUAL\n");
for(int i=17;i>=0;i--)
if(f[i][a]!=f[i][b]){
a=f[i][a];b=f[i][b];
}
//printf("%d\n",f[0][a]);
return f[0][a];
}
long long getans(int x,int y){//获得两点之间距离
return (long long)dis[1][x]-dis[1][lca(x,y)]*2+dis[1][y];
}
struct point{
int u,v;
};
bool operator < (const point a,const point b){
return (a.v==b.v ? a.u<b.u : a.v<b.v);
}
set<point>se;
point getp(set <point> :: iterator x){//获得某个点DFN值的前驱后继
if(x==se.begin()) return (*(--se.end()));
return *(--x);
}
point getn(set <point> :: iterator x){
if(x==--se.end()) return *(se.begin());
return *(++x);
}
int root=1;
int main(){
char junk;
int n,a,b,c;
cin>>n;
for(int i=1;i<n;i++){
scanf("%d%d%d",&a,&b,&c);
link(a,b,c);
link(b,a,c);
}
long long ans=0;
init(root,n);/*
for(int i=1;i<=n;i++){
printf("F:");
for(int j=0;j<=17;j++)
printf("%d ",f[j][i]);
printf("\nDIS:(%d,%d)\n",dis[0][i],dis[1][i]);
}*/
cin>>n;
for(int i=0;i<n;i++){
do{junk=getchar();}while(junk!='+'&&junk!='?'&&junk!='-');
if(junk=='?'){
printf("%lld\n",ans/2);
continue;
}
scanf("%d",&a);
if(junk=='+'){
auto x=se.insert((point){a,dfn[a]}).first;
//printf("**%d\n",((point)*x).u);
int pre=(getp(x)).u;
int nxt=(getn(x)).u;
//printf("**GETTING %d pre=%d,nxt=%d\n",((point)*x).u,pre,nxt);
ans+=getans(((point)*x).u,pre)+getans(((point)*x).u,nxt)-getans(pre,nxt);//上面的公式
}
else{
auto x=se.find((point){a,dfn[a]});
//printf("**%d\n",((point)*x).u);
int pre=(getp(x)).u;
int nxt=(getn(x)).u;
//printf("**GETTING %d pre=%d,nxt=%d\n",((point)*x).u,pre,nxt);
ans-=getans(((point)*x).u,pre)+getans(((point)*x).u,nxt)-getans(pre,nxt);
se.erase(x);
}
//printf("**%lld\n",ans);
}
return 0;
}
大佬请问能不能帮我看一下为什么我写的代码会T。。。
大佬 获得先驱和后继哪里为什么是 –se.end() 呢 se.end()返回的是什么?
se.end()返回的是set中最后一个元素的后面一个,相当于在1~n数组中的n+1,所以用–se.end()来得到最后一个元素(因为第一个元素的前面是最后一个元素)
明白了 谢谢大佬