LCA + 树状前缀和
$首先有n个节点,n-1条边,所以没有环是一颗树,而求任意两点之间的最短距离在算法提高课的图论中讲过$
$方法是LCA,最近公共祖先,先求每个点到根节点的距离,\\ 那么两个点a, b之间的最短距离就是分别到最近公共祖先的距离之和\\ 而这题就是上面简单扩展一下,加上经过了多少种类的糖果,\\ 由于糖果的种类较少,所以可以枚举每种糖果是否出现\\ 事先求得每个节点到根节点的糖果i出现的次数(前缀和),\\ 那么如果a, b分别到最近公共祖先之间糖果出现次数大于0的话就包含这个糖果$
$时间复杂度:\\ 1. lca初始化O(nlogn)\\ 2. lca查询logn$
总的时间复杂度$O(nlogn)$
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
const int N = 100010, M = N * 2;
int h[N], e[M], ne[M], idx;
int w[N], cnt[N][21];// cnt[i][j]表示从根节点到i的糖果j的数量
int depth[N], f[N][17]; //log2(100010) = 16.6
int n, m;
void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void bfs(int root){
memset(depth, 0x3f, sizeof depth);
queue<int> q;
q.push(root);
depth[0] = 0, depth[root] = 1;
while (q.size()){
int u = q.front();
q.pop();
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (depth[j] > depth[u] + 1){
depth[j] = depth[u] + 1;
f[j][0] = u;
q.push(j);
for (int k = 1; k <= 16; k ++){
f[j][k] = f[f[j][k - 1]][k - 1];
}
}
}
}
}
int lca(int a, int b){
if (depth[a] < depth[b]) return lca(b, a);
for (int i = 16; i >= 0; i --){
if (depth[f[a][i]] >= depth[b])
a = f[a][i];
}
if (a == b) return a;
for (int i = 16; i >= 0; i --){
if (f[a][i] != f[b][i]){
a = f[a][i];
b = f[b][i];
}
}
return f[a][0];
}
void dfs(int u, int fa){
cnt[u][w[u]] ++;
for (int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if (j == fa) continue;
for (int k = 1; k <= 20; k ++){
cnt[j][k] = cnt[u][k] + (k == w[j]? 1 : 0);
}
dfs(j, u);
}
}
int main(){
memset(h, -1, sizeof h);
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++) scanf("%d", &w[i]);
for (int i = 0; i < n - 1; i ++){
int a, b;
scanf("%d%d", &a, &b);
add(a, b); add(b, a);
}
bfs(1);
dfs(1, -1);
while (m --){
int a, b;
scanf("%d%d", &a, &b);
int p = lca(a, b);
int res = 0;
for (int i = 1; i <= 20; i ++){
if (cnt[a][i] + cnt[b][i] - 2 * cnt[f[p][0]][i] > 0) res ++;
}
printf("%d\n", res);
}
return 0;
}