rust_black_trees/
avltree.rs

1use std::cell::RefCell;
2use std::rc::Rc;
3
4use super::node::{endpaint, paint};
5use super::tree::BaseTree;
6use super::tree::Tree;
7
8use super::node::Node;
9use super::node::*;
10
11const TREE_END: usize = 0xFFFFFFFF;
12
13/// a nice convenient macro which allows a user to initialize a tree with
14/// a number of elements
15/// usage: redblack!{1, 2, 3, 4, 5, 6, 7, 8, 9, 0};
16#[macro_export]
17macro_rules! avl {
18    ( $( $x:expr ),* ) => {
19        {
20            let mut temp_tree = AVLTree::new();
21            $(
22                temp_tree.insert($x);
23            )*
24            temp_tree
25        }
26    };
27}
28
29#[derive(Debug)]
30pub struct AVLNode<T> {
31    pub value: T,
32    pub ptr: usize,
33    pub parent: Option<usize>,
34    pub lchild: Option<usize>,
35    pub rchild: Option<usize>,
36    data: Rc<RefCell<Vec<AVLNode<T>>>>,
37    // For AVL nodes...
38    pub height: usize,
39    pub balance_factor: isize,
40}
41
42impl<T> AVLNode<T> {
43    fn new(val: T, selfptr: usize, data: Rc<RefCell<Vec<AVLNode<T>>>>) -> Self {
44        Self {
45            value: val,
46            ptr: selfptr,
47            parent: None,
48            lchild: None,
49            rchild: None,
50            data: data,
51            height: 1,
52            balance_factor: 0,
53        }
54    }
55}
56
57impl<T: std::fmt::Debug + std::cmp::PartialOrd> Node<T> for AVLNode<T> {
58    fn to_self_string(&self) -> String {
59        format!(
60            "[V:{:?} H:{:?} BF:{:?}]",
61            self.value, self.height, self.balance_factor
62        )
63    }
64    fn to_self_string_display(&self) -> (String, usize) {
65        const GRN: usize = 2;
66        const YEL: usize = 3;
67        const BLU: usize = 4;
68        const BLK: usize = 0;
69        const WHT: usize = 7;
70        const FG: usize = 30;
71        const BG: usize = 40;
72        let col = match self.balance_factor {
73            -1 => BLU,
74            1 => YEL,
75            0 => GRN,
76            _ => WHT,
77        };
78        (
79            format!(
80                "{}{:?}{}",
81                paint(FG + BLK, BG + col),
82                self.value,
83                endpaint()
84            ),
85            format!("{:?}", self.value).len(),
86        )
87    }
88    fn get_value(&self) -> &T {
89        &self.value
90    }
91
92    fn is(&self, val: &T) -> bool {
93        &self.value == val
94    }
95    fn greater(&self, val: &T) -> bool {
96        &self.value > val
97    }
98    fn lesser(&self, val: &T) -> bool {
99        &self.value < val
100    }
101
102    /**
103     * In order to return a reference to a value of a vector contained within a
104     * refcell, a raw pointer is used. The unsafe code could be avoided by
105     * replacing each call to self.get(n) with &self.data.borrow()[n] and each call
106     * to self.get_mut(n) with &mut self.data.borrow()[n]
107     */
108    fn get(&self, ptr: usize) -> &AVLNode<T> {
109        unsafe { &(*self.data.as_ptr())[ptr] }
110    }
111
112    fn get_mut(&self, ptr: usize) -> &mut AVLNode<T> {
113        unsafe { &mut (*self.data.as_ptr())[ptr] }
114    }
115
116    fn get_child(&self, side: Side) -> Option<usize> {
117        match side {
118            Side::Left => self.lchild,
119            Side::Right => self.rchild,
120        }
121    }
122
123    fn set_child(&mut self, child: usize, side: Side) {
124        self.set_child_opt(Some(child), side)
125    }
126
127    fn set_child_opt(&mut self, c: Option<usize>, side: Side) {
128        match side {
129            Side::Left => self.lchild = c,
130            Side::Right => self.rchild = c,
131        };
132        if let Some(child) = c {
133            self.get_mut(child).parent = Some(self.location());
134        }
135    }
136    fn set_parent(&mut self, p: Option<usize>) {
137        self.parent = p;
138    }
139
140    fn get_parent(&self) -> Option<usize> {
141        self.parent
142    }
143
144    fn location(&self) -> usize {
145        self.ptr
146    }
147}
148
149/**
150 * Arena based memory tree structure
151*/
152#[derive(Debug)]
153pub struct AVLTree<T> {
154    root: Option<usize>,
155    size: usize,
156    data: Rc<RefCell<Vec<AVLNode<T>>>>,
157    free: Vec<usize>,
158}
159
160impl<T> Tree<T> for AVLTree<T>
161where
162    T: PartialOrd,
163    T: PartialEq,
164    T: std::fmt::Debug,
165{
166    fn new() -> Self {
167        Self {
168            root: None,
169            data: Rc::new(RefCell::new(Vec::new())),
170            size: 0,
171            free: Vec::new(),
172        }
173    }
174}
175
176impl<T> BaseTree<T> for AVLTree<T>
177where
178    T: PartialOrd,
179    T: PartialEq,
180    T: std::fmt::Debug,
181{
182    type MNode = AVLNode<T>;
183    /**
184     * In order to return a reference to a value of a vector contained within a refcell, a raw
185     * pointer is used. The unsafe code could be avoided by replacing each call to self.get(n) with
186     * &self.data.borrow()[n] and each call to self.get_mut(n) with &mut self.data.borrow()[n]. This
187     * allows us to do the same thing with less keystrokes. It does make the program not
188     * thread-safe, but a this data structure is a pretty terrible choice for a multi-threaded data
189     * structure anyways, since re-balancing can require that most of the tree be locked to one
190     * thread during an insertion or deletion
191     */
192    fn get(&self, val: usize) -> &Self::MNode {
193        unsafe { &(*self.data.as_ptr())[val] }
194    }
195
196    fn get_mut(&self, val: usize) -> &mut Self::MNode {
197        unsafe { &mut (*self.data.as_ptr())[val] }
198    }
199
200    fn get_root(&self) -> Option<usize> {
201        self.root
202    }
203
204    fn set_root(&mut self, new_root: Option<usize>) {
205        self.root = new_root
206    }
207
208    fn crement_size(&mut self, amount: isize) {
209        self.size = (self.size as isize + amount) as usize;
210    }
211
212    fn attach_child(&self, p: usize, c: usize, side: Side) {
213        self.get_mut(p).set_child(c, side)
214    }
215
216    fn rebalance_ins(&mut self, n: usize) {
217        self.retrace(n);
218    }
219
220    fn rebalance_del(&mut self, n: usize, _child: usize) {
221        if self.get_mut(n).ptr == TREE_END {
222            self.slow_delete();
223        } else {
224            self.del_retrace(n);
225            if let Some(r) = self.root {
226                self.traverse_to_fix(r);
227            }
228        }
229    }
230
231    fn delete_replace(&mut self, n: usize) -> usize {
232        self.get_mut(n).ptr = TREE_END;
233        n
234    }
235
236    fn replace_node(&mut self, _to_delete: usize, _to_attach: Option<usize>) {
237    }
238
239    fn get_size(&self) -> usize {
240        return self.size;
241    }
242
243    fn create_node(&mut self, val: T) -> usize {
244        let loc = self.data.borrow().len();
245        self.data
246            .borrow_mut()
247            .push(AVLNode::new(val, loc, self.data.clone()));
248        loc
249    }
250
251    fn delete_node(&mut self, index: usize) {
252        self.free.push(index);
253    }
254}
255
256impl<T> AVLTree<T>
257where
258    T: PartialOrd,
259    T: PartialEq,
260    T: std::fmt::Debug,
261{
262    fn del_retrace(&mut self, n: usize) {
263        loop {
264            let x = self.get(n).parent;
265            if !x.is_some() {
266                return;
267            }
268            let x: usize = x.expect("Deletion retrace get z parent");
269            //println!("n v:{:?}", self.get(n).value);
270            //println!("{}", self.to_pretty_string());
271            if self.get(n).is_child(Side::Left) {
272                if self.is_heavy_on_side(Side::Right, x) {
273                    // Sibling of N (higher by 2)
274                    if let Some(z) = self.get(n).get_sibling() {
275                        if self.is_heavy_on_side(Side::Left, z) {
276                            self.avl_rotate(Side::Right, z);
277                            self.avl_rotate(Side::Left, x);
278                        } else {
279                            self.avl_rotate(Side::Left, x);
280                        }
281                    } else {
282                        //println!("THIS IS SKETCHY");
283                        //self.del_retrace(x);
284                        self.avl_rotate(Side::Left, x);
285                    }
286                } else {
287                    if self.calc_bal_fac(x) == 0 {
288                        self.set_balance_factor(x, 1);
289                        break;
290                    }
291                    self.set_balance_factor(n, 0);
292                    //N = X; //
293                    self.del_retrace(x);
294                }
295            } else {
296                if self.is_heavy_on_side(Side::Left, x) {
297                    // Sibling of N (higher by 2)
298                    if let Some(z) = self.get(n).get_sibling() {
299                        if self.is_heavy_on_side(Side::Right, z) {
300                            self.avl_rotate(Side::Left, z);
301                            self.avl_rotate(Side::Right, x);
302                        } else {
303                            self.avl_rotate(Side::Right, x);
304                        }
305                    } else {
306                        //println!("THIS IS SKETCHY");
307                        //self.del_retrace(x);
308                        self.avl_rotate(Side::Right, x);
309                    }
310                } else {
311                    if self.calc_bal_fac(x) == 0 {
312                        self.set_balance_factor(x, -1);
313                        break; // Leave the loop
314                    }
315                    self.set_balance_factor(n, 0);
316                    //N = X;
317                    self.del_retrace(x);
318                }
319            }
320            break;
321        }
322    }
323
324    fn retrace(&mut self, z: usize) {
325        //println!("Z= {:?}", self.get(z).value);
326        //println!("X= {:?}", self.get(x).value);
327        // get the parent of current node
328        let x = self.get(z).parent;
329        if !x.is_some() {
330            // current node z is the root of the tree
331            // nothing to do, return?
332            return;
333        }
334        let x: usize = x.expect("Retrace get z parent");
335
336        if self.get(z).is_child(Side::Right) {
337            // The right subtree increases
338            if self.is_heavy_on_side(Side::Right, x) {
339                if self.is_heavy_on_side(Side::Left, z) {
340                    self.avl_rotate(Side::Right, z);
341                    self.avl_rotate(Side::Left, x);
342                } else {
343                    // TODO: rotates panic rn
344                    // wiki has a differnet definiton of
345                    // rotate than we do I think
346                    self.avl_rotate(Side::Left, x);
347                    //self.rotate(Side::Left, z);
348                }
349            } else {
350                if self.is_heavy_on_side(Side::Left, x) {
351                    self.set_balance_factor(x, 0);
352                    return;
353                }
354                self.set_balance_factor(x, 1);
355                //Z = X; // Height(Z) increases by 1
356                //z = x;
357                self.retrace(x);
358                //continue;
359            }
360        } else {
361            if self.is_heavy_on_side(Side::Left, x) {
362                if self.is_heavy_on_side(Side::Right, z) {
363                    self.avl_rotate(Side::Left, z);
364                    self.avl_rotate(Side::Right, x);
365                } else {
366                    self.avl_rotate(Side::Right, x);
367                }
368            } else {
369                if self.is_heavy_on_side(Side::Right, x) {
370                    self.set_balance_factor(x, 0);
371                    return; // Leave the loop
372                }
373                self.set_balance_factor(x, -1);
374                //Z = X; // Height(Z) increases by 1
375                //z = x;
376                self.retrace(x);
377                //continue;
378            }
379        }
380        //self.retrace(x);
381        return;
382        // Unless loop is left via break, the height of the total tree increases by 1.
383    }
384
385    fn avl_rotate(&mut self, side: Side, n: usize) {
386        // make an adjustment to account for differnt rotate
387        // algorithm off wiki than implemented in tree...
388        // ALSO adjust the balance factors
389        //        println!("Pre-rotate on n={:?} for\n {}",
390        //            self.get(n).value,
391        //            self.to_pretty_string());
392        if let Some(z) = self.get(n).get_child(!side) {
393            self.rotate(side, z);
394            //self.traverse_to_fix(self.root.unwrap());
395            self.traverse_to_fix(z);
396        //            match self.calc_bal_fac(z) {
397        //                0 => {
398        //                    self.set_balance_factor(n, 1);
399        //                    self.set_balance_factor(z, -1);
400        //                }
401        //                _ => {
402        //                    self.set_balance_factor(n, 0);
403        //                    self.set_balance_factor(z, 0);
404        //                }
405        //            }
406        } else {
407            //panic!("avl rotate unwrap");
408            println!("tried to rotate on None");
409        }
410    }
411
412    fn get_balance_factor(&self, n: usize) -> isize {
413        self.get(n).balance_factor
414    }
415
416    fn set_balance_factor(&mut self, n: usize, bf: isize) {
417        self.get_mut(n).balance_factor = bf;
418    }
419
420    fn calc_bal_fac(&self, n: usize) -> isize {
421        let rc = self.get(n).get_child(Side::Right);
422        let lc = self.get(n).get_child(Side::Left);
423        let safe_get_bf = |x| match x {
424            Some(y) => self.get_balance_factor(y),
425            None => 0,
426        };
427        let bf_rc = safe_get_bf(rc);
428        let bf_lc = safe_get_bf(lc);
429        bf_rc - bf_lc
430    }
431
432    fn is_heavy_on_side(&self, side: Side, n: usize) -> bool {
433        // check the balance factor on side of node n
434        match side {
435            Side::Right => self.get_balance_factor(n) > 0,
436            Side::Left => self.get_balance_factor(n) < 0,
437        }
438    }
439
440    fn slow_delete(&mut self) {
441        let mut t = AVLTree::new();
442        let mut v = self.data.borrow_mut().pop();
443        while v.is_some() {
444            let n = v.unwrap();
445            if n.ptr != TREE_END {
446                t.insert(n.value);
447            }
448            v = self.data.borrow_mut().pop();
449        }
450
451        *self = t;
452        self.size += 1;
453    }
454
455    fn fix_bf(&mut self, n: usize) {
456        let rc = self.get(n).get_child(Side::Right);
457        let lc = self.get(n).get_child(Side::Left);
458        // get height and BF of each child
459
460        //        let rcbf = match rc {
461        //            Some(c) =>  self.get_balance_factor(c),
462        //            None => 0,
463        //        };
464        //        let lcbf = match lc {
465        //            Some(c) => self.get_balance_factor(c),
466        //            None => 0,
467        //        };
468        let rch = match rc {
469            Some(c) => self.get(c).height,
470            None => 0,
471        };
472        let lch = match lc {
473            Some(c) => self.get(c).height,
474            None => 0,
475        };
476        self.get_mut(n).height = std::cmp::max(lch, rch) + 1;
477        self.set_balance_factor(n, rch as isize - lch as isize);
478    }
479
480    fn traverse_to_fix(&mut self, n: usize) {
481        /*if !self.get(n).is_some() {
482            return;
483        }*/
484        if let Some(c) = self.get(n).get_child(Side::Left) {
485            self.traverse_to_fix(c);
486        }
487
488        if let Some(c) = self.get(n).get_child(Side::Right) {
489            self.traverse_to_fix(c);
490        }
491        self.fix_bf(n);
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn new() {
501        let _tree = AVLTree::<i32>::new();
502    }
503
504    #[test]
505    fn insert_one() {
506        let mut tree = AVLTree::<i32>::new();
507        tree.insert(1);
508        let root = tree.root.expect("tree root");
509        assert_eq!(tree.get_balance_factor(root), 0);
510        assert!(tree.is_heavy_on_side(Side::Right, root) == false);
511        assert!(tree.is_heavy_on_side(Side::Left, root) == false);
512    }
513
514    #[test]
515    fn balance_factor_helpers() {
516        let mut tree = AVLTree::<i32>::new();
517        tree.insert(1);
518        let root = tree.root.expect("tree root");
519        tree.set_balance_factor(root, 1);
520        assert!(tree.is_heavy_on_side(Side::Right, root));
521        tree.set_balance_factor(root, -1);
522        assert!(tree.is_heavy_on_side(Side::Left, root));
523    }
524
525    #[test]
526    fn insert_few() {
527        // puts the smallest tree through all the combos
528        // of rebalance rotations
529        let mut tree = AVLTree::<i32>::new();
530        tree.insert(1);
531        tree.insert(2);
532        tree.insert(3);
533        println!("123");
534        assert_eq!(
535            tree.to_string(),
536            "([V:2 H:2 BF:0] ([V:1 H:1 BF:0] () ()) ([V:3 H:1 BF:0] () ()))"
537        );
538
539        let mut tree = AVLTree::<i32>::new();
540        tree.insert(1);
541        tree.insert(3);
542        tree.insert(2);
543        println!("132");
544        assert_eq!(
545            tree.to_string(),
546            "([V:2 H:2 BF:0] ([V:1 H:1 BF:0] () ()) ([V:3 H:1 BF:0] () ()))"
547        );
548
549        let mut tree = AVLTree::<i32>::new();
550        tree.insert(3);
551        tree.insert(2);
552        tree.insert(1);
553        println!("321");
554        assert_eq!(
555            tree.to_string(),
556            "([V:2 H:2 BF:0] ([V:1 H:1 BF:0] () ()) ([V:3 H:1 BF:0] () ()))"
557        );
558
559        let mut tree = AVLTree::<i32>::new();
560        tree.insert(3);
561        tree.insert(1);
562        tree.insert(2);
563        println!("312");
564        assert_eq!(
565            tree.to_string(),
566            "([V:2 H:2 BF:0] ([V:1 H:1 BF:0] () ()) ([V:3 H:1 BF:0] () ()))"
567        );
568    }
569
570    #[test]
571    fn avl_del() {
572        let mut tree = AVLTree::<i32>::new();
573        tree.insert(2);
574        tree.insert(4);
575        tree.insert(6);
576
577        for i in vec![1, 3, 5, 7] {
578            println!("Adding and removing leaf v={}", i);
579            tree.insert(i);
580            tree.delete(i);
581            assert_eq!(
582                tree.to_string(),
583                "([V:4 H:2 BF:0] ([V:2 H:1 BF:0] () ()) ([V:6 H:1 BF:0] () ()))"
584            );
585        }
586    }
587}