应用:
1. 求树上任意两个点之间的距离: 任取$root$为根节点, $dis[i]$表示节点$i$到根节点的距离
$dis(a, b) = dis[a] + dis[b] - 2 * dis[lca(a, b)]$
祖孙询问
$O(log_2n)$:倍增求$lca$
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 40040, M = N << 1;
int n, m;
int e[M], ne[M], h[N], idx;
int fa[N][16], depth[N];
int q[N];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void bfs(int root)
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[root] = 1; // 初始化深度, depth[0] = 0表示跳过根节点后的标记深度
int hh = 0, tt = 0;
q[0] = root;
while (hh <= tt)
{
int u = q[hh ++];
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[u] + 1)
{
depth[j] = depth[u] + 1;
q[++ tt] = j;
fa[j][0] = u;
// 求fa[][]数组
for (int k = 1; k <= 15; k ++)
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
int lca(int a, int b)
{
if (depth[a] < depth[b]) swap(a, b); // 保证a在下面
// 先让a与b跳到相同高度
for (int k = 15; k >= 0; k --)
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return a;
// 让a与b变成它们最近公共祖先的两个儿子节点
for (int k = 15; k >= 0; k --)
if (fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
int main()
{
scanf ("%d", &n);
memset(h, -1, sizeof h);
int root = 0;
for (int i = 0; i < n; i ++)
{
int a, b;
scanf ("%d %d", &a, &b);
if (b == -1) root = a;
else add(a, b), add(b, a);
}
bfs(root); // 广搜求depth[] 和 fa[][]
scanf ("%d", &m);
while (m --)
{
int a, b;
scanf ("%d %d", &a, &b);
int p = lca(a, b);
if (p == a) puts("1");
else if (p == b) puts("2");
else puts("0");
}
return 0;
}
距离
$O(n + m)$ $Tarjan$算法离线求$lca$
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef pair<int, int> PII;
const int N = 10010, M = N << 1;
int n, m;
int e[M], w[M], ne[M], h[N], idx;
vector<PII> query[N]; // 记录下每个询问
int st[N]; // 标记每个点的搜索状态, 1表示正在遍历的分支, 2表示已经遍历过, 0表示还未遍历
int dis[N], p[N];
int res[M];
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}
int find(int x)
{
return p[x] == x ? p[x] : p[x] = find(p[x]);
}
void dfs(int u, int fa)
{
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
dis[j] = dis[u] + w[i]; // 更新距离
dfs(j, u);
}
}
void Tarjan(int u)
{
st[u] = 1; // 表示当前正在遍历的分支
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (!st[j]) // 未遍历的点才需要遍历
{
Tarjan(j);
p[j] = u;
}
}
for (auto item : query[u])
{
int y = item.first, id = item.second;
if (st[y] == 2)
{
int lca = find(y);
res[id] = dis[u] + dis[y] - 2 * dis[lca];
}
}
st[u] = 2;
}
int main()
{
scanf ("%d %d", &n, &m);
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i ++)
{
int a, b, c;
scanf ("%d %d %d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
for (int i = 0; i < m ; i++)
{
int x, y;
scanf ("%d %d", &x, &y);
if (x != y) // 如果x == y, 则两点距离为0, 不用另求
{
query[x].push_back({y, i});
query[y].push_back({x, i});
}
}
for (int i = 1; i <= n; i ++) p[i] = i; // 并查集初始化
dfs(1, -1); // 任选一个点为根节点, 求各点到根节点的距离
Tarjan(1);
for (int i = 0; i < m; i ++) printf ("%d\n", res[i]);
return 0;
}
闇の連鎖
树上差分: 给以$a$和$b$为端点路径上的边权都加上$k$:($d[]$为差分数组)
$$ d[a] += k, d[b] += k, d[lca(a, b)] -= 2 * k$$
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 100100, M = N << 1;
int n, m, res;
int e[M], ne[M], h[N], idx;
int depth[N], fa[N][18];
int d[N]; // 差分数组
int q[N];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void bfs(int root)
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[root] = 1;
int hh = 0, tt = 0;
q[0] = root;
while (hh <= tt)
{
int u = q[hh ++];
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[u] + 1)
{
depth[j] = depth[u] + 1;
q[++ tt] = j;
fa[j][0] = u;
for (int k = 1; k <= 17; k ++)
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
int lca(int a, int b)
{
if (depth[a] < depth[b]) swap(a, b);
for (int k = 17; k >= 0; k --)
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return a;
for (int k = 17; k >= 0; k --)
if (fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
int dfs(int u, int fa)
{
int tmp = d[u];
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
int s = dfs(j, u);
if (s == 0) res += m;
else if (s == 1) res += 1;
tmp += s;
}
return tmp;
}
int main()
{
scanf ("%d %d", &n, &m);
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i ++)
{
int a, b;
scanf ("%d %d", &a, &b);
add(a, b), add(b, a);
}
bfs(1);
for (int i = 0; i < m; i ++)
{
int a, b;
scanf ("%d %d", &a, &b);
int ans = lca(a, b);
d[a] ++, d[b] ++, d[ans] -= 2; // 树上差分: 修改两个叶子节点的值和其lca的值
}
dfs(1, -1);
printf ("%d", res);
return 0;
}