find函数
import java.util.*;
import java.io.*;
public class Main {
static final int N = 30010;
static int[] p = new int[N], size = new int[N], d = new int[N];
static {
/*
size[x]表示集合的大小
d[x]表示x到p[x]的距离
初始化三个数组,p[i] = i,
size[i] = 1(祖宗节点自己就是一个),
d[i] = 0(初始时d[x]表示x到p[x]的距离为零)
*/
for (int i = 1; i < N; i++) {
p[i] = i;
size[i] = 1;
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
PrintWriter out = new PrintWriter(System.out);
int T = Integer.parseInt(br.readLine());
while (T -- > 0) {
String[] str = br.readLine().split(" ");
int a = Integer.parseInt(str[1]), b = Integer.parseInt(str[2]);
int pa = find(a), pb = find(b);
if (str[0].equals("M")) {
if (pa != pb) {
p[pa] = pb;//将a连b(a是子树)
//将a连向b这个边的权值就是size[pb],会在下一轮find进行更新a子节点的d数值
d[pa] = size[pb];
size[pb] += size[pa];//更新b树的大小
}
} else {//注意一下在查询前需要先find一下保证d[x]的结果正确
if (pa != pb) out.println(-1);
//根据题目要求,求出中间的船数量特别判断0,da = db时表示中间没有船,输出零即可不是-1
else out.println(Math.max(0, Math.abs(d[a] - d[b]) - 1));
}
}
out.flush();
}
/**
* 明确递归find函数:寻找x的祖宗结点(带有路径优化)
* 再明确d[x]:x到px的距离
* d[px]:px到ppx的距离
* 明确了上面几点我们就可以更新d[x]了,通过px作为桥梁来更新d[x]
* 由于路径优化d[px]会变为px到祖宗结点的距离
* 即d[x] = d[x] + d[px];
*/
static int find(int x) {
if (x != p[x]) {
int root = find(p[x]);//先找到祖宗节点
/*
更新当前节点到p[x]的距离(以p[x]为连接点计算x到root的距离)
该root就是真正的祖宗节点并没有进行连边,会在后面进行连边
(下一轮find的时候,这就是为什么需要在查询的时候先调用find函数)
连边时,将后面的节点加上前面树的大小即可(优化为只需要将root节点更
新为size大小,这一点会在下一轮find函数执行的时候进行更新距离)
总体上就是借助父节点来计算到祖结点
*/
d[x] += d[p[x]];
p[x] = root;//路径优化
}
return p[x];
}
}
实践题
import java.util.*;
public class Main {
static final int N = 30010;
static int[] p = new int[N], size = new int[N], d = new int[N];
static int n, q;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
q = sc.nextInt();
init(n);
while (q -- > 0) {
int ops = sc.nextInt();
if (ops == 1) {
int a = sc.nextInt(), b = sc.nextInt();
merge(a, b);
} else {
int a = sc.nextInt();
find(a);
System.out.println(d[a]);
}
}
sc.close();
}
//a->b
static void merge(int a, int b) {
int pa = find(a), pb = find(b);
if (pa == pb) return;
p[pa] = pb;
d[pa] = size[pb];
size[pb] += size[pa];
}
static int find(int x) {
if (p[x] != x) {
int px = find(p[x]);
d[x] = d[x] + d[p[x]];
p[x] = px;
}
return p[x];
}
static void init(int n) {
for (int i = 1; i <= n; i++) {
p[i] = i;
size[i] = 1;
// d[i] = 0;
}
}
}