题目描述
难度分:$2400$
你有一棵树,一开始有$4$个节点,编号为$1,2,3,4$,其中$2,3,4$都和$1$相连。
输入$q(1 \leq q \leq 5 \times 10^5)$表示有$q$次操作。每次操作,输入$v(1 \leq v \leq$ 当前树的大小$)$,保证$v$是叶子。
在叶子$v$的下面添加两个新的节点与$v$相连,编号分别为$n+1$和$n+2$,其中$n$是当前树的大小。
每次操作后,输出树的直径长度。
输入样例
5
2
3
4
8
5
输出样例
3
4
4
5
6
算法
倍增求LCA
如果加入一个新的节点后,树的直径变长了,那么新直径的一个端点肯定是新加入的这个节点。
初始化直径的长度$diameter=2$,两个端点分别为$end_1$和$end_2$。每次加入一个节点$cur$的时候,如果这个节点到$end_1$和$end_2$中任意一点的距离超过了$diameter$,就可以把直径长度更新为这个距离。比如$cur$到$end_1$的距离超过了$diameter$,那直径长度就更新为$cur$到$end_1$的距离,$end_2$就更新为$cur$。
可以证明任何直径外的点$y$都不会比$cur$到直径端点的距离更长。证明:假设$cur$到$end_1$的距离长于$cur$到$end_2$的距离。如果$cur$到一个直径外的点$y$距离比$cur$到$end_1$的距离还长,则$end_1$到$end_2$就不应该是直径,直径应该是$y$到$end_2$才对。
这样一来,先基于初始的$4$个节点预处理出倍增的$dist$、$depth$、$fa$三个数组。然后每加入一个新的节点$cur$,就更新这三个数组,计算新节点到上一轮直径端点$end_1$和$end_2$的距离,从而判断直径端点和长度该如何更新。
复杂度分析
时间复杂度
倍增初始化的时间复杂度为$O(nlog_2U)$,$U$为$q$个询问完成后的树节点总数,初始情况下$n=4$很小。接下来处理每个询问,新加入一个节点就要更新一次$fa$数组、$dist$数组和$depth$数组,时间复杂度为$O(log_2U)$,每次加入两个节点,时间复杂度仍然是这个级别。求新加入节点与老的直径端点$end1$和$end2$的距离,时间复杂度为$O(log_2U)$。因此,处理$q$个询问的时间复杂度为$O(qlog_2U)$。
综上,时间复杂度和树的节点个数并没有关系,只跟预估的树节点总数有关,时间复杂度为$O(nlog_2U+qlog_2U)$。
空间复杂度
空间消耗就是倍增求LCA
的辅助数组消耗,$dist$和$depth$数组是线性空间,为$O(U)$。$fa$数组还需要考虑到每个节点距离根节点的路径长度,空间消耗为$O(Ulog_2U)$。因此,整个算法的额外空间复杂度为$O(Ulog_2U)$。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1000010, M = 20;
int q, v, dist[N], depth[N], fa[N][M];
void bfs(unordered_map<int, vector<int>>& graph, int root, int d) {
depth[root] = 1;
dist[root] = 0;
fa[root][0] = 0;
queue<int> q;
q.push(root);
while(!q.empty()) {
int cur = q.front();
q.pop();
for(int nxt: graph[cur]) {
if(depth[nxt]) continue;
q.push(nxt);
depth[nxt] = depth[cur] + 1;
dist[nxt] = dist[cur] + 1;
fa[nxt][0] = cur;
for(int j = 1; j < M; j++) {
fa[nxt][j] = fa[fa[nxt][j - 1]][j - 1];
}
}
}
}
int lca(int a, int b) {
if(depth[a] < depth[b]) swap(a, b);
for(int i = M - 1; i >= 0; i--) {
if(depth[fa[a][i]] >= depth[b]) {
a = fa[a][i];
}
}
if(a == b) return a;
for(int i = M - 1; i >= 0; i--) {
if(fa[a][i] != fa[b][i]) {
a = fa[a][i];
b = fa[b][i];
}
}
return fa[a][0];
}
int get(int x, int y) {
return dist[x] + dist[y] - (dist[lca(x, y)]<<1);
}
int main() {
int end1 = 2, end2 = 3, diameter = 2;
scanf("%d", &q);
unordered_map<int, vector<int>> graph;
graph[1].push_back(2);
graph[2].push_back(1);
graph[1].push_back(3);
graph[3].push_back(1);
graph[1].push_back(4);
graph[4].push_back(1);
bfs(graph, 1, 0);
int n = 4;
for(int k = 1; k <= q; k++) {
scanf("%d", &v);
for(int i = 1; i <= 2; i++) {
int cur = n + i;
fa[cur][0] = v;
dist[cur] = dist[v] + 1;
depth[cur] = depth[v] + 1;
for(int j = 1; j < M; j++) {
fa[cur][j] = fa[fa[cur][j - 1]][j - 1];
}
int d1 = get(cur, end1), d2 = get(cur, end2);
if(d1 >= d2) {
if(d1 > diameter) {
end2 = cur;
diameter = d1;
}
}else {
if(d2 > diameter) {
end1 = cur;
diameter = d2;
}
}
}
n += 2;
printf("%d\n", diameter);
}
return 0;
}