u_numflow/collections/
union_find.rs1#[derive(Debug, Clone)]
41pub struct UnionFind {
42 parent: Vec<usize>,
43 rank: Vec<u8>,
44 size: Vec<usize>,
45 components: usize,
46}
47
48impl UnionFind {
49 pub fn new(n: usize) -> Self {
54 Self {
55 parent: (0..n).collect(),
56 rank: vec![0; n],
57 size: vec![1; n],
58 components: n,
59 }
60 }
61
62 pub fn len(&self) -> usize {
64 self.parent.len()
65 }
66
67 pub fn is_empty(&self) -> bool {
69 self.parent.is_empty()
70 }
71
72 pub fn find(&mut self, x: usize) -> usize {
83 if self.parent[x] != x {
84 self.parent[x] = self.find(self.parent[x]);
85 }
86 self.parent[x]
87 }
88
89 pub fn union(&mut self, x: usize, y: usize) -> bool {
104 let root_x = self.find(x);
105 let root_y = self.find(y);
106
107 if root_x == root_y {
108 return false;
109 }
110
111 match self.rank[root_x].cmp(&self.rank[root_y]) {
113 std::cmp::Ordering::Less => {
114 self.parent[root_x] = root_y;
115 self.size[root_y] += self.size[root_x];
116 }
117 std::cmp::Ordering::Greater => {
118 self.parent[root_y] = root_x;
119 self.size[root_x] += self.size[root_y];
120 }
121 std::cmp::Ordering::Equal => {
122 self.parent[root_y] = root_x;
123 self.size[root_x] += self.size[root_y];
124 self.rank[root_x] += 1;
125 }
126 }
127
128 self.components -= 1;
129 true
130 }
131
132 pub fn connected(&mut self, x: usize, y: usize) -> bool {
137 self.find(x) == self.find(y)
138 }
139
140 pub fn component_count(&self) -> usize {
145 self.components
146 }
147
148 pub fn component_size(&mut self, x: usize) -> usize {
153 let root = self.find(x);
154 self.size[root]
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn test_new() {
164 let uf = UnionFind::new(5);
165 assert_eq!(uf.len(), 5);
166 assert_eq!(uf.component_count(), 5);
167 }
168
169 #[test]
170 fn test_new_empty() {
171 let uf = UnionFind::new(0);
172 assert_eq!(uf.len(), 0);
173 assert!(uf.is_empty());
174 assert_eq!(uf.component_count(), 0);
175 }
176
177 #[test]
178 fn test_find_initial() {
179 let mut uf = UnionFind::new(5);
180 for i in 0..5 {
181 assert_eq!(uf.find(i), i);
182 }
183 }
184
185 #[test]
186 fn test_union_basic() {
187 let mut uf = UnionFind::new(5);
188 assert!(uf.union(0, 1));
189 assert!(uf.connected(0, 1));
190 assert_eq!(uf.component_count(), 4);
191 }
192
193 #[test]
194 fn test_union_same_set() {
195 let mut uf = UnionFind::new(5);
196 uf.union(0, 1);
197 assert!(!uf.union(0, 1)); assert_eq!(uf.component_count(), 4);
199 }
200
201 #[test]
202 fn test_transitivity() {
203 let mut uf = UnionFind::new(5);
204 uf.union(0, 1);
205 uf.union(1, 2);
206 assert!(uf.connected(0, 2));
207 }
208
209 #[test]
210 fn test_not_connected() {
211 let mut uf = UnionFind::new(5);
212 uf.union(0, 1);
213 uf.union(2, 3);
214 assert!(!uf.connected(0, 2));
215 assert!(!uf.connected(1, 3));
216 }
217
218 #[test]
219 fn test_merge_components() {
220 let mut uf = UnionFind::new(5);
221 uf.union(0, 1);
222 uf.union(2, 3);
223 assert_eq!(uf.component_count(), 3);
224
225 uf.union(1, 3); assert_eq!(uf.component_count(), 2);
227 assert!(uf.connected(0, 2));
228 assert!(uf.connected(0, 3));
229 }
230
231 #[test]
232 fn test_component_size() {
233 let mut uf = UnionFind::new(5);
234 assert_eq!(uf.component_size(0), 1);
235
236 uf.union(0, 1);
237 assert_eq!(uf.component_size(0), 2);
238 assert_eq!(uf.component_size(1), 2);
239
240 uf.union(0, 2);
241 assert_eq!(uf.component_size(0), 3);
242 assert_eq!(uf.component_size(2), 3);
243 }
244
245 #[test]
246 fn test_all_in_one() {
247 let mut uf = UnionFind::new(5);
248 for i in 0..4 {
249 uf.union(i, i + 1);
250 }
251 assert_eq!(uf.component_count(), 1);
252 assert_eq!(uf.component_size(0), 5);
253 for i in 0..5 {
254 for j in 0..5 {
255 assert!(uf.connected(i, j));
256 }
257 }
258 }
259
260 #[test]
261 fn test_single_element() {
262 let mut uf = UnionFind::new(1);
263 assert_eq!(uf.find(0), 0);
264 assert_eq!(uf.component_count(), 1);
265 assert_eq!(uf.component_size(0), 1);
266 }
267}
268
269#[cfg(test)]
270mod proptests {
271 use super::*;
272 use proptest::prelude::*;
273
274 proptest! {
275 #![proptest_config(ProptestConfig::with_cases(300))]
276
277 #[test]
278 fn union_find_transitivity(
279 n in 2_usize..20,
280 ops in proptest::collection::vec((0_usize..20, 0_usize..20), 0..50),
281 ) {
282 let mut uf = UnionFind::new(n);
283 for &(x, y) in &ops {
284 if x < n && y < n {
285 uf.union(x, y);
286 }
287 }
288
289 for x in 0..n {
291 for y in 0..n {
292 for z in 0..n {
293 if uf.connected(x, y) && uf.connected(y, z) {
294 prop_assert!(
295 uf.connected(x, z),
296 "transitivity violated: {x}~{y} and {y}~{z} but not {x}~{z}"
297 );
298 }
299 }
300 }
301 }
302 }
303
304 #[test]
305 fn component_count_invariant(
306 n in 1_usize..20,
307 ops in proptest::collection::vec((0_usize..20, 0_usize..20), 0..50),
308 ) {
309 let mut uf = UnionFind::new(n);
310 let mut expected_components = n;
311
312 for &(x, y) in &ops {
313 if x < n && y < n {
314 let merged = uf.union(x, y);
315 if merged {
316 expected_components -= 1;
317 }
318 }
319 }
320
321 prop_assert_eq!(uf.component_count(), expected_components);
322 }
323
324 #[test]
325 fn component_sizes_sum_to_n(
326 n in 1_usize..20,
327 ops in proptest::collection::vec((0_usize..20, 0_usize..20), 0..30),
328 ) {
329 let mut uf = UnionFind::new(n);
330 for &(x, y) in &ops {
331 if x < n && y < n {
332 uf.union(x, y);
333 }
334 }
335
336 let mut total = 0;
338 for i in 0..n {
339 if uf.find(i) == i {
340 total += uf.component_size(i);
341 }
342 }
343 prop_assert_eq!(total, n, "component sizes should sum to n");
344 }
345 }
346}