#include <bits/stdc++.h>
using namespace std;
const int N=500010,M=2*N;//双向边存两倍
int h[N],e[M],ne[M],idx;
int deep[N],fa[N][32],lg[N];
//deep[i]是i号节点的深度
//lg是log数组
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs(int x,int y)
{
deep[x]=deep[y]+1;//x是y的儿子节点,所以要+1
fa[x][0]=y;//fa[x][0]表示x的父亲节点,而y是x的父亲节点.
for(int i=1; (1<<i)<=deep[x]; i++) //2^i<=deep[x]表示不能跳出去了,最多跳到根节点上面
fa[x][i]=fa[fa[x][i-1]][i-1];//状态转移 2^i=2^(i-1)+2^(i-1)
for(int i=h[x]; ~i; i=ne[i])
if(e[i]!=y)
dfs(e[i],x);
}
int lca(int x,int y)
{
if(deep[x]<deep[y])
swap(x,y);
while(deep[x]>deep[y])
x=fa[x][lg[deep[x]-deep[y]]];
if(x==y)
return x;
for(int k=lg[deep[x]]; k>=0; k--)
if(fa[x][k]!=fa[y][k])
{
x=fa[x][k];
y=fa[y][k];
}
return fa[x][0];
}
int dis(int x,int y){
return deep[x]+deep[y]-2*deep[lca(x,y)];
}
int main()
{
lg[0]=-1; //预处理lg数组
for(int i=1;i<N;i++){
lg[i]=lg[i>>1]+1;
}
int n,m;
scanf("%d%d",&n,&m);//n个节点,m次询问
memset(h,-1,sizeof(h));
for(int i=1; i<n; i++) //n-1条边
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1,0);//根节点的父亲节点没有,故选择0
for(int i=1; i<=m; i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
int p1=lca(y,z),d1=dis(x,lca(y,z))+dis(y,z);
int p2=lca(x,y),d2=dis(z,lca(x,y))+dis(x,y);
int p3=lca(x,z),d3=dis(y,lca(x,z))+dis(x,z);
if(d1>d2){ //保持d1是最小的,不额外设置变量比较方便
p1=p2,d1=d2;
}
if(d1>d3){
p1=p3,d1=d3;
}
printf("%d %d\n",p1,d1);
}
return 0;
}
import java.io.;
import java.util.;
public class Main {
static int[] val,next,h,depth;
static int[][]fa;
static int idx;
static boolean[]vis;
public static void main(String[] args) {
Scanner scan = new Scanner(System.in);
int n = scan.nextInt();
int m = scan.nextInt();
idx = 0;
val = new int[m2+10];
next = new int[m2+10];
h = new int[n+10];
fa = new int[n+1][20];
depth = new int[n+1];
Arrays.fill(h, -1);
int[] count = new int[n+1];//记录相连节点个数,用于找出根节点
for(int i = 0; i < n-1; i) {
int a = scan.nextInt();
int b = scan.nextInt();
count[a];
count[b];
add(a,b);
add(b,a);
}
dfs(1,0);
for(int i = 0; i < m; i) {
int x = scan.nextInt();
int y = scan.nextInt();
int z = scan.nextInt();
int p1=lca(y,z),d1=dis(x,lca(y,z))+dis(y,z);
int p2=lca(x,y),d2=dis(z,lca(x,y))+dis(x,y);
int p3=lca(x,z),d3=dis(y,lca(x,z))+dis(x,z);
}
大佬帮忙看一下我的为什么超时了?
作者您好,这里不使用ans=(d1+d2+d3)/2的做法是因为需要求出p的位置对吗?