原题链接Count on a tree
题目描述
给定一棵 $n$ 个节点的树,每个点有一个权值。有 $m$ 个询问,每次给你 $u$,$v$,$k$,你需要回答 $u$xor $last$ 和 $v$ 这两个节点间第 $k$ 小的点权。
其中 $last$ 是上一个询问的答案,定义其初始为 $0$,即第一个询问的 $u$ 是明文。
输入格式
第一行两个整数 $n,m$。
第二行有 $n$ 个整数,其中第 $i$ 个整数表示点 $i$ 的权值。
后面 $n−1$ 行每行两个整数 $x,y$,表示点$x$ 到点 $y$ 有一条边。
最后 $m$ 行每行两个整数 $u,v,k$,表示一组询问。
题目分析
此题为区间第k小数 升级版且强制在线,树上两点路径之间的点可以看成一个区间,由此转化为区间第k小,考虑主席树。在一维数组中考虑对前缀建主席树,而在树中,考虑从根节点到该点的路径上的点建主席树(dfs过程中建主席树)。然后和区间第k小思路一样进行$query$。
对于$u->v$路径上的点可以表示为$root[u]+root[v]-root[puv]-root[ppuv]$这棵树
$puv$是$u$和$v$的最近公共祖先即puv=lca(u,v)
$ppuv$是$puv$的父亲节点即ppuv=fa[puv][0]
AC代码
注意:注释部分为debug时候用到的代码
#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int N=100010,M=2*N;
int d[N];
int h[N],e[M],ne[M],idx;
int n,m,a[N];
vector<int> num;
struct node
{
int l,r;
int s;
}tree[N*40];
int root[N],cnt;
int depth[N],fa[N][17];
void add(int a,int b)
{
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
int find(int x)
{
return lower_bound(num.begin(),num.end(),x)-num.begin()+1;
}
void insert(int l,int r,int pre,int &now,int pos)
{
now=++cnt;
tree[now]=tree[pre];
tree[now].s++;
if(l==r) return;
int mid=l+r>>1;
if(pos<=mid) insert(l,mid,tree[pre].l,tree[now].l,pos);
else insert(mid+1,r,tree[pre].r,tree[now].r,pos);
}
int query(int l,int r,int u,int v,int puv,int ppuv,int k)
{
if(l==r) return num[l-1];
int mid=l+r>>1;
int tmp=tree[tree[u].l].s+tree[tree[v].l].s-tree[tree[puv].l].s-tree[tree[ppuv].l].s;
if(tmp>=k) return query(l,mid,tree[u].l,tree[v].l,tree[puv].l,tree[ppuv].l,k);
else return query(mid+1,r,tree[u].r,tree[v].r,tree[puv].r,tree[ppuv].r,k-tmp);
}
void dfs(int u,int p)
{
insert(1,num.size(),root[p],root[u],find(a[u]));
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(depth[u]+1>=depth[j]) continue;
depth[j]=depth[u]+1;
fa[j][0]=u;
for(int k=1;k<=16;k++)
fa[j][k]=fa[fa[j][k-1]][k-1];
dfs(j,u);
}
}
int lca(int a,int b)
{
if(depth[a]<depth[b]) swap(a,b);
for(int k=16;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a=fa[a][k];
if(a==b) return a;
for(int k=16;k>=0;k--)
if(fa[a][k]!=fa[b][k])
{
a=fa[a][k];
b=fa[b][k];
}
return fa[a][0];
}
// void print(int l,int r,int u)
// {
// if(l==r)
// {
// cout<<l<<' '<<tree[u].s<<endl;
// return;
// }
// int mid=l+r>>1;
// print(l,mid,tree[u].l);
// print(mid+1,r,tree[u].r);
// }
int main()
{
memset(h,-1,sizeof h);
cin>>n>>m;
for(int i=1;i<=n;i++)
{
cin>>a[i];
num.push_back(a[i]);
}
sort(num.begin(),num.end());
num.erase(unique(num.begin(),num.end()),num.end());
for(int i=1;i<n;i++)
{
int a,b;
cin>>a>>b;
add(a,b),add(b,a);
}
memset(depth,0x3f,sizeof depth);
depth[0]=0;
depth[1]=1;
dfs(1,0);
int last=0;
while(m--)
{
int u,v,k;
cin>>u>>v>>k;
u=u^last;
int puv=lca(u,v);
int ppuv=fa[puv][0];
last=query(1,num.size(),root[u],root[v],root[puv],root[ppuv],k);
cout<<last<<endl;
}
//for(int i=1;i<=n;i++,cout<<"**********\n") print(1,num.size(),root[i]);
return 0;
}