题目描述
在 2016 年,佳媛姐姐刚刚学习了树,非常开心。
现在她想解决这样一个问题:
给定一颗有根树(根为 1),有以下两种操作:
标记操作:对某个结点打上标记(在最开始,只有结点 1 有标记,其他结点均无标记,而且对于某个结点,可以打多次标记。)
询问操作:询问某个结点最近的一个打了标记的祖先(这个结点本身也算自己的祖先)
你能帮帮她吗?
输入格式
第一行两个正整数 N 和 Q 分别表示节点个数和操作次数。
接下来 N−1 行,每行两个正整数 u,v 表示 u 到 v 有一条有向边。
接下来 Q 行,形如 oper num,oper为 C 时表示这是一个标记操作,oper 为 Q 时表示这是一个询问操作。
输出格式
输出一个正整数,表示结果。
数据范围
1≤N,Q≤105,
1≤u,v,num≤N
输入样例:
5 5
1 2
1 3
2 4
2 5
Q 2
C 2
Q 2
Q 5
Q 3
输出样例:
1
2
2
1
算法1
(二分+树链剖分) $O(nlognlogn)$
维护的信息:加标记,相当于+1,这用线段树很好维护,只是单点修改
询问的信息是某点u的祖先的最近标记祖先,可以在query_path函数里解决,但是要怎么解决呢?
我们知道树链剖分可以解决lca问题,这里类似,[1,id[u]]区间中去找距离id[u]最近的标记祖先,
,可以通过询问区间和来二分.具体是这么操作
int check(int l,int r)
{
if(l==r) return l;
int mid=l+r>>1;
int t=query(1,mid+1,r);//先看遍历右边==》一定是先右边
//因为是里r最近的标记祖先!!!
if(t)//右的区间和不为0,去右边
{
return check(mid+1,r);
}//右的区间和为0,去左边
else return check(l,mid);
}
ll query_path(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
int t=query(id[top[u]],id[u]);
if(!t)
{
//都没被标记,区间和为0
u=fa[top[u]];
}
else
{
//有被标记,找到最近的那一个节点
return check(id[top[u]],id[u]);
}
}
还有需要注意的是输出答案要输出线段树的节点的原来的编号,需要一个iid[]数组维护
id[u]=++cnt;//线段树上的编号
iid[cnt]=u;//原来的编号
C++ 代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N =100010,M=N*2;
int w[N],e[M],ne[M],h[N],idx;
int id[N],iid[N],nw[N],cnt;//dfs序的编号,新编号的权值,
int dep[N],sz[N],top[N],fa[N],son[N];//深度,以每个点为根的子树的大小
//重链的顶点,节点的重儿子
struct tree
{
int l,r;
ll add,sum;
}tr[N*4];
int n,m;
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs1(int u,int father,int depth)
{
dep[u]=depth,fa[u]=father,sz[u]=1;
for(int i=h[u];i!=-1;i=ne[i])//访问所有的出边
{
int j=e[i];
if(j==father) continue;//访问到唯一一个不是儿子节点的父节点去了,跳过
dfs1(j,u,depth+1);//继续往下遍历且深度加1
sz[u]+=sz[j];//计算当前节点的个数,是父节点加上所有儿子节点的个数
if(sz[son[u]]<sz[j]) son[u]=j;//son[u]表示u的重儿子,如果当前
//的儿子比当前的重儿子
}
}
void dfs2(int u,int t)
{
id[u]=++cnt,nw[cnt]=w[u],top[u]=t;
iid[cnt]=u;//原来的编号
if(!son[u]) return;//没有重链就退出
dfs2(son[u],t);//遍历重链
for(int i=h[u];i!=-1;i=ne[i])
{//遍历轻链
int j=e[i];
if(j==fa[u]||j==son[u]) continue;//搜到父节点或搜到重儿子 跳过
dfs2(j,j);
}
}
void pushup(int u)
{
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
}
void build(int u,int l,int r)
{
if(l==r) tr[u]={l,r};
else
{
tr[u]={l,r};
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
}
}
void update(int u,int x,int k)
{
if(tr[u].l==x&&tr[u].r==x) tr[u].sum=k;
else
{
int mid=tr[u].l+tr[u].r>>1;
if(x>mid) update(u<<1|1,x,k);
else update(u<<1,x,k);
pushup(u);
}
}
ll query(int u,int l,int r)
{
if(tr[u].l>=l&&tr[u].r<=r) return tr[u].sum;
int mid=tr[u].l+tr[u].r>>1;
ll res=0;
if(l<=mid) res+=query(u<<1,l,r);
if(r>mid) res+=query(u<<1|1,l,r);
return res;
}
ll check(int l,int r)
{
if(l==r) return l;
int mid=l+r>>1;
ll t=query(1,mid+1,r);//先看遍历右边
if(!t)//右边没有,去左边
{
return check(l,mid);
}//右边有,去右边
else return check(mid+1,r);
}
ll query_path(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
ll res=query(1,id[top[u]],id[u]);
if(!res)
{
u=fa[top[u]];
}
else
{
return check(id[top[u]],id[u]);
}
}
if(dep[u]>dep[v]) swap(u,v);
return check(id[u],id[v]);
}
int main()
{
scanf("%d%d",&n,&m);
memset(h,-1,sizeof h);
for(int i=0;i<n-1;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
dfs1(1,-1,1);
dfs2(1,1);
build(1,1,n);
update(1,1,1);
while(m--)
{
char opt;
int u;
cin>>opt>>u;
if(opt=='C')
{
update(1,id[u],1);
}
else if(opt=='Q')
{
printf("%d\n",iid[query_path(1,u)]);
}
}
return 0;
}