树剖LCA它不香么
引理:所求即为将所有异象石按dfs序排成一圈(首位相接),相邻两点的树上距离的和的一半
不会证明
dfs序可以直接采用树剖的dfs序。
接下来,我们用set维护这个异象石构成的dfs序,顺便把点的编号作为第二关键字。
用res
表示当前所有相邻点的树上距离和。
插入的时候,在set中查找前驱pre
,后继nxt
,那么前驱点l=pre->second
,后继点r=nxt->second
新的$$res=res-dist(l,r)+dist(l,u)+dist(u,r)$$
dist是树上距离,树剖LCA求一下就好了。
删除也是类似的。
时间复杂度$O(nlogn)$
/**********/省略快读
#define MAXN 100011
struct Edge
{
ll v,w,nxt;
}e[MAXN<<1|1];
ll cnt=0,last[MAXN];
void adde(ll u,ll v,ll w)
{
++cnt;
e[cnt].v=v;e[cnt].w=w;
e[cnt].nxt=last[u],last[u]=cnt;
}
ll fa[MAXN],dep[MAXN],size[MAXN],mson[MAXN],dis[MAXN];
void dfs1(ll u,ll now)
{
dep[u]=now;
size[u]=1;
for(ll i=last[u];i;i=e[i].nxt)
{
ll v=e[i].v;
if(dep[v])continue;
fa[v]=u;dis[v]=dis[u]+e[i].w;
dfs1(v,now+1);
size[u]+=size[v];
if(size[v]>size[mson[u]])mson[u]=v;
}
}
ll top[MAXN],t[MAXN],tot=0;
void dfs2(ll u,ll cur)
{
top[u]=cur;t[u]=++tot;
if(mson[u])dfs2(mson[u],cur);
for(ll i=last[u];i;i=e[i].nxt)
{
ll v=e[i].v;
if(v==fa[u]||v==mson[u])continue;
dfs2(v,v);
}
}
ll LCA(ll u,ll v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]>=dep[top[v]])u=fa[top[u]];
else v=fa[top[v]];
}
if(dep[u]>=dep[v])return v;
else return u;
}
ll dist(ll u,ll v)
{
return dis[u]+dis[v]-2*dis[LCA(u,v)];
}
std::set<pll>s;
int main()
{
ll n=read();
for(ll i=1;i<n;++i)
{
ll u=read(),v=read(),w=read();
adde(u,v,w),adde(v,u,w);
}
dfs1(1,1),dfs2(1,1);
ll m=read(),res=0;
for(ll i=1;i<=m;++i)
{
char op=getchar();
while(op!='+'&&op!='-'&&op!='?')op=getchar();
if(op=='+')
{
ll u=read();
if(s.empty())
{
s.insert(pll(t[u],u));continue;
}
std::set<pll>::iterator itl=--s.lower_bound(pll(t[u],u));
std::set<pll>::iterator itr=s.lower_bound(pll(t[u],u));
if(itr==s.begin())itl=--s.end();
if(itr==s.end())itr=s.begin();
res-=dist(itl->second,itr->second);
res+=dist(itl->second,u)+dist(u,itr->second);
s.insert(pll(t[u],u));
}
else if(op=='-')
{
ll u=read();
s.erase(pll(t[u],u));
if(s.empty())continue;
std::set<pll>::iterator itl=--s.lower_bound(pll(t[u],u));
std::set<pll>::iterator itr=s.lower_bound(pll(t[u],u));
if(itr==s.begin())itl=--s.end();
if(itr==s.end())itr=s.begin();
res-=dist(itl->second,u)+dist(u,itr->second);
res+=dist(itl->second,itr->second);
}
else printf("%lld\n",res>>1);
}
return 0;
}