scirs2_cluster/hierarchy/disjoint_set.rs
1//! Disjoint Set (Union-Find) data structure for connectivity queries
2//!
3//! This module provides a disjoint set data structure that efficiently supports
4//! union and find operations. It's particularly useful for clustering algorithms
5//! that need to track connected components or merge clusters.
6//!
7//! The implementation uses path compression and union by rank optimizations
8//! for nearly O(1) amortized performance.
9
10use std::collections::HashMap;
11
12/// Disjoint Set (Union-Find) data structure
13///
14/// This data structure maintains a collection of disjoint sets and supports
15/// efficient union and find operations. It's commonly used in clustering
16/// algorithms for tracking connected components.
17///
18/// # Examples
19///
20/// ```
21/// use scirs2_cluster::hierarchy::DisjointSet;
22///
23/// let mut ds = DisjointSet::new();
24///
25/// // Add some elements
26/// ds.make_set(1);
27/// ds.make_set(2);
28/// ds.make_set(3);
29/// ds.make_set(4);
30///
31/// // Union some sets
32/// ds.union(1, 2);
33/// ds.union(3, 4);
34///
35/// // Check connectivity
36/// assert_eq!(ds.find(&1), ds.find(&2)); // 1 and 2 are connected
37/// assert_eq!(ds.find(&3), ds.find(&4)); // 3 and 4 are connected
38/// assert_ne!(ds.find(&1), ds.find(&3)); // 1 and 3 are in different sets
39/// ```
40#[derive(Debug, Clone)]
41pub struct DisjointSet<T: Clone + std::hash::Hash + Eq> {
42 /// Parent pointers for each element
43 parent: HashMap<T, T>,
44 /// Rank (approximate depth) of each tree
45 rank: HashMap<T, usize>,
46 /// Number of disjoint sets
47 num_sets: usize,
48}
49
50impl<T: Clone + std::hash::Hash + Eq> DisjointSet<T> {
51 /// Create a new empty disjoint set
52 ///
53 /// # Examples
54 ///
55 /// ```
56 /// use scirs2_cluster::hierarchy::DisjointSet;
57 /// let ds: DisjointSet<i32> = DisjointSet::new();
58 /// ```
59 pub fn new() -> Self {
60 Self {
61 parent: HashMap::new(),
62 rank: HashMap::new(),
63 num_sets: 0,
64 }
65 }
66
67 /// Create a new disjoint set with a specified capacity
68 ///
69 /// This can improve performance when you know approximately how many
70 /// elements you'll be adding.
71 ///
72 /// # Arguments
73 ///
74 /// * `capacity` - Expected number of elements
75 pub fn with_capacity(capacity: usize) -> Self {
76 Self {
77 parent: HashMap::with_capacity(capacity),
78 rank: HashMap::with_capacity(capacity),
79 num_sets: 0,
80 }
81 }
82
83 /// Add a new element as its own singleton set
84 ///
85 /// If the element already exists, this operation has no effect.
86 ///
87 /// # Arguments
88 ///
89 /// * `x` - Element to add
90 ///
91 /// # Examples
92 ///
93 /// ```
94 /// use scirs2_cluster::hierarchy::DisjointSet;
95 /// let mut ds = DisjointSet::new();
96 /// ds.make_set(42);
97 /// assert!(ds.contains(&42));
98 /// ```
99 pub fn make_set(&mut self, x: T) {
100 if !self.parent.contains_key(&x) {
101 self.parent.insert(x.clone(), x.clone());
102 self.rank.insert(x, 0);
103 self.num_sets += 1;
104 }
105 }
106
107 /// Find the representative (root) of the set containing the given element
108 ///
109 /// Uses path compression for optimization: all nodes on the path to the root
110 /// are made to point directly to the root.
111 ///
112 /// # Arguments
113 ///
114 /// * `x` - Element to find the representative for
115 ///
116 /// # Returns
117 ///
118 /// * `Some(representative)` if the element exists in the structure
119 /// * `None` if the element doesn't exist
120 ///
121 /// # Examples
122 ///
123 /// ```
124 /// use scirs2_cluster::hierarchy::DisjointSet;
125 /// let mut ds = DisjointSet::new();
126 /// ds.make_set(1);
127 /// ds.make_set(2);
128 /// ds.union(1, 2);
129 ///
130 /// let root1 = ds.find(&1).unwrap();
131 /// let root2 = ds.find(&2).unwrap();
132 /// assert_eq!(root1, root2); // Same representative
133 /// ```
134 pub fn find(&mut self, x: &T) -> Option<T> {
135 if !self.parent.contains_key(x) {
136 return None;
137 }
138
139 // Path compression: make all nodes on path point to root
140 let mut current = x.clone();
141 let mut path = Vec::new();
142
143 // Find root
144 while self.parent[¤t] != current {
145 path.push(current.clone());
146 current = self.parent[¤t].clone();
147 }
148
149 // Compress path
150 for node in path {
151 self.parent.insert(node, current.clone());
152 }
153
154 Some(current)
155 }
156
157 /// Union two sets containing the given elements
158 ///
159 /// Uses union by rank: the root of the tree with smaller rank becomes
160 /// a child of the root with larger rank.
161 ///
162 /// # Arguments
163 ///
164 /// * `x` - Element from first set
165 /// * `y` - Element from second set
166 ///
167 /// # Returns
168 ///
169 /// * `true` if the sets were successfully unioned (they were different sets)
170 /// * `false` if the elements were already in the same set or don't exist
171 ///
172 /// # Examples
173 ///
174 /// ```
175 /// use scirs2_cluster::hierarchy::DisjointSet;
176 /// let mut ds = DisjointSet::new();
177 /// ds.make_set(1);
178 /// ds.make_set(2);
179 ///
180 /// assert!(ds.union(1, 2)); // Successfully unioned
181 /// assert!(!ds.union(1, 2)); // Already in same set
182 /// ```
183 pub fn union(&mut self, x: T, y: T) -> bool {
184 let root_x = match self.find(&x) {
185 Some(root) => root,
186 None => return false,
187 };
188
189 let root_y = match self.find(&y) {
190 Some(root) => root,
191 None => return false,
192 };
193
194 if root_x == root_y {
195 return false; // Already in same set
196 }
197
198 // Union by rank
199 let rank_x = self.rank[&root_x];
200 let rank_y = self.rank[&root_y];
201
202 match rank_x.cmp(&rank_y) {
203 std::cmp::Ordering::Less => {
204 self.parent.insert(root_x, root_y);
205 }
206 std::cmp::Ordering::Greater => {
207 self.parent.insert(root_y, root_x);
208 }
209 std::cmp::Ordering::Equal => {
210 // Same rank, make one root and increase its rank
211 self.parent.insert(root_y, root_x.clone());
212 self.rank.insert(root_x, rank_x + 1);
213 }
214 }
215
216 self.num_sets -= 1;
217 true
218 }
219
220 /// Check if two elements are in the same set
221 ///
222 /// # Arguments
223 ///
224 /// * `x` - First element
225 /// * `y` - Second element
226 ///
227 /// # Returns
228 ///
229 /// * `true` if both elements exist and are in the same set
230 /// * `false` if they're in different sets or don't exist
231 ///
232 /// # Examples
233 ///
234 /// ```
235 /// use scirs2_cluster::hierarchy::DisjointSet;
236 /// let mut ds = DisjointSet::new();
237 /// ds.make_set(1);
238 /// ds.make_set(2);
239 /// ds.make_set(3);
240 /// ds.union(1, 2);
241 ///
242 /// assert!(ds.connected(&1, &2)); // Connected
243 /// assert!(!ds.connected(&1, &3)); // Not connected
244 /// ```
245 pub fn connected(&mut self, x: &T, y: &T) -> bool {
246 match (self.find(x), self.find(y)) {
247 (Some(root_x), Some(root_y)) => root_x == root_y,
248 _ => false,
249 }
250 }
251
252 /// Check if an element exists in the disjoint set
253 ///
254 /// # Arguments
255 ///
256 /// * `x` - Element to check
257 ///
258 /// # Returns
259 ///
260 /// * `true` if the element exists
261 /// * `false` otherwise
262 pub fn contains(&self, x: &T) -> bool {
263 self.parent.contains_key(x)
264 }
265
266 /// Get the number of disjoint sets
267 ///
268 /// # Returns
269 ///
270 /// The number of disjoint sets currently in the structure
271 ///
272 /// # Examples
273 ///
274 /// ```
275 /// use scirs2_cluster::hierarchy::DisjointSet;
276 /// let mut ds = DisjointSet::new();
277 /// assert_eq!(ds.num_sets(), 0);
278 ///
279 /// ds.make_set(1);
280 /// ds.make_set(2);
281 /// assert_eq!(ds.num_sets(), 2);
282 ///
283 /// ds.union(1, 2);
284 /// assert_eq!(ds.num_sets(), 1);
285 /// ```
286 pub fn num_sets(&self) -> usize {
287 self.num_sets
288 }
289
290 /// Get the total number of elements
291 ///
292 /// # Returns
293 ///
294 /// The total number of elements in all sets
295 pub fn size(&self) -> usize {
296 self.parent.len()
297 }
298
299 /// Check if the disjoint set is empty
300 ///
301 /// # Returns
302 ///
303 /// * `true` if no elements have been added
304 /// * `false` otherwise
305 pub fn is_empty(&self) -> bool {
306 self.parent.is_empty()
307 }
308
309 /// Get all elements in the same set as the given element
310 ///
311 /// # Arguments
312 ///
313 /// * `x` - Element to find set members for
314 ///
315 /// # Returns
316 ///
317 /// * `Some(Vec<T>)` containing all elements in the same set
318 /// * `None` if the element doesn't exist
319 ///
320 /// # Examples
321 ///
322 /// ```
323 /// use scirs2_cluster::hierarchy::DisjointSet;
324 /// let mut ds = DisjointSet::new();
325 /// ds.make_set(1);
326 /// ds.make_set(2);
327 /// ds.make_set(3);
328 /// ds.union(1, 2);
329 ///
330 /// let set_members = ds.get_set_members(&1).unwrap();
331 /// assert_eq!(set_members.len(), 2);
332 /// assert!(set_members.contains(&1));
333 /// assert!(set_members.contains(&2));
334 /// assert!(!set_members.contains(&3));
335 /// ```
336 pub fn get_set_members(&mut self, x: &T) -> Option<Vec<T>> {
337 let target_root = self.find(x)?;
338
339 let mut members = Vec::new();
340 let elements_to_check: Vec<T> = self.parent.keys().cloned().collect();
341
342 for element in elements_to_check {
343 if let Some(root) = self.find(&element) {
344 if root == target_root {
345 members.push(element);
346 }
347 }
348 }
349
350 Some(members)
351 }
352
353 /// Get all disjoint sets as a vector of vectors
354 ///
355 /// # Returns
356 ///
357 /// A vector where each inner vector contains the elements of one set
358 ///
359 /// # Examples
360 ///
361 /// ```
362 /// use scirs2_cluster::hierarchy::DisjointSet;
363 /// let mut ds = DisjointSet::new();
364 /// ds.make_set(1);
365 /// ds.make_set(2);
366 /// ds.make_set(3);
367 /// ds.union(1, 2);
368 ///
369 /// let all_sets = ds.get_all_sets();
370 /// assert_eq!(all_sets.len(), 2); // Two disjoint sets
371 /// ```
372 pub fn get_all_sets(&mut self) -> Vec<Vec<T>> {
373 let mut sets_map: HashMap<T, Vec<T>> = HashMap::new();
374
375 // Group elements by their root
376 for element in self.parent.keys().cloned().collect::<Vec<_>>() {
377 if let Some(root) = self.find(&element) {
378 sets_map.entry(root).or_default().push(element);
379 }
380 }
381
382 sets_map.into_values().collect()
383 }
384
385 /// Clear all elements from the disjoint set
386 ///
387 /// # Examples
388 ///
389 /// ```
390 /// use scirs2_cluster::hierarchy::DisjointSet;
391 /// let mut ds = DisjointSet::new();
392 /// ds.make_set(1);
393 /// ds.make_set(2);
394 ///
395 /// assert_eq!(ds.size(), 2);
396 /// ds.clear();
397 /// assert_eq!(ds.size(), 0);
398 /// ```
399 pub fn clear(&mut self) {
400 self.parent.clear();
401 self.rank.clear();
402 self.num_sets = 0;
403 }
404}
405
406impl<T: Clone + std::hash::Hash + Eq> Default for DisjointSet<T> {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_basic_operations() {
418 let mut ds = DisjointSet::new();
419
420 // Initially empty
421 assert_eq!(ds.size(), 0);
422 assert_eq!(ds.num_sets(), 0);
423 assert!(ds.is_empty());
424
425 // Add elements
426 ds.make_set(1);
427 ds.make_set(2);
428 ds.make_set(3);
429
430 assert_eq!(ds.size(), 3);
431 assert_eq!(ds.num_sets(), 3);
432 assert!(!ds.is_empty());
433
434 // Check individual sets
435 assert!(ds.contains(&1));
436 assert!(ds.contains(&2));
437 assert!(ds.contains(&3));
438 assert!(!ds.contains(&4));
439 }
440
441 #[test]
442 fn test_union_find() {
443 let mut ds = DisjointSet::new();
444 ds.make_set(1);
445 ds.make_set(2);
446 ds.make_set(3);
447 ds.make_set(4);
448
449 // Initially all separate
450 assert!(!ds.connected(&1, &2));
451 assert!(!ds.connected(&3, &4));
452
453 // Union 1 and 2
454 assert!(ds.union(1, 2));
455 assert_eq!(ds.num_sets(), 3);
456 assert!(ds.connected(&1, &2));
457 assert!(!ds.connected(&1, &3));
458
459 // Union 3 and 4
460 assert!(ds.union(3, 4));
461 assert_eq!(ds.num_sets(), 2);
462 assert!(ds.connected(&3, &4));
463 assert!(!ds.connected(&1, &3));
464
465 // Union the two sets
466 assert!(ds.union(1, 3));
467 assert_eq!(ds.num_sets(), 1);
468 assert!(ds.connected(&1, &3));
469 assert!(ds.connected(&2, &4));
470
471 // Redundant union
472 assert!(!ds.union(1, 2));
473 assert_eq!(ds.num_sets(), 1);
474 }
475
476 #[test]
477 fn test_path_compression() {
478 let mut ds = DisjointSet::new();
479
480 // Create a chain: 1 -> 2 -> 3 -> 4
481 ds.make_set(1);
482 ds.make_set(2);
483 ds.make_set(3);
484 ds.make_set(4);
485
486 ds.union(1, 2);
487 ds.union(2, 3);
488 ds.union(3, 4);
489
490 // After find operations, path should be compressed
491 let root1 = ds.find(&1).unwrap();
492 let root2 = ds.find(&2).unwrap();
493 let root3 = ds.find(&3).unwrap();
494 let root4 = ds.find(&4).unwrap();
495
496 assert_eq!(root1, root2);
497 assert_eq!(root2, root3);
498 assert_eq!(root3, root4);
499 }
500
501 #[test]
502 fn test_get_set_members() {
503 let mut ds = DisjointSet::new();
504 ds.make_set(1);
505 ds.make_set(2);
506 ds.make_set(3);
507 ds.make_set(4);
508
509 ds.union(1, 2);
510 ds.union(3, 4);
511
512 let members1 = ds.get_set_members(&1).unwrap();
513 assert_eq!(members1.len(), 2);
514 assert!(members1.contains(&1));
515 assert!(members1.contains(&2));
516
517 let members3 = ds.get_set_members(&3).unwrap();
518 assert_eq!(members3.len(), 2);
519 assert!(members3.contains(&3));
520 assert!(members3.contains(&4));
521
522 // Non-existent element
523 assert!(ds.get_set_members(&5).is_none());
524 }
525
526 #[test]
527 fn test_get_all_sets() {
528 let mut ds = DisjointSet::new();
529 ds.make_set(1);
530 ds.make_set(2);
531 ds.make_set(3);
532 ds.make_set(4);
533 ds.make_set(5);
534
535 ds.union(1, 2);
536 ds.union(3, 4);
537 // 5 remains alone
538
539 let all_sets = ds.get_all_sets();
540 assert_eq!(all_sets.len(), 3); // Three disjoint sets
541
542 // Find which set contains which elements
543 let mut set_sizes: Vec<usize> = all_sets.iter().map(|s| s.len()).collect();
544 set_sizes.sort();
545 assert_eq!(set_sizes, vec![1, 2, 2]);
546 }
547
548 #[test]
549 fn test_edge_cases() {
550 let mut ds = DisjointSet::new();
551
552 // Union with non-existent elements
553 assert!(!ds.union(1, 2));
554
555 // Find non-existent element
556 assert!(ds.find(&1).is_none());
557
558 // Connected with non-existent elements
559 assert!(!ds.connected(&1, &2));
560
561 // Add same element twice
562 ds.make_set(1);
563 ds.make_set(1); // Should have no effect
564 assert_eq!(ds.size(), 1);
565 assert_eq!(ds.num_sets(), 1);
566 }
567
568 #[test]
569 fn test_clear() {
570 let mut ds = DisjointSet::new();
571 ds.make_set(1);
572 ds.make_set(2);
573 ds.union(1, 2);
574
575 assert_eq!(ds.size(), 2);
576 assert_eq!(ds.num_sets(), 1);
577
578 ds.clear();
579
580 assert_eq!(ds.size(), 0);
581 assert_eq!(ds.num_sets(), 0);
582 assert!(ds.is_empty());
583 }
584
585 #[test]
586 fn test_with_strings() {
587 let mut ds = DisjointSet::new();
588 ds.make_set("alice".to_string());
589 ds.make_set("bob".to_string());
590 ds.make_set("charlie".to_string());
591
592 ds.union("alice".to_string(), "bob".to_string());
593
594 assert!(ds.connected(&"alice".to_string(), &"bob".to_string()));
595 assert!(!ds.connected(&"alice".to_string(), &"charlie".to_string()));
596 }
597
598 #[test]
599 fn test_large_dataset() {
600 let mut ds = DisjointSet::with_capacity(1000);
601
602 // Add many elements
603 for i in 0..1000 {
604 ds.make_set(i);
605 }
606
607 assert_eq!(ds.size(), 1000);
608 assert_eq!(ds.num_sets(), 1000);
609
610 // Union them in pairs
611 for i in (0..1000).step_by(2) {
612 ds.union(i, i + 1);
613 }
614
615 assert_eq!(ds.num_sets(), 500);
616
617 // Check some connections
618 assert!(ds.connected(&0, &1));
619 assert!(ds.connected(&998, &999));
620 assert!(!ds.connected(&0, &2));
621 }
622}