并查集

作用:

  1. 将两个集合合并
  2. 询问两个元素是否在一个集合当中
  3. 可以在近乎 O(1)O_{(1)} 的时间复杂度上完成以上的两个操作

基本原理:每个集合用一棵树来表示,树根的编号就是整个集合的编号。每个节点存储它的父节点,用p[x] 来表示x的父节点(注意这里的树并不是二叉树)

  1. 判断树根节点:p[x] == x
  2. 求 x 的集合编号:while(p[x] != x) x = p[x]
  3. 合并两个集合:合并的方法就是保留一个根节点,将里一个作为根节点的子树,p[x] 是 x 的集合编号, p[y] 是 y 的集合编号,直接令 p[x] = y
    img
  4. 优化之路径压缩:将一次搜索得到的路径上的所有节点的父节点都转化为根节点,这样在多次转化查找后就会实现树的高度近似为1
    img
  5. 优化之按秩合并:并查集 - OI Wiki

注意:简单的并查集并不支持高效的集合的分离

(图片源自 OI-wiki)

核心代码:

1
2
3
4
5
6
7
8
9
10
//返回x的祖宗节点 + 路径优化
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}

//合并集合a b, 将a作为b的子集插入
void unionSet(int a, int b){
p[find(a)] = find(b);
}

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include<iostream>
using namespace std;

const int N = 100010;

int n, m;
int p[N];

//返回x的祖宗节点 + 路径优化
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}

int main(){
scanf("%d%d",&n, &m);
for(int i = 0; i <= n; i++) p[i] = i;
while(m--){
char op[2];
int a, b;
scanf("%s%d%d", op, &a, &b);
//合并
if(op[0] == 'M') p[find(a)] = find(b);
//查找
else{
if(find(a) == find(b)) puts("Yes");
else puts("No");
}
}
return 0;
}

动态维护每个集合的元素个数:

  1. 使用并查集进行集合合并前的整体本身就应该是一个集合即没有重复数据,故并查集合并操作并不需要考虑查重的问题;
  2. 解决这样一个不需要查重的集合合并问题只需要考虑合并后集合元素的数量相加即可;
  3. 维护时确保每个节点的祖先节点的集合个数才是有意义的,然后在集合合并时将集合元素相加即可。

核心代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int find(int x){
if(x != p[x]) p[x] = find(p[x]);
return p[x];
}

void unionSet(int x, int y){
int a = find(x), b = find(y);

//主要特判元素初始是否在同一集合
if(a == b) return;

node_size[b] += node_size[a];
p[a] = b;
}

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include<iostream>
#include<cstdio>
using namespace std;

const int N = 100010;
int p[N], node_size[N];

int find(int x){
if(x != p[x]) p[x] = find(p[x]);
return p[x];
}

void unionSet(int x, int y){
int a = find(x), b = find(y);
if(a == b) return;
node_size[b] += node_size[a];
p[a] = b;
}

int main(){
int n, m;
cin>>n>>m;
for(int i = 0; i <= n; i++) p[i] = i, node_size[i] = 1;
while(m--){
char op[5];
int a, b;
scanf("%s", op);
if(op[0] == 'C'){
scanf("%d%d", &a, &b);
unionSet(a, b);
}
else if(op[1] == '1'){
scanf("%d%d",&a, &b);
if(find(a) == find(b)) puts("Yes");
else puts("No");
}
else{
scanf("%d", &a);
printf("%d\n", node_size[find(a)]);
}

}
}