两个点到根节点的距离 - lca到根节点的距离 * 2
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
typedef long long LL;
const int N = 4e4+10,M = N * 2;
int n,m;
int h[N],e[M],ne[M],w[M],idx;
int fa[N][16];
int depth[N];
int dist[N];
queue<int> q;
void add(int a,int b,int c)
{
e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++;
}
void bfs(int root)
{
memset(depth,0x3f,sizeof depth);
depth[0] = 0,depth[root] = 1;
q.push(root);
while(q.size())
{
int t = q.front();
q.pop();
for(int i = h[t];~i;i = ne[i])
{
int j = e[i];
if(depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
dist[j] = dist[t] + w[i];
q.push(j);
fa[j][0] = t;
for(int k = 1;k<=15;k++)
{
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
}
int lca(int x,int y)
{
if(depth[x] < depth[y])swap(x,y);
for(int k = 15;k>=0;k--)
{
if(depth[fa[x][k]] >= depth[y])
{
x = fa[x][k];
}
}
if(x == y)return x;
for(int k = 15;k>=0;k--)
{
if(fa[x][k] != fa[y][k])
{
x = fa[x][k];
y = fa[y][k];
}
}
return fa[x][0];
}
int main()
{
memset(h,-1,sizeof h);
cin>>n>>m;
int root;
for(int i = 0;i<n-1;i++)
{
int a,b,k;
cin>>a>>b>>k;
add(a,b,k),add(b,a,k);
}
bfs(1);
while (m -- )
{
int x,y;
cin>>x>>y;
int p = lca(x,y);
cout<<dist[x] + dist[y] - 2 * dist[p]<<endl;
}
return 0;
}
dfs版本
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
typedef long long LL;
const int N = 4e4+10,M = N * 2;
int n,m;
int h[N],e[M],ne[M],w[M],idx;
int fa[N][16];
int depth[N];
int dist[N];
queue<int> q;
void add(int a,int b,int c)
{
e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++;
}
void dfs(int u,int f)
{
depth[u] = depth[f] + 1;
fa[u][0] = f;
for(int i = 1;i<16;i++)
{
fa[u][i] = fa[fa[u][i-1]][i-1];
}
for(int i = h[u];~i;i = ne[i])
{
int j = e[i];
if(j == f)continue;
dist[j] = dist[u] + w[i];
dfs(j,u);
}
}
int lca(int x,int y)
{
if(depth[x] < depth[y])swap(x,y);
for(int k = 15;k>=0;k--)
{
if(depth[fa[x][k]] >= depth[y])
{
x = fa[x][k];
}
}
if(x == y)return x;
for(int k = 15;k>=0;k--)
{
if(fa[x][k] != fa[y][k])
{
x = fa[x][k];
y = fa[y][k];
}
}
return fa[x][0];
}
int main()
{
memset(h,-1,sizeof h);
cin>>n>>m;
int root;
for(int i = 0;i<n-1;i++)
{
int a,b,k;
cin>>a>>b>>k;
add(a,b,k),add(b,a,k);
}
dfs(1,0);
while (m -- )
{
int x,y;
cin>>x>>y;
int p = lca(x,y);
cout<<dist[x] + dist[y] - 2 * dist[p]<<endl;
}
return 0;
}