scopegraphs_prust_lib/
avl.rs

1use std::cmp::max;
2
3use crate::RefCounter;
4
5pub enum AVL<K, V = ()> {
6    Empty,
7    Node {
8        key: RefCounter<K>,
9        value: RefCounter<V>,
10        left: RefCounter<AVL<K, V>>,
11        right: RefCounter<AVL<K, V>>,
12    },
13}
14
15pub type OrderedMap<K, V> = AVL<K, V>;
16pub type OrderedSet<K> = AVL<K>;
17
18impl<K, V> Clone for AVL<K, V> {
19    fn clone(&self) -> Self {
20        match self {
21            Self::Empty => Self::Empty,
22            Self::Node {
23                key,
24                value,
25                left,
26                right,
27            } => Self::Node {
28                key: key.clone(),
29                value: value.clone(),
30                left: left.clone(),
31                right: right.clone(),
32            },
33        }
34    }
35}
36
37impl<K: Ord> AVL<K> {
38    pub fn insert(&self, value: K) -> Self {
39        self.put(value, ())
40    }
41    pub fn search(&self, value: &K) -> bool {
42        self.find(value).is_some()
43    }
44}
45
46impl<K: Ord, V> AVL<K, V> {
47    pub fn empty() -> AVL<K, V> {
48        AVL::Empty
49    }
50    pub fn is_empty(&self) -> bool {
51        matches!(self, AVL::Empty)
52    }
53    fn height(&self) -> i64 {
54        match self {
55            AVL::Empty => 0,
56            AVL::Node {
57                key: _,
58                value: _,
59                left,
60                right,
61            } => 1 + max(&left.height(), &right.height()),
62        }
63    }
64    fn diff(&self) -> i64 {
65        match self {
66            AVL::Empty => 0,
67            AVL::Node {
68                key: _,
69                value: _,
70                left,
71                right,
72            } => left.height() - right.height(),
73        }
74    }
75    pub fn find(&self, target_value: &K) -> Option<&V> {
76        match self {
77            AVL::Empty => Option::None,
78            AVL::Node {
79                key,
80                value,
81                left,
82                right,
83            } => match target_value.cmp(key) {
84                std::cmp::Ordering::Less => left.find(target_value),
85                std::cmp::Ordering::Equal => Option::Some(value.as_ref()),
86                std::cmp::Ordering::Greater => right.find(target_value),
87            },
88        }
89    }
90    fn right_rotation(&self) -> AVL<K, V> {
91        if let AVL::Node {
92            key: x,
93            value: vx,
94            left: lt,
95            right: t3,
96        } = self
97        {
98            if let AVL::Node {
99                key: y,
100                value: vy,
101                left: t1,
102                right: t2,
103            } = (*lt).as_ref()
104            {
105                return AVL::Node {
106                    key: y.clone(),
107                    left: t1.clone(),
108                    value: vy.clone(),
109                    right: RefCounter::new(AVL::Node {
110                        key: x.clone(),
111                        value: vx.clone(),
112                        left: t2.clone(),
113                        right: t3.clone(),
114                    }),
115                };
116            }
117        }
118        self.clone()
119    }
120    fn right_fix(&self) -> AVL<K, V> {
121        if let AVL::Node {
122            key: x,
123            value: vx,
124            left: t1,
125            right: t2,
126        } = self
127        {
128            if t1.diff() == -1 {
129                return AVL::Node {
130                    key: x.clone(),
131                    value: vx.clone(),
132                    left: RefCounter::new(t1.left_rotation()),
133                    right: t2.clone(),
134                }
135                .right_rotation();
136            } else {
137                return self.right_rotation();
138            }
139        }
140        self.clone()
141    }
142    fn left_rotation(&self) -> AVL<K, V> {
143        if let AVL::Node {
144            key: x,
145            value: vx,
146            left: t1,
147            right: rt,
148        } = self
149        {
150            if let AVL::Node {
151                key: y,
152                value: vy,
153                left: t2,
154                right: t3,
155            } = (*rt).as_ref()
156            {
157                return AVL::Node {
158                    key: y.clone(),
159                    value: vy.clone(),
160                    left: RefCounter::new(AVL::Node {
161                        key: x.clone(),
162                        value: vx.clone(),
163                        left: t1.clone(),
164                        right: t2.clone(),
165                    }),
166                    right: t3.clone(),
167                };
168            }
169        }
170        self.clone()
171    }
172    fn left_fix(&self) -> AVL<K, V> {
173        if let AVL::Node {
174            key: x,
175            value: vx,
176            left: t1,
177            right: t2,
178        } = self
179        {
180            if t2.diff() == 1 {
181                return AVL::Node {
182                    key: x.clone(),
183                    value: vx.clone(),
184                    left: t1.clone(),
185                    right: RefCounter::new(t2.right_rotation()),
186                }
187                .left_rotation();
188            } else {
189                return self.left_rotation();
190            }
191        }
192        self.clone()
193    }
194    fn fix(&self) -> AVL<K, V> {
195        match self.diff() {
196            2 => self.right_fix(),
197            -2 => self.left_fix(),
198            _ => self.clone(),
199        }
200    }
201    pub fn put(&self, key: K, value: V) -> AVL<K, V> {
202        self.put_rc(RefCounter::new(key), RefCounter::new(value))
203    }
204    fn put_rc(&self, key_rc: RefCounter<K>, value_rc: RefCounter<V>) -> AVL<K, V> {
205        match self {
206            AVL::Empty => AVL::Node {
207                key: key_rc,
208                value: value_rc,
209                left: RefCounter::new(AVL::Empty),
210                right: RefCounter::new(AVL::Empty),
211            },
212            AVL::Node {
213                key,
214                value,
215                left,
216                right,
217            } => match key_rc.cmp(key) {
218                std::cmp::Ordering::Less => AVL::Node {
219                    key: key.clone(),
220                    value: value.clone(),
221                    left: RefCounter::new(left.put_rc(key_rc, value_rc)),
222                    right: right.clone(),
223                }
224                .fix(),
225                std::cmp::Ordering::Equal => AVL::Node {
226                    key: key_rc,
227                    value: value_rc,
228                    left: left.clone(),
229                    right: right.clone(),
230                },
231                std::cmp::Ordering::Greater => AVL::Node {
232                    key: key.clone(),
233                    value: value.clone(),
234                    left: left.clone(),
235                    right: RefCounter::new(right.put_rc(key_rc, value_rc)),
236                }
237                .fix(),
238            },
239        }
240    }
241    pub fn delete(&self, target_key: &K) -> AVL<K, V> {
242        match self {
243            AVL::Empty => AVL::Empty,
244            AVL::Node {
245                key,
246                value,
247                left,
248                right,
249            } => {
250                match target_key.cmp(key) {
251                    std::cmp::Ordering::Less => {
252                        let left_deleted = left.delete(target_key);
253                        AVL::Node {
254                            key: key.clone(),
255                            value: value.clone(),
256                            left: RefCounter::new(left_deleted),
257                            right: right.clone(),
258                        }
259                        .fix()
260                    }
261                    std::cmp::Ordering::Equal => {
262                        // Node with only one child or no child
263                        if left.is_empty() {
264                            return right.as_ref().clone();
265                        } else if right.is_empty() {
266                            return left.as_ref().clone();
267                        }
268
269                        // Node with two children, get the inorder predecessor (maximum value in the left subtree)
270                        let inorder_predecessor = left.find_max();
271                        if let Some((pred_key, pred_value)) = inorder_predecessor {
272                            let left_deleted = left.delete(&pred_key);
273                            AVL::Node {
274                                key: pred_key.clone(),
275                                value: pred_value.clone(),
276                                left: RefCounter::new(left_deleted),
277                                right: right.clone(),
278                            }
279                            .fix()
280                        } else {
281                            self.clone()
282                        }
283                    }
284                    std::cmp::Ordering::Greater => {
285                        let right_deleted = right.delete(target_key);
286                        AVL::Node {
287                            key: key.clone(),
288                            value: value.clone(),
289                            left: left.clone(),
290                            right: RefCounter::new(right_deleted),
291                        }
292                        .fix()
293                    }
294                }
295            }
296        }
297    }
298
299    fn find_max(&self) -> Option<(RefCounter<K>, RefCounter<V>)> {
300        match self {
301            AVL::Empty => None,
302            AVL::Node {
303                key,
304                value,
305                left: _,
306                right,
307            } => {
308                if right.is_empty() {
309                    Some((key.clone(), value.clone()))
310                } else {
311                    right.find_max()
312                }
313            }
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_avl_set() {
324        let l = AVL::empty().insert(1).insert(2).insert(3).insert(4);
325        let l2 = l.clone().insert(5);
326        for i in 1..=4 {
327            assert!(l.search(&i));
328            assert!(l2.search(&i));
329        }
330        assert!(!l.search(&5));
331        assert!(l2.search(&5));
332    }
333
334    #[test]
335    fn test_avl_map() {
336        let l = AVL::empty().put(1, 999);
337        let l2 = l.clone().put(1, 123).put(2, 3);
338        assert_eq!(l.find(&1), Some(&999));
339        assert_eq!(l2.find(&1), Some(&123));
340        assert!(l.find(&2).is_none());
341        assert!(l2.find(&2).is_some());
342    }
343
344    #[test]
345    fn test_avl_delete() {
346        let l = AVL::empty()
347            .insert(1)
348            .insert(2)
349            .insert(3)
350            .insert(4)
351            .insert(5);
352        let l = l.delete(&3);
353        assert!(!l.search(&3));
354        assert!(l.search(&1));
355        assert!(l.search(&2));
356        assert!(l.search(&4));
357        assert!(l.search(&5));
358    }
359}