算法1
(倍增法) $\mathcal{O}((n+m)*log_2n)$
倍增法求解$\mathbb{LCA}$问题分为两步。
$\mathcal 1$. 预处理出每个节点的father数组,对于fa[i][k],其含义为i号节点向上的第$2^k$个祖先节点的编号,从其定义不难看出可以利用递推求出。在具体求解时,还需要记录一个depth数组,保存i号节点的深度,规定根节点深度为1,利用bfs遍历树,对于每个出队的节点,可以利用该节点的深度更新所有其儿子节点的深度,同时还可以更新fa数组,假如i号节点有儿子j,则不难看出fa[j][0] = i,然后可以利用fa[j][0]继续往上更新fa[j][1],fa[j][2]…,规则为fa[j][k] = fa[fa[j][k-1]][k-1]。遍历完所有节点则depth和fa数组均预处理完毕。
$\mathcal 2$. 查询阶段,对于任意给定的两个节点a和b,思路是先看两个节点深度是否一致,不一致则将更深的节点上浮到浅的那一层,然后判断同深度的这两个节点是否是同一个,是则直接结束,否则还需要两个节点齐头并进上浮,直到两个节点到达同一个位置。具体而言,上浮是通过fa数组实现的,并且上浮的尺度是从大到小,逐位逼近,Tips是可以将depth[0]设为0,这样在更深节点上浮时,如果大尺度的上浮导致浮到了0号节点(无效节点),也可以继续缩小尺度循环(因为depth[0]一定不大于depth[b]);当a和b深度一致后,通过判断fa[a][k]是否等于fa[b][k],判断是否能结束循环,如果不等,则说明可以继续上浮$2^k$,最终结束时fa[a][0]就是答案了。
$\mathcal 3$. 本题就是在以上的基础上套了个距离,因为还是树,所以可以直接保存一个dist数组,在bfs时保存所有节点到root的距离,最终a和b和最短距离就是dist[a] + dist[b] - 2*dist[lca(a, b)]
Golang 代码
package main
import (
"os"
"bufio"
"strings"
"strconv"
"fmt"
)
func main() {
sc := bufio.NewScanner(os.Stdin)
buf := make([]byte, 1024*1024)
sc.Buffer(buf, len(buf))
sc.Scan()
str := strings.Fields(sc.Text())
n, _ := strconv.Atoi(str[0])
m, _ := strconv.Atoi(str[1])
head, next, to, w := make([]int, n+1), make([]int, n<<1), make([]int, n<<1), make([]int, n<<1)
cnt := 0
add := func(a, b, c int) {
cnt++
next[cnt], to[cnt], w[cnt], head[a] = head[a], b, c, cnt
}
for i := 1; i < n; i++ {
sc.Scan()
str = strings.Fields(sc.Text())
a, _ := strconv.Atoi(str[0])
b, _ := strconv.Atoi(str[1])
c, _ := strconv.Atoi(str[2])
add(a, b, c)
add(b, a, c)
}
dist := make([]int, n+1)
depth := make([]int, n+1)
fa := make([][16]int, n+1)
for i := range dist {
depth[i] = 1<<31 - 1
}
depth[0], depth[1] = 0, 1
bfs := func(root int) {
q := []int{root}
for len(q) > 0 {
cur := q[0]
q = q[1:]
for i := head[cur]; i != 0; i = next[i] {
tot := to[i]
if depth[tot] > depth[cur] + 1 {
depth[tot] = depth[cur] + 1
dist[tot] = dist[cur] + w[i]
fa[tot][0] = cur
q = append(q, tot)
for k := 1; k <= 15; k++ {
fa[tot][k] = fa[fa[tot][k-1]][k-1]
}
}
}
}
}
bfs(1)
lca := func(a, b int) int {
if depth[a] < depth[b] {
a, b = b, a
}
for i := 15; i >= 0; i-- {
if depth[fa[a][i]] >= depth[b] {
a = fa[a][i]
}
}
if a == b {
return a
}
for i := 15; i >= 0; i-- {
if fa[a][i] != fa[b][i] {
a, b = fa[a][i], fa[b][i]
}
}
return fa[a][0]
}
for i := 1; i <= m; i++ {
sc.Scan()
str = strings.Fields(sc.Text())
a, _ := strconv.Atoi(str[0])
b, _ := strconv.Atoi(str[1])
fmt.Println(dist[a]+dist[b]-2*dist[lca(a, b)])
}
}
Go爱了