scirs2_cluster/hierarchy/
disjoint_set.rs

1//! Disjoint Set (Union-Find) data structure for connectivity queries
2//!
3//! This module provides a disjoint set data structure that efficiently supports
4//! union and find operations. It's particularly useful for clustering algorithms
5//! that need to track connected components or merge clusters.
6//!
7//! The implementation uses path compression and union by rank optimizations
8//! for nearly O(1) amortized performance.
9
10use std::collections::HashMap;
11
12/// Disjoint Set (Union-Find) data structure
13///
14/// This data structure maintains a collection of disjoint sets and supports
15/// efficient union and find operations. It's commonly used in clustering
16/// algorithms for tracking connected components.
17///
18/// # Examples
19///
20/// ```
21/// use scirs2_cluster::hierarchy::DisjointSet;
22///
23/// let mut ds = DisjointSet::new();
24///
25/// // Add some elements
26/// ds.make_set(1);
27/// ds.make_set(2);
28/// ds.make_set(3);
29/// ds.make_set(4);
30///
31/// // Union some sets
32/// ds.union(1, 2);
33/// ds.union(3, 4);
34///
35/// // Check connectivity
36/// assert_eq!(ds.find(&1), ds.find(&2)); // 1 and 2 are connected
37/// assert_eq!(ds.find(&3), ds.find(&4)); // 3 and 4 are connected
38/// assert_ne!(ds.find(&1), ds.find(&3)); // 1 and 3 are in different sets
39/// ```
40#[derive(Debug, Clone)]
41pub struct DisjointSet<T: Clone + std::hash::Hash + Eq> {
42    /// Parent pointers for each element
43    parent: HashMap<T, T>,
44    /// Rank (approximate depth) of each tree
45    rank: HashMap<T, usize>,
46    /// Number of disjoint sets
47    num_sets: usize,
48}
49
50impl<T: Clone + std::hash::Hash + Eq> DisjointSet<T> {
51    /// Create a new empty disjoint set
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// use scirs2_cluster::hierarchy::DisjointSet;
57    /// let ds: DisjointSet<i32> = DisjointSet::new();
58    /// ```
59    pub fn new() -> Self {
60        Self {
61            parent: HashMap::new(),
62            rank: HashMap::new(),
63            num_sets: 0,
64        }
65    }
66
67    /// Create a new disjoint set with a specified capacity
68    ///
69    /// This can improve performance when you know approximately how many
70    /// elements you'll be adding.
71    ///
72    /// # Arguments
73    ///
74    /// * `capacity` - Expected number of elements
75    pub fn with_capacity(capacity: usize) -> Self {
76        Self {
77            parent: HashMap::with_capacity(capacity),
78            rank: HashMap::with_capacity(capacity),
79            num_sets: 0,
80        }
81    }
82
83    /// Add a new element as its own singleton set
84    ///
85    /// If the element already exists, this operation has no effect.
86    ///
87    /// # Arguments
88    ///
89    /// * `x` - Element to add
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use scirs2_cluster::hierarchy::DisjointSet;
95    /// let mut ds = DisjointSet::new();
96    /// ds.make_set(42);
97    /// assert!(ds.contains(&42));
98    /// ```
99    pub fn make_set(&mut self, x: T) {
100        if !self.parent.contains_key(&x) {
101            self.parent.insert(x.clone(), x.clone());
102            self.rank.insert(x, 0);
103            self.num_sets += 1;
104        }
105    }
106
107    /// Find the representative (root) of the set containing the given element
108    ///
109    /// Uses path compression for optimization: all nodes on the path to the root
110    /// are made to point directly to the root.
111    ///
112    /// # Arguments
113    ///
114    /// * `x` - Element to find the representative for
115    ///
116    /// # Returns
117    ///
118    /// * `Some(representative)` if the element exists in the structure
119    /// * `None` if the element doesn't exist
120    ///
121    /// # Examples
122    ///
123    /// ```
124    /// use scirs2_cluster::hierarchy::DisjointSet;
125    /// let mut ds = DisjointSet::new();
126    /// ds.make_set(1);
127    /// ds.make_set(2);
128    /// ds.union(1, 2);
129    ///
130    /// let root1 = ds.find(&1).unwrap();
131    /// let root2 = ds.find(&2).unwrap();
132    /// assert_eq!(root1, root2); // Same representative
133    /// ```
134    pub fn find(&mut self, x: &T) -> Option<T> {
135        if !self.parent.contains_key(x) {
136            return None;
137        }
138
139        // Path compression: make all nodes on path point to root
140        let mut current = x.clone();
141        let mut path = Vec::new();
142
143        // Find root
144        while self.parent[&current] != current {
145            path.push(current.clone());
146            current = self.parent[&current].clone();
147        }
148
149        // Compress path
150        for node in path {
151            self.parent.insert(node, current.clone());
152        }
153
154        Some(current)
155    }
156
157    /// Union two sets containing the given elements
158    ///
159    /// Uses union by rank: the root of the tree with smaller rank becomes
160    /// a child of the root with larger rank.
161    ///
162    /// # Arguments
163    ///
164    /// * `x` - Element from first set
165    /// * `y` - Element from second set
166    ///
167    /// # Returns
168    ///
169    /// * `true` if the sets were successfully unioned (they were different sets)
170    /// * `false` if the elements were already in the same set or don't exist
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// use scirs2_cluster::hierarchy::DisjointSet;
176    /// let mut ds = DisjointSet::new();
177    /// ds.make_set(1);
178    /// ds.make_set(2);
179    ///
180    /// assert!(ds.union(1, 2)); // Successfully unioned
181    /// assert!(!ds.union(1, 2)); // Already in same set
182    /// ```
183    pub fn union(&mut self, x: T, y: T) -> bool {
184        let root_x = match self.find(&x) {
185            Some(root) => root,
186            None => return false,
187        };
188
189        let root_y = match self.find(&y) {
190            Some(root) => root,
191            None => return false,
192        };
193
194        if root_x == root_y {
195            return false; // Already in same set
196        }
197
198        // Union by rank
199        let rank_x = self.rank[&root_x];
200        let rank_y = self.rank[&root_y];
201
202        match rank_x.cmp(&rank_y) {
203            std::cmp::Ordering::Less => {
204                self.parent.insert(root_x, root_y);
205            }
206            std::cmp::Ordering::Greater => {
207                self.parent.insert(root_y, root_x);
208            }
209            std::cmp::Ordering::Equal => {
210                // Same rank, make one root and increase its rank
211                self.parent.insert(root_y, root_x.clone());
212                self.rank.insert(root_x, rank_x + 1);
213            }
214        }
215
216        self.num_sets -= 1;
217        true
218    }
219
220    /// Check if two elements are in the same set
221    ///
222    /// # Arguments
223    ///
224    /// * `x` - First element
225    /// * `y` - Second element
226    ///
227    /// # Returns
228    ///
229    /// * `true` if both elements exist and are in the same set
230    /// * `false` if they're in different sets or don't exist
231    ///
232    /// # Examples
233    ///
234    /// ```
235    /// use scirs2_cluster::hierarchy::DisjointSet;
236    /// let mut ds = DisjointSet::new();
237    /// ds.make_set(1);
238    /// ds.make_set(2);
239    /// ds.make_set(3);
240    /// ds.union(1, 2);
241    ///
242    /// assert!(ds.connected(&1, &2)); // Connected
243    /// assert!(!ds.connected(&1, &3)); // Not connected
244    /// ```
245    pub fn connected(&mut self, x: &T, y: &T) -> bool {
246        match (self.find(x), self.find(y)) {
247            (Some(root_x), Some(root_y)) => root_x == root_y,
248            _ => false,
249        }
250    }
251
252    /// Check if an element exists in the disjoint set
253    ///
254    /// # Arguments
255    ///
256    /// * `x` - Element to check
257    ///
258    /// # Returns
259    ///
260    /// * `true` if the element exists
261    /// * `false` otherwise
262    pub fn contains(&self, x: &T) -> bool {
263        self.parent.contains_key(x)
264    }
265
266    /// Get the number of disjoint sets
267    ///
268    /// # Returns
269    ///
270    /// The number of disjoint sets currently in the structure
271    ///
272    /// # Examples
273    ///
274    /// ```
275    /// use scirs2_cluster::hierarchy::DisjointSet;
276    /// let mut ds = DisjointSet::new();
277    /// assert_eq!(ds.num_sets(), 0);
278    ///
279    /// ds.make_set(1);
280    /// ds.make_set(2);
281    /// assert_eq!(ds.num_sets(), 2);
282    ///
283    /// ds.union(1, 2);
284    /// assert_eq!(ds.num_sets(), 1);
285    /// ```
286    pub fn num_sets(&self) -> usize {
287        self.num_sets
288    }
289
290    /// Get the total number of elements
291    ///
292    /// # Returns
293    ///
294    /// The total number of elements in all sets
295    pub fn size(&self) -> usize {
296        self.parent.len()
297    }
298
299    /// Check if the disjoint set is empty
300    ///
301    /// # Returns
302    ///
303    /// * `true` if no elements have been added
304    /// * `false` otherwise
305    pub fn is_empty(&self) -> bool {
306        self.parent.is_empty()
307    }
308
309    /// Get all elements in the same set as the given element
310    ///
311    /// # Arguments
312    ///
313    /// * `x` - Element to find set members for
314    ///
315    /// # Returns
316    ///
317    /// * `Some(Vec<T>)` containing all elements in the same set
318    /// * `None` if the element doesn't exist
319    ///
320    /// # Examples
321    ///
322    /// ```
323    /// use scirs2_cluster::hierarchy::DisjointSet;
324    /// let mut ds = DisjointSet::new();
325    /// ds.make_set(1);
326    /// ds.make_set(2);
327    /// ds.make_set(3);
328    /// ds.union(1, 2);
329    ///
330    /// let set_members = ds.get_set_members(&1).unwrap();
331    /// assert_eq!(set_members.len(), 2);
332    /// assert!(set_members.contains(&1));
333    /// assert!(set_members.contains(&2));
334    /// assert!(!set_members.contains(&3));
335    /// ```
336    pub fn get_set_members(&mut self, x: &T) -> Option<Vec<T>> {
337        let target_root = self.find(x)?;
338
339        let mut members = Vec::new();
340        let elements_to_check: Vec<T> = self.parent.keys().cloned().collect();
341
342        for element in elements_to_check {
343            if let Some(root) = self.find(&element) {
344                if root == target_root {
345                    members.push(element);
346                }
347            }
348        }
349
350        Some(members)
351    }
352
353    /// Get all disjoint sets as a vector of vectors
354    ///
355    /// # Returns
356    ///
357    /// A vector where each inner vector contains the elements of one set
358    ///
359    /// # Examples
360    ///
361    /// ```
362    /// use scirs2_cluster::hierarchy::DisjointSet;
363    /// let mut ds = DisjointSet::new();
364    /// ds.make_set(1);
365    /// ds.make_set(2);
366    /// ds.make_set(3);
367    /// ds.union(1, 2);
368    ///
369    /// let all_sets = ds.get_all_sets();
370    /// assert_eq!(all_sets.len(), 2); // Two disjoint sets
371    /// ```
372    pub fn get_all_sets(&mut self) -> Vec<Vec<T>> {
373        let mut sets_map: HashMap<T, Vec<T>> = HashMap::new();
374
375        // Group elements by their root
376        for element in self.parent.keys().cloned().collect::<Vec<_>>() {
377            if let Some(root) = self.find(&element) {
378                sets_map.entry(root).or_default().push(element);
379            }
380        }
381
382        sets_map.into_values().collect()
383    }
384
385    /// Clear all elements from the disjoint set
386    ///
387    /// # Examples
388    ///
389    /// ```
390    /// use scirs2_cluster::hierarchy::DisjointSet;
391    /// let mut ds = DisjointSet::new();
392    /// ds.make_set(1);
393    /// ds.make_set(2);
394    ///
395    /// assert_eq!(ds.size(), 2);
396    /// ds.clear();
397    /// assert_eq!(ds.size(), 0);
398    /// ```
399    pub fn clear(&mut self) {
400        self.parent.clear();
401        self.rank.clear();
402        self.num_sets = 0;
403    }
404}
405
406impl<T: Clone + std::hash::Hash + Eq> Default for DisjointSet<T> {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_basic_operations() {
418        let mut ds = DisjointSet::new();
419
420        // Initially empty
421        assert_eq!(ds.size(), 0);
422        assert_eq!(ds.num_sets(), 0);
423        assert!(ds.is_empty());
424
425        // Add elements
426        ds.make_set(1);
427        ds.make_set(2);
428        ds.make_set(3);
429
430        assert_eq!(ds.size(), 3);
431        assert_eq!(ds.num_sets(), 3);
432        assert!(!ds.is_empty());
433
434        // Check individual sets
435        assert!(ds.contains(&1));
436        assert!(ds.contains(&2));
437        assert!(ds.contains(&3));
438        assert!(!ds.contains(&4));
439    }
440
441    #[test]
442    fn test_union_find() {
443        let mut ds = DisjointSet::new();
444        ds.make_set(1);
445        ds.make_set(2);
446        ds.make_set(3);
447        ds.make_set(4);
448
449        // Initially all separate
450        assert!(!ds.connected(&1, &2));
451        assert!(!ds.connected(&3, &4));
452
453        // Union 1 and 2
454        assert!(ds.union(1, 2));
455        assert_eq!(ds.num_sets(), 3);
456        assert!(ds.connected(&1, &2));
457        assert!(!ds.connected(&1, &3));
458
459        // Union 3 and 4
460        assert!(ds.union(3, 4));
461        assert_eq!(ds.num_sets(), 2);
462        assert!(ds.connected(&3, &4));
463        assert!(!ds.connected(&1, &3));
464
465        // Union the two sets
466        assert!(ds.union(1, 3));
467        assert_eq!(ds.num_sets(), 1);
468        assert!(ds.connected(&1, &3));
469        assert!(ds.connected(&2, &4));
470
471        // Redundant union
472        assert!(!ds.union(1, 2));
473        assert_eq!(ds.num_sets(), 1);
474    }
475
476    #[test]
477    fn test_path_compression() {
478        let mut ds = DisjointSet::new();
479
480        // Create a chain: 1 -> 2 -> 3 -> 4
481        ds.make_set(1);
482        ds.make_set(2);
483        ds.make_set(3);
484        ds.make_set(4);
485
486        ds.union(1, 2);
487        ds.union(2, 3);
488        ds.union(3, 4);
489
490        // After find operations, path should be compressed
491        let root1 = ds.find(&1).unwrap();
492        let root2 = ds.find(&2).unwrap();
493        let root3 = ds.find(&3).unwrap();
494        let root4 = ds.find(&4).unwrap();
495
496        assert_eq!(root1, root2);
497        assert_eq!(root2, root3);
498        assert_eq!(root3, root4);
499    }
500
501    #[test]
502    fn test_get_set_members() {
503        let mut ds = DisjointSet::new();
504        ds.make_set(1);
505        ds.make_set(2);
506        ds.make_set(3);
507        ds.make_set(4);
508
509        ds.union(1, 2);
510        ds.union(3, 4);
511
512        let members1 = ds.get_set_members(&1).unwrap();
513        assert_eq!(members1.len(), 2);
514        assert!(members1.contains(&1));
515        assert!(members1.contains(&2));
516
517        let members3 = ds.get_set_members(&3).unwrap();
518        assert_eq!(members3.len(), 2);
519        assert!(members3.contains(&3));
520        assert!(members3.contains(&4));
521
522        // Non-existent element
523        assert!(ds.get_set_members(&5).is_none());
524    }
525
526    #[test]
527    fn test_get_all_sets() {
528        let mut ds = DisjointSet::new();
529        ds.make_set(1);
530        ds.make_set(2);
531        ds.make_set(3);
532        ds.make_set(4);
533        ds.make_set(5);
534
535        ds.union(1, 2);
536        ds.union(3, 4);
537        // 5 remains alone
538
539        let all_sets = ds.get_all_sets();
540        assert_eq!(all_sets.len(), 3); // Three disjoint sets
541
542        // Find which set contains which elements
543        let mut set_sizes: Vec<usize> = all_sets.iter().map(|s| s.len()).collect();
544        set_sizes.sort();
545        assert_eq!(set_sizes, vec![1, 2, 2]);
546    }
547
548    #[test]
549    fn test_edge_cases() {
550        let mut ds = DisjointSet::new();
551
552        // Union with non-existent elements
553        assert!(!ds.union(1, 2));
554
555        // Find non-existent element
556        assert!(ds.find(&1).is_none());
557
558        // Connected with non-existent elements
559        assert!(!ds.connected(&1, &2));
560
561        // Add same element twice
562        ds.make_set(1);
563        ds.make_set(1); // Should have no effect
564        assert_eq!(ds.size(), 1);
565        assert_eq!(ds.num_sets(), 1);
566    }
567
568    #[test]
569    fn test_clear() {
570        let mut ds = DisjointSet::new();
571        ds.make_set(1);
572        ds.make_set(2);
573        ds.union(1, 2);
574
575        assert_eq!(ds.size(), 2);
576        assert_eq!(ds.num_sets(), 1);
577
578        ds.clear();
579
580        assert_eq!(ds.size(), 0);
581        assert_eq!(ds.num_sets(), 0);
582        assert!(ds.is_empty());
583    }
584
585    #[test]
586    fn test_with_strings() {
587        let mut ds = DisjointSet::new();
588        ds.make_set("alice".to_string());
589        ds.make_set("bob".to_string());
590        ds.make_set("charlie".to_string());
591
592        ds.union("alice".to_string(), "bob".to_string());
593
594        assert!(ds.connected(&"alice".to_string(), &"bob".to_string()));
595        assert!(!ds.connected(&"alice".to_string(), &"charlie".to_string()));
596    }
597
598    #[test]
599    fn test_large_dataset() {
600        let mut ds = DisjointSet::with_capacity(1000);
601
602        // Add many elements
603        for i in 0..1000 {
604            ds.make_set(i);
605        }
606
607        assert_eq!(ds.size(), 1000);
608        assert_eq!(ds.num_sets(), 1000);
609
610        // Union them in pairs
611        for i in (0..1000).step_by(2) {
612            ds.union(i, i + 1);
613        }
614
615        assert_eq!(ds.num_sets(), 500);
616
617        // Check some connections
618        assert!(ds.connected(&0, &1));
619        assert!(ds.connected(&998, &999));
620        assert!(!ds.connected(&0, &2));
621    }
622}