绕题解逛了一圈,没有发现完全按书上第二种方法写的,自己gong了一个放上来qwq
比直接存dis的难写一些,实测快一点 几十ms emmm
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 10005, M = 20005;
int n, k;
int tot;
int head[N], nxt[M], ver[M], edge[M];
void add(int x, int y, int z)
{
nxt[++tot] = head[x];
ver[tot] = y;
edge[tot] = z;
head[x] = tot;
}
int minx, root, all;
int sz[N], vis[N];
void get_root(int x, int fa)
{
sz[x] = 1;
int max_part = 0;
for (int i = head[x]; i; i = nxt[i])
{
int y = ver[i];
if (y == fa || vis[y]) continue;
get_root(y, x);
sz[x] += sz[y];
max_part = max(max_part, sz[y]);
}
max_part = max(max_part, all - sz[x]);
if (max_part < minx)
{
minx = max_part;
root = x;
}
}
int ans, len;
int cnt[N], d[N];
struct Node
{
int dis, bel;
bool operator < (const Node &x) const
{
return dis < x.dis;
}
}seq[N];
void get_dis(int x, int fa, int bel)
{
seq[++len].dis = d[x];
if (x == root || fa == root) bel = x;
seq[len].bel = bel;
cnt[bel]++;
for (int i = head[x]; i; i = nxt[i])
{
int y = ver[i];
if (y == fa || vis[y]) continue;
d[y] = d[x] + edge[i];
get_dis(y, x, bel);
}
}
int cal()
{
sort(seq + 1, seq + len + 1);
int res = 0;
/*
//cnt[s]存[l + 1, r]中属于s的个数
int l = 1, r = len;
cnt[seq[1].bel]--;//这里要先减
while (l < r)
{
if (seq[l].dis + seq[r].dis <= k)
{
res += r - l - cnt[seq[l].bel];
cnt[seq[++l].bel]--;
}
else cnt[seq[r--].bel]--;
}
*/
//cnt[s]存[l, r]中属于s的个数(因为我习惯[l, r]啦)
int l = 1, r = len;
while (l <= r)//l = r时显然答案不会更新了,但是为了清空cnt数组emm
{
if (seq[l].dis + seq[r].dis <= k)
{
res += r - l + 1 - cnt[seq[l].bel];
cnt[seq[l++].bel]--;
}
else cnt[seq[r--].bel]--;
}
return res;
}
void solve(int x)
{
len = 0;
d[x] = 0;
get_dis(x, 0, 0);
ans += cal();
vis[x] = 1;
for (int i = head[x]; i; i = nxt[i])
{
int y = ver[i];
if (vis[y]) continue;
all = minx = sz[y];
get_root(y, x);
solve(root);
}
}
int main()
{
while (cin >> n >> k, n && k)
{
memset(vis, 0, sizeof vis);
memset(head, 0, sizeof head);
tot = 1;
for (int i = 1; i < n; i++)
{
int x, y, z;
cin >> x >> y >> z;
x++, y++;
add(x, y, z), add(y, x, z);
}
ans = 0;
minx = all = n;
get_root(1, 0);
solve(root);
cout << ans << endl;
}
return 0;
}