博客食用更佳 https://www.cnblogs.com/czyty114/p/14449737.html
引入
$\;$
假如现在我们得到了一棵$n$个节点的树,每条边都有长度。
现在我们要求这棵树中两个点之间距离小于$k$的点对个数。
$n\leq 4×10^4$
朴素做法
$\;$
先预处理好距离,再$O(n^2)$枚举点对。
重心
$\;$
我们找到这棵树的重心G,把这棵树分为若干个子树,那么发现满足条件的点对只有3种情况:
1.点对在某个子树中(直接递归求解)
2.两个点所构成的路径经过了重心G,但你会发现这两个点一定不能在同一个子树中。
所以我们处理出当前这棵树中每个点的d值,$d_i$表示点$i$到重心G的距离。
那么只需要用$d_i+d_j\leq k$这样$(i,j)$的数量减去$d_i+d_j\leq k$且满足$i,j$在同一个子树中的数量
而你会发现,后者可以在递归子树中处理。
3.这条路经的一个端点是G,那么实质上和2.是一种情况,再加入一个$d_G=0$即可
$\;$
时间复杂度
$\;$
选重心来分割整棵树的目的:
你会发现,这若干棵子树中不会有子树的大小超过原树的一半(否则就与重心的定义不符)
所以最多只会递归$log(n)$层,每一层也是$n$个点。但在递归中还要将处理好的d排序。
总复杂度$O(n log^2 n)$
Code
$\;$
一定要注意:如果在函数里单独开变量而是开全局变量,一定要注意随时清空,防止上一层的答案对下面有影响。
#include <bits/stdc++.h>
const int N = 40010;
int n, k, head[N], tot, f[N], mn, W, vis[N], d[N], ans, q[N], cnt, sz[N];
struct node {
int to, nxt, val;
}E[N << 1];
void add(int u, int v, int w) {
E[++tot].to = v; E[tot].nxt = head[u]; E[tot].val = w; head[u] = tot;
}
void dfs(int total, int u, int fa) {
f[u] = 0; // 一定注意初始化
for(int i=head[u];i;i=E[i].nxt) {
int v = E[i].to;
if(v == fa || vis[v]) continue;
dfs(total, v, u);
f[u] = std::max(f[u], sz[v]);
}
f[u] = std::max(f[u], total - sz[u]);
if(f[u] < mn) {
mn = f[u]; W = u;
}
}
void dfs0(int u, int fa) {
sz[u] = 1; q[++cnt] = d[u]; // 这是在减去2那一部分的时候的d值
for(int i=head[u];i;i=E[i].nxt) {
int v = E[i].to;
if(v == fa || vis[v]) continue;
dfs0(v, u);
sz[u] += sz[v];
}
}
void getG(int rt) {
cnt = 0; // 随时清空
dfs0(rt, 0); // 预处理好每个点的子树大小(因为随着划分重心,树的形态会变化)
mn = 1e9; // 注意初始化
dfs(sz[rt], rt, 0); // DP计算重心
vis[W] = 1; // 这个点作为重心,将其打上标记(相当于一个边界条件)
}
void getd(int u, int fa) {
q[++cnt] = d[u]; // 在求d值的过程中将其存入q数组中,这里是以这棵树为重心是的d值,与上面的d不一样
for(int i=head[u];i;i=E[i].nxt) {
int v = E[i].to;
if(vis[v] || v == fa) continue;
d[v] = d[u] + E[i].val;
getd(v, u);
}
}
void solve(int rt) {
getG(rt); // 得到重心
if(sz[rt] != n) {
// 对于2.情况,要减去在相同子树内(i,j)<=k的个数。这个过程是在递归到这个子树中时进行的
//但对于一开始整棵树的情况就没必要减了
std::sort(q + 1, q + cnt + 1);
// q里存储的是这棵子树内的d值
// 因为树内点的编号不一定是连续的,所以需要开q这个数组存它
int e1 = 1, e2 = cnt;
for(int e1=1;e1<=cnt;e1++) {
while(e2 > e1 && q[e1] + q[e2] > k) e2 --; // 双指针枚举,复杂度是线性的
if(e2 <= e1) break;
ans -= (e2 - e1);
}
}
d[W] = 0; // 重心的d当然是0
cnt = 0; // 一定注意要随时清空
getd(W, 0);
std::sort(q + 1, q + cnt + 1);
int e1 = 1, e2 = cnt;
for(int e1=1;e1<=cnt;e1++) {
while(e2 > e1 && q[e1] + q[e2] > k) e2 --;
if(e2 <= e1) break;
ans += (e2 - e1);
}
for(int i=head[W];i;i=E[i].nxt) {
int v = E[i].to;
if(!vis[v]) solve(v); // 如果这个点没被打上标记,一定要向下递归
}
}
int main() {
scanf("%d", &n);
for(int i=1;i<n;i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add(u, v, w); add(v, u, w);
}
scanf("%d", &k);
solve(1);
printf("%d", ans);
return 0;
}