'''
每一条非树边都会和树上的边形成一个环,对环上的每一条边进行累加计数,任何一条树
边只要出现在环里面,就增加一个计数,对于计数是1的边,可以共享一种方案,那就是
删除这条树边和对应形成环的那条非树边,对于计数超过1的树边,删掉该树边,至少
还要删除两条非树边,不符合题目方案要求,对于计数为0的边,没有出现在任何环里面
删除该边可以贡献非树边数量这么多的方案数,因此枚举每一条非树边,两个非树边的
端点在树中形成路径上面每条边计数都增加1,最后统计每一条边的计数即可,注意使用
LCA进行树上差分快速增加路径上所有边计数
'''
import math
from collections import deque
class LCA:
# edges 为所有树中的无向边,max_node_num为节点数,节点编号为0, 1, 2, ...... (max_node_num-1)
# 0号节点是整个树的根
def __init__(self, edges, max_node_num):
self.max_node_num = max_node_num
link = [[] for i in range(max_node_num)]
self.__max_dep = 0
for a, b in edges:
link[a].append(b)
link[b].append(a)
self.dep = [0] * self.max_node_num # 每个节点的深度
fa = [0] * self.max_node_num # 每个节点的父节点
# bfs一次把每个节点的父节点和深度求出
def __bfs():
que = deque()
que.append(0)
visit = [0] * self.max_node_num
visit[0] = 1
depth = 0
while len(que) > 0:
node_num = len(que)
for _ in range(node_num):
cur = que.popleft()
self.dep[cur] = depth
self.__max_dep = max(self.__max_dep, depth)
for child in link[cur]:
if visit[child] == 0:
visit[child] = 1
fa[child] = cur
que.append(child)
depth += 1
__bfs()
# f[i][j] 表示节点向上跳跃2^j步后的祖先节点标号
self.max_j = int(math.log2(self.__max_dep)) + 1
self.f = [[0] * (self.max_j + 1) for _ in range(self.max_node_num)]
for i in range(self.max_node_num):
self.f[i][0] = fa[i]
for j in range(1, self.max_j+1):
for i in range(self.max_node_num):
self.f[i][j] = self.f[ self.f[i][j-1] ][j-1]
# 获取a, b 两个节点的最近公共祖先编号
def get_lca(self, a, b):
if self.dep[a] > self.dep[b]:
a, b = b, a
sub_dep = self.dep[b] - self.dep[a]
if sub_dep > 0:
# b 移动到和 a 同一层
j = 0
while sub_dep:
if sub_dep & 1 != 0:
b = self.f[b][j]
sub_dep >>= 1
j += 1
# a, b两个点本来就是祖孙关系,直接返回
if a == b:
return a
# 两个节点同时上移
j = min(int(math.log2(self.dep[a])) + 1, self.max_j)
while j >= 0:
aa, bb = self.f[a][j], self.f[b][j]
if aa != bb:
a, b = aa, bb
j -= 1
return self.f[a][0]
n, m = map(int, input().split())
edges = []
link = [[] for _ in range(n)]
for i in range(n-1):
a, b = map(int, input().split())
a, b = a-1, b-1
edges.append((a, b))
link[a].append(b)
link[b].append(a)
cnt = [0] * n # 每一个节点上的计数
prefix_cnt = [0] * n
lca = LCA(edges, max_node_num=n)
for i in range(m):
a, b = map(int, input().split())
a, b = a-1, b-1
# 树上差分
p = lca.get_lca(a, b)
cnt[a] += 1
cnt[b] += 1
cnt[p] -= 2
# dfs 一次算每个节点的原始值
def dfs(cur, prev = None):
ans = cnt[cur]
for child in link[cur]:
if child != prev:
ans += dfs(child, cur)
prefix_cnt[cur] = ans
return ans
dfs(0)
ans = 0
for i in range(1, n):
if prefix_cnt[i] == 0:
ans += m
elif prefix_cnt[i] == 1:
ans += 1
print(ans)