yui_core/misc/
union_find.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::rc::Rc;
4
5use itertools::Itertools;
6
7pub struct UnionFind { 
8    p: Vec<usize>
9}
10
11impl UnionFind { 
12    pub fn new(n: usize) -> Self { 
13        Self { p: (0..n).collect() }
14    }
15
16    pub fn extend(&mut self, l: usize) { 
17        let n = self.p.len();
18        self.p.extend(n .. n + l);
19    }
20
21    pub fn size(&self) -> usize { 
22        self.p.len()
23    }
24
25    pub fn root(&self, i: usize) -> usize { 
26        let p = self.p[i];
27        if p == i { 
28            i
29        } else { 
30            self.root(p)
31        }
32    }
33    
34    pub fn is_same(&self, i: usize, j: usize) -> bool { 
35        self.root(i) == self.root(j)
36    }
37
38    pub fn union(&mut self, i: usize, j: usize) {
39        use std::cmp::Ordering::*;
40        let ri = self.root(i);
41        let rj = self.root(j);
42
43        match usize::cmp(&ri, &rj) {
44            Less    => self.p[rj] = ri,
45            Equal   => (),
46            Greater => self.p[ri] = rj,
47        }
48    }
49
50    pub fn group(&self) -> Vec<Vec<usize>> { 
51        let n = self.size();
52        (0..n).into_group_map_by(|&i| self.root(i)).into_iter().sorted_by_key(|&(i, _)| i).map(|(_, l)| l).collect()
53    }
54}
55
56pub struct KeyedUnionFind<X> where X: Eq + Hash { 
57    inner: UnionFind,
58    keys: Vec<Rc<X>>,
59    dict: HashMap<Rc<X>, usize>
60}
61
62impl<X> KeyedUnionFind<X> where X: Eq + Hash { 
63    pub fn new() -> Self { 
64        Self { inner: UnionFind::new(0), keys: vec![], dict: HashMap::new() }
65    }
66
67    pub fn insert(&mut self, x: X) -> usize { 
68        let i = self.size();
69        let x = Rc::new(x);
70
71        self.inner.extend(1);
72        self.keys.push(Rc::clone(&x));
73        self.dict.insert(x, i);
74
75        i
76    }
77
78    fn index_of(&self, x: &X) -> usize { 
79        self.dict[x]
80    }
81
82    fn element_at(&self, i: usize) -> &X { 
83        &self.keys[i]
84    }
85
86    pub fn size(&self) -> usize { 
87        self.inner.size()
88    }
89
90    pub fn contains(&self, x: &X) -> bool { 
91        self.dict.contains_key(x)
92    }
93
94    pub fn root(&self, x: &X) -> &X { 
95        let i = self.index_of(x);
96        let j = self.inner.root(i);
97        self.element_at(j)
98    }
99    
100    pub fn is_same(&self, x: &X, y: &X) -> bool { 
101        self.root(x) == self.root(y)
102    }
103
104    pub fn union(&mut self, x: &X, y: &X) {
105        let i = self.index_of(x);
106        let j = self.index_of(y);
107        self.inner.union(i, j);
108    }
109
110    pub fn group(&self) -> Vec<Vec<&X>> { 
111        self.inner.group().iter().map(|l| 
112            l.iter().map(|&i| 
113                self.element_at(i)
114            ).collect()
115        ).collect()
116    }
117
118    pub fn into_group(mut self) -> Vec<Vec<X>> { 
119        let group = self.inner.group();
120        let keys = std::mem::take(&mut self.keys);
121        
122        std::mem::drop(self);
123
124        let mut map = keys.into_iter().enumerate().collect::<HashMap<_, _>>();
125
126        group.iter().map(|l| 
127            l.iter().map(|i| {
128                let x = map.remove(i).unwrap();
129                let x = Rc::into_inner(x).unwrap();
130                x
131            }).collect()
132        ).collect()
133    }
134}
135
136impl<X> FromIterator<X> for KeyedUnionFind<X>
137where X: Hash + Eq {
138    fn from_iter<T: IntoIterator<Item = X>>(keys: T) -> Self {
139        let keys = keys.into_iter().map(|e| Rc::new(e)).collect_vec();
140        let dict = keys.iter().enumerate().map(|(i, e)| (Rc::clone(e), i)).collect();
141        let n = keys.len();
142        let inner = UnionFind::new(n);
143        Self { inner, keys, dict }
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use itertools::Itertools;
150
151    use super::*;
152 
153    #[test]
154    fn test() { 
155        let mut u = UnionFind::new(4);
156        
157        assert_eq!(u.size(), 4);
158        assert_eq!(&u.p, &vec![0,1,2,3]);
159        assert_eq!((0..4).map(|i| u.root(i)).collect_vec(), vec![0,1,2,3]);
160
161        assert!(!u.is_same(0, 1));
162        assert!(!u.is_same(1, 2));
163        assert!(!u.is_same(2, 3));
164        assert_eq!(u.group(), vec![vec![0], vec![1], vec![2], vec![3]]);
165
166        u.union(0, 1);
167
168        assert_eq!(&u.p, &vec![0,0,2,3]);
169        assert_eq!((0..4).map(|i| u.root(i)).collect_vec(), vec![0,0,2,3]);
170
171        assert!( u.is_same(0, 1));
172        assert!(!u.is_same(1, 2));
173        assert!(!u.is_same(2, 3));
174        assert_eq!(u.group(), vec![vec![0, 1], vec![2], vec![3]]);
175
176        u.union(2, 3);
177
178        assert_eq!(&u.p, &vec![0,0,2,2]);
179        assert_eq!((0..4).map(|i| u.root(i)).collect_vec(), vec![0,0,2,2]);
180
181        assert!( u.is_same(0, 1));
182        assert!(!u.is_same(1, 2));
183        assert!( u.is_same(2, 3));
184        assert_eq!(u.group(), vec![vec![0, 1], vec![2, 3]]);
185
186        u.union(1, 3);
187
188        assert_eq!(&u.p, &vec![0,0,0,2]);
189        assert_eq!((0..4).map(|i| u.root(i)).collect_vec(), vec![0,0,0,0]);
190
191        assert!( u.is_same(0, 1));
192        assert!( u.is_same(1, 2));
193        assert!( u.is_same(2, 3));
194        assert_eq!(u.group(), vec![vec![0, 1, 2, 3]]);
195    }
196
197    #[test]
198    fn test_hash() { 
199        let mut u = KeyedUnionFind::from_iter(["a", "b", "c", "d"]);
200        
201        assert_eq!(u.size(), 4);
202        assert_eq!(u.root(&"a"), &"a");
203        assert_eq!(u.root(&"b"), &"b");
204        assert_eq!(u.root(&"c"), &"c");
205        assert_eq!(u.root(&"d"), &"d");
206
207        assert!(!u.is_same(&"a", &"b"));
208        assert!(!u.is_same(&"b", &"c"));
209        assert!(!u.is_same(&"c", &"d"));
210        assert_eq!(u.group(), vec![vec![&"a"], vec![&"b"], vec![&"c"], vec![&"d"]]);
211
212        u.union(&"a", &"b");
213
214        assert_eq!(u.root(&"a"), &"a");
215        assert_eq!(u.root(&"b"), &"a");
216        assert_eq!(u.root(&"c"), &"c");
217        assert_eq!(u.root(&"d"), &"d");
218
219        assert!( u.is_same(&"a", &"b"));
220        assert!(!u.is_same(&"b", &"c"));
221        assert!(!u.is_same(&"c", &"d"));
222        assert_eq!(u.group(), vec![vec![&"a", &"b"], vec![&"c"], vec![&"d"]]);
223
224        u.union(&"c", &"d");
225
226        assert_eq!(u.root(&"a"), &"a");
227        assert_eq!(u.root(&"b"), &"a");
228        assert_eq!(u.root(&"c"), &"c");
229        assert_eq!(u.root(&"d"), &"c");
230
231        assert!( u.is_same(&"a", &"b"));
232        assert!(!u.is_same(&"b", &"c"));
233        assert!( u.is_same(&"c", &"d"));
234        assert_eq!(u.group(), vec![vec![&"a", &"b"], vec![&"c", &"d"]]);
235
236        u.union(&"b", &"d");
237
238        assert_eq!(u.root(&"a"), &"a");
239        assert_eq!(u.root(&"b"), &"a");
240        assert_eq!(u.root(&"c"), &"a");
241        assert_eq!(u.root(&"d"), &"a");
242
243        assert!(u.is_same(&"a", &"b"));
244        assert!(u.is_same(&"b", &"c"));
245        assert!(u.is_same(&"c", &"d"));
246
247        assert_eq!(u.group(), vec![vec![&"a", &"b", &"c", &"d"]]);
248        assert_eq!(u.into_group(), vec![vec!["a", "b", "c", "d"]]);
249    }
250}