Skip to main content

u_numflow/collections/
union_find.rs

1//! Disjoint-set (Union-Find) data structure.
2//!
3//! Maintains a collection of disjoint sets over elements `0..n` with
4//! near-constant-time union and find operations.
5//!
6//! # Algorithm
7//!
8//! Uses **path compression** during `find` and **union by rank** during
9//! `union` to achieve amortized O(α(n)) per operation, where α is the
10//! inverse Ackermann function.
11//!
12//! For all practical input sizes (n < 2^65536), α(n) ≤ 4, so operations
13//! are effectively O(1).
14//!
15//! # References
16//!
17//! - Tarjan (1975), "Efficiency of a Good but Not Linear Set Union Algorithm"
18//! - Tarjan & van Leeuwen (1984), "Worst-Case Analysis of Set Union Algorithms"
19
20/// Disjoint-set forest with path compression and union by rank.
21///
22/// # Examples
23/// ```
24/// use u_numflow::collections::UnionFind;
25///
26/// let mut uf = UnionFind::new(5);
27/// assert_eq!(uf.component_count(), 5);
28///
29/// uf.union(0, 1);
30/// uf.union(2, 3);
31/// assert_eq!(uf.component_count(), 3);
32///
33/// assert!(uf.connected(0, 1));
34/// assert!(!uf.connected(0, 2));
35///
36/// uf.union(1, 3);
37/// assert!(uf.connected(0, 2)); // transitivity
38/// assert_eq!(uf.component_count(), 2);
39/// ```
40#[derive(Debug, Clone)]
41pub struct UnionFind {
42    parent: Vec<usize>,
43    rank: Vec<u8>,
44    size: Vec<usize>,
45    components: usize,
46}
47
48impl UnionFind {
49    /// Creates a new Union-Find with `n` disjoint singleton sets `{0}, {1}, ..., {n-1}`.
50    ///
51    /// # Complexity
52    /// O(n)
53    pub fn new(n: usize) -> Self {
54        Self {
55            parent: (0..n).collect(),
56            rank: vec![0; n],
57            size: vec![1; n],
58            components: n,
59        }
60    }
61
62    /// Returns the number of elements.
63    pub fn len(&self) -> usize {
64        self.parent.len()
65    }
66
67    /// Returns `true` if there are no elements.
68    pub fn is_empty(&self) -> bool {
69        self.parent.is_empty()
70    }
71
72    /// Finds the representative (root) of the set containing `x`.
73    ///
74    /// Applies **path compression**: every node on the path from `x` to
75    /// the root is made a direct child of the root.
76    ///
77    /// # Complexity
78    /// Amortized O(α(n))
79    ///
80    /// # Panics
81    /// Panics if `x >= len()`.
82    pub fn find(&mut self, x: usize) -> usize {
83        if self.parent[x] != x {
84            self.parent[x] = self.find(self.parent[x]);
85        }
86        self.parent[x]
87    }
88
89    /// Merges the sets containing `x` and `y`.
90    ///
91    /// Uses **union by rank**: the tree with smaller rank is attached
92    /// under the root of the tree with larger rank.
93    ///
94    /// # Returns
95    /// `true` if `x` and `y` were in different sets (and are now merged),
96    /// `false` if they were already in the same set.
97    ///
98    /// # Complexity
99    /// Amortized O(α(n))
100    ///
101    /// # Panics
102    /// Panics if `x >= len()` or `y >= len()`.
103    pub fn union(&mut self, x: usize, y: usize) -> bool {
104        let root_x = self.find(x);
105        let root_y = self.find(y);
106
107        if root_x == root_y {
108            return false;
109        }
110
111        // Union by rank
112        match self.rank[root_x].cmp(&self.rank[root_y]) {
113            std::cmp::Ordering::Less => {
114                self.parent[root_x] = root_y;
115                self.size[root_y] += self.size[root_x];
116            }
117            std::cmp::Ordering::Greater => {
118                self.parent[root_y] = root_x;
119                self.size[root_x] += self.size[root_y];
120            }
121            std::cmp::Ordering::Equal => {
122                self.parent[root_y] = root_x;
123                self.size[root_x] += self.size[root_y];
124                self.rank[root_x] += 1;
125            }
126        }
127
128        self.components -= 1;
129        true
130    }
131
132    /// Returns `true` if `x` and `y` are in the same set.
133    ///
134    /// # Complexity
135    /// Amortized O(α(n))
136    pub fn connected(&mut self, x: usize, y: usize) -> bool {
137        self.find(x) == self.find(y)
138    }
139
140    /// Returns the number of disjoint sets.
141    ///
142    /// # Complexity
143    /// O(1)
144    pub fn component_count(&self) -> usize {
145        self.components
146    }
147
148    /// Returns the size of the set containing `x`.
149    ///
150    /// # Complexity
151    /// Amortized O(α(n))
152    pub fn component_size(&mut self, x: usize) -> usize {
153        let root = self.find(x);
154        self.size[root]
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_new() {
164        let uf = UnionFind::new(5);
165        assert_eq!(uf.len(), 5);
166        assert_eq!(uf.component_count(), 5);
167    }
168
169    #[test]
170    fn test_new_empty() {
171        let uf = UnionFind::new(0);
172        assert_eq!(uf.len(), 0);
173        assert!(uf.is_empty());
174        assert_eq!(uf.component_count(), 0);
175    }
176
177    #[test]
178    fn test_find_initial() {
179        let mut uf = UnionFind::new(5);
180        for i in 0..5 {
181            assert_eq!(uf.find(i), i);
182        }
183    }
184
185    #[test]
186    fn test_union_basic() {
187        let mut uf = UnionFind::new(5);
188        assert!(uf.union(0, 1));
189        assert!(uf.connected(0, 1));
190        assert_eq!(uf.component_count(), 4);
191    }
192
193    #[test]
194    fn test_union_same_set() {
195        let mut uf = UnionFind::new(5);
196        uf.union(0, 1);
197        assert!(!uf.union(0, 1)); // already same set
198        assert_eq!(uf.component_count(), 4);
199    }
200
201    #[test]
202    fn test_transitivity() {
203        let mut uf = UnionFind::new(5);
204        uf.union(0, 1);
205        uf.union(1, 2);
206        assert!(uf.connected(0, 2));
207    }
208
209    #[test]
210    fn test_not_connected() {
211        let mut uf = UnionFind::new(5);
212        uf.union(0, 1);
213        uf.union(2, 3);
214        assert!(!uf.connected(0, 2));
215        assert!(!uf.connected(1, 3));
216    }
217
218    #[test]
219    fn test_merge_components() {
220        let mut uf = UnionFind::new(5);
221        uf.union(0, 1);
222        uf.union(2, 3);
223        assert_eq!(uf.component_count(), 3);
224
225        uf.union(1, 3); // merge two components
226        assert_eq!(uf.component_count(), 2);
227        assert!(uf.connected(0, 2));
228        assert!(uf.connected(0, 3));
229    }
230
231    #[test]
232    fn test_component_size() {
233        let mut uf = UnionFind::new(5);
234        assert_eq!(uf.component_size(0), 1);
235
236        uf.union(0, 1);
237        assert_eq!(uf.component_size(0), 2);
238        assert_eq!(uf.component_size(1), 2);
239
240        uf.union(0, 2);
241        assert_eq!(uf.component_size(0), 3);
242        assert_eq!(uf.component_size(2), 3);
243    }
244
245    #[test]
246    fn test_all_in_one() {
247        let mut uf = UnionFind::new(5);
248        for i in 0..4 {
249            uf.union(i, i + 1);
250        }
251        assert_eq!(uf.component_count(), 1);
252        assert_eq!(uf.component_size(0), 5);
253        for i in 0..5 {
254            for j in 0..5 {
255                assert!(uf.connected(i, j));
256            }
257        }
258    }
259
260    #[test]
261    fn test_single_element() {
262        let mut uf = UnionFind::new(1);
263        assert_eq!(uf.find(0), 0);
264        assert_eq!(uf.component_count(), 1);
265        assert_eq!(uf.component_size(0), 1);
266    }
267}
268
269#[cfg(test)]
270mod proptests {
271    use super::*;
272    use proptest::prelude::*;
273
274    proptest! {
275        #![proptest_config(ProptestConfig::with_cases(300))]
276
277        #[test]
278        fn union_find_transitivity(
279            n in 2_usize..20,
280            ops in proptest::collection::vec((0_usize..20, 0_usize..20), 0..50),
281        ) {
282            let mut uf = UnionFind::new(n);
283            for &(x, y) in &ops {
284                if x < n && y < n {
285                    uf.union(x, y);
286                }
287            }
288
289            // Verify transitivity
290            for x in 0..n {
291                for y in 0..n {
292                    for z in 0..n {
293                        if uf.connected(x, y) && uf.connected(y, z) {
294                            prop_assert!(
295                                uf.connected(x, z),
296                                "transitivity violated: {x}~{y} and {y}~{z} but not {x}~{z}"
297                            );
298                        }
299                    }
300                }
301            }
302        }
303
304        #[test]
305        fn component_count_invariant(
306            n in 1_usize..20,
307            ops in proptest::collection::vec((0_usize..20, 0_usize..20), 0..50),
308        ) {
309            let mut uf = UnionFind::new(n);
310            let mut expected_components = n;
311
312            for &(x, y) in &ops {
313                if x < n && y < n {
314                    let merged = uf.union(x, y);
315                    if merged {
316                        expected_components -= 1;
317                    }
318                }
319            }
320
321            prop_assert_eq!(uf.component_count(), expected_components);
322        }
323
324        #[test]
325        fn component_sizes_sum_to_n(
326            n in 1_usize..20,
327            ops in proptest::collection::vec((0_usize..20, 0_usize..20), 0..30),
328        ) {
329            let mut uf = UnionFind::new(n);
330            for &(x, y) in &ops {
331                if x < n && y < n {
332                    uf.union(x, y);
333                }
334            }
335
336            // Sum of sizes of all roots should equal n
337            let mut total = 0;
338            for i in 0..n {
339                if uf.find(i) == i {
340                    total += uf.component_size(i);
341                }
342            }
343            prop_assert_eq!(total, n, "component sizes should sum to n");
344        }
345    }
346}