算法
总的时间复杂度: O(m+n)
总的空间复杂度: O(n)
UnionFind的时间复杂度:
- UnionFind:
O(n)
- toUnion: 平均
O(1)
- isConnected: 平均
O(1)
- size:
O(1)
- count: 平均
O(1)
UnionFind的空间复杂度: O(n)
C++代码
class UnionFind {
struct Node {
explicit Node() noexcept : cnt(1) {}
int root;
int cnt;
};
public:
explicit UnionFind(int capacity) : nodeVec(capacity), size(capacity) {
for (int x = 0; x < capacity; ++x) {
nodeVec[x].root = x;
}
}
inline bool isConnected(int x, int y) { return findRoot(x) == findRoot(y); }
inline bool toUnion(int x, int y) {
const int xRoot = findRoot(x);
const int yRoot = findRoot(y);
if (xRoot != yRoot) {
if (nodeVec[xRoot].cnt < nodeVec[yRoot].cnt) {
nodeVec[xRoot].root = yRoot;
nodeVec[yRoot].cnt += nodeVec[xRoot].cnt;
} else {
nodeVec[yRoot].root = xRoot;
nodeVec[xRoot].cnt += nodeVec[yRoot].cnt;
}
--size;
return true;
}
return false;
}
inline int sizeSet() const { return size; }
inline int count(int x) { return nodeVec[findRoot(x)].cnt; }
private:
int findRoot(int x) {
int root = x;
while (nodeVec[root].root != root) {
root = nodeVec[root].root;
}
while (nodeVec[x].root != root) {
const int y = nodeVec[x].root;
nodeVec[x].root = root;
x = y;
}
return root;
}
vector<Node> nodeVec;
int size;
};
Java代码
class UnionFind {
public UnionFind(final int n) {
roots = new int[n];
cnts = new int[n];
siz = n;
for (int i = 0; i < n; ++i) {
roots[i] = i;
cnts[i] = 1;
}
}
public void toUnion(final int x, final int y) {
final int xRoot = findRoot(x);
final int yRoot = findRoot(y);
if (xRoot != yRoot) {
roots[xRoot] = yRoot;
cnts[yRoot] += cnts[xRoot];
--siz;
}
}
public boolean isConnected(final int x, final int y) {
return findRoot(x) == findRoot(y);
}
public int size() {
return siz;
}
public int count(final int x) {
return cnts[findRoot(x)];
}
private int findRoot(int x) {
int root = x;
while (roots[root] != root) {
root = roots[root];
}
while (roots[x] != root) {
final int tmp = roots[x];
roots[x] = root;
x = tmp;
}
return root;
}
private final int[] roots;
private final int[] cnts;
private int siz;
}
public class Main {
public static void main(String[] args) {
final Scanner scanner = new Scanner(System.in);
final int n = scanner.nextInt();
final int m = scanner.nextInt();
final UnionFind unionFind = new UnionFind(n);
for (int i = 0; i < m; ++i) {
final String oper = scanner.next();
if (oper.equals("C")) {
final int a = scanner.nextInt();
final int b = scanner.nextInt();
unionFind.toUnion(a - 1, b - 1);
} else if (oper.equals("Q1")) {
final int a = scanner.nextInt();
final int b = scanner.nextInt();
System.out.println(unionFind.isConnected(a - 1, b - 1) ? "Yes" : "No");
} else { // oper.equals("Q2")
final int a = scanner.nextInt();
System.out.println(unionFind.count(a - 1));
}
}
}
}
Python3代码
class UnionFind:
def __init__(self, n):
self.roots = list(range(n))
self.cnts = [1] * n
self.siz = n
def toUnion(self, x, y):
xRoot = self.findRoot(x)
yRoot = self.findRoot(y)
if xRoot != yRoot:
self.roots[xRoot] = yRoot
self.cnts[yRoot] += self.cnts[xRoot]
self.siz -= 1
def isConnected(self, x, y):
return self.findRoot(x) == self.findRoot(y)
def size(self):
return self.siz
def count(self, x):
return self.cnts[self.findRoot(x)]
def findRoot(self, x):
root = x
while self.roots[root] != root:
root = self.roots[root]
while self.roots[x] != root:
tmp = self.roots[x]
self.roots[x] = root
x = tmp
return root
if __name__ == '__main__':
n, m = map(int, input().split())
unionFind = UnionFind(n)
for i in range(m):
line = input().split()
oper = line[0]
line = line[1:]
if oper == "C":
a, b = map(int, line)
unionFind.toUnion(a - 1, b - 1)
elif oper == "Q1":
a, b = map(int, line)
if unionFind.isConnected(a - 1, b - 1):
print("Yes")
else:
print("No")
else:
a, = map(int, line)
print(unionFind.count(a - 1))