unc_sdk/collections/
tree_map.rs

1use borsh::{BorshDeserialize, BorshSerialize};
2use std::ops::Bound;
3
4use crate::collections::LookupMap;
5use crate::collections::{append, Vector};
6use crate::{env, IntoStorageKey};
7use unc_sdk_macros::unc;
8
9/// TreeMap based on AVL-tree
10///
11/// Runtime complexity (worst case):
12/// - `get`/`contains_key`:     O(1) - UnorderedMap lookup
13/// - `insert`/`remove`:        O(log(N))
14/// - `min`/`max`:              O(log(N))
15/// - `above`/`below`:          O(log(N))
16/// - `range` of K elements:    O(Klog(N))
17///
18
19#[unc(inside_uncsdk)]
20pub struct TreeMap<K, V> {
21    root: u64,
22    // ser/de is independent of `K`,`V` ser/de, `BorshSerialize`/`BorshDeserialize`/`BorshSchema` bounds removed
23    #[cfg_attr(not(feature = "abi"), borsh(bound(serialize = "", deserialize = "")))]
24    #[cfg_attr(
25        feature = "abi",
26        borsh(bound(serialize = "", deserialize = ""), schema(params = ""))
27    )]
28    val: LookupMap<K, V>,
29    // ser/de is independent of `K` ser/de, `BorshSerialize`/`BorshDeserialize`/`BorshSchema` bounds removed
30    #[cfg_attr(not(feature = "abi"), borsh(bound(serialize = "", deserialize = "")))]
31    #[cfg_attr(
32        feature = "abi",
33        borsh(bound(serialize = "", deserialize = ""), schema(params = ""))
34    )]
35    tree: Vector<Node<K>>,
36}
37
38#[unc(inside_uncsdk)]
39#[derive(Clone, Debug)]
40pub struct Node<K> {
41    id: u64,
42    key: K,           // key stored in a node
43    lft: Option<u64>, // left link of a node
44    rgt: Option<u64>, // right link of a node
45    ht: u64,          // height of a subtree at a node
46}
47
48impl<K> Node<K>
49where
50    K: Ord + Clone + BorshSerialize + BorshDeserialize,
51{
52    fn of(id: u64, key: K) -> Self {
53        Self { id, key, lft: None, rgt: None, ht: 1 }
54    }
55}
56
57impl<K, V> TreeMap<K, V>
58where
59    K: Ord + Clone + BorshSerialize + BorshDeserialize,
60    V: BorshSerialize + BorshDeserialize,
61{
62    /// Makes a new, empty TreeMap
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// use unc_sdk::collections::TreeMap;
68    /// let mut tree: TreeMap<u32, u32> = TreeMap::new(b"t");
69    /// ```
70    pub fn new<S>(prefix: S) -> Self
71    where
72        S: IntoStorageKey,
73    {
74        let prefix = prefix.into_storage_key();
75        Self {
76            root: 0,
77            val: LookupMap::new(append(&prefix, b'v')),
78            tree: Vector::new(append(&prefix, b'n')),
79        }
80    }
81
82    /// Returns the number of elements in the tree, also referred to as its size.
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use unc_sdk::collections::TreeMap;
88    ///
89    /// let mut tree: TreeMap<u32, u32> = TreeMap::new(b"t");
90    /// tree.insert(&1, &10);
91    /// tree.insert(&2, &20);
92    /// assert_eq!(tree.len(), 2);
93    /// ```
94    pub fn len(&self) -> u64 {
95        self.tree.len()
96    }
97
98    pub fn is_empty(&self) -> bool {
99        self.tree.is_empty()
100    }
101
102    /// Clears the tree, removing all elements.
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// use unc_sdk::collections::TreeMap;
108    ///
109    /// let mut tree: TreeMap<u32, u32> = TreeMap::new(b"t");
110    /// tree.insert(&1, &10);
111    /// tree.insert(&2, &20);
112    /// tree.clear();
113    /// assert_eq!(tree.len(), 0);
114    /// ```
115    pub fn clear(&mut self) {
116        self.root = 0;
117        for n in self.tree.iter() {
118            self.val.remove(&n.key);
119        }
120        self.tree.clear();
121    }
122
123    fn node(&self, id: u64) -> Option<Node<K>> {
124        self.tree.get(id)
125    }
126
127    fn save(&mut self, node: &Node<K>) {
128        if node.id < self.len() {
129            self.tree.replace(node.id, node);
130        } else {
131            self.tree.push(node);
132        }
133    }
134
135    /// Returns true if the map contains a given key.
136    ///
137    /// # Examples
138    ///
139    /// ```
140    /// use unc_sdk::collections::TreeMap;
141    ///
142    /// let mut tree: TreeMap<u32, u32> = TreeMap::new(b"t");
143    /// assert_eq!(tree.contains_key(&1), false);
144    /// tree.insert(&1, &10);
145    /// assert_eq!(tree.contains_key(&1), true);
146    /// ```
147    pub fn contains_key(&self, key: &K) -> bool {
148        self.val.get(key).is_some()
149    }
150
151    /// Returns the value corresponding to the key.
152    ///
153    /// # Examples
154    ///
155    /// ```
156    /// use unc_sdk::collections::TreeMap;
157    ///
158    /// let mut tree: TreeMap<u32, u32> = TreeMap::new(b"t");
159    /// assert_eq!(tree.get(&1), None);
160    /// tree.insert(&1, &10);
161    /// assert_eq!(tree.get(&1), Some(10));
162    /// ```
163    pub fn get(&self, key: &K) -> Option<V> {
164        self.val.get(key)
165    }
166
167    /// Inserts a key-value pair into the tree.
168    /// If the tree did not have this key present, `None` is returned. Otherwise returns
169    /// a value. Note, the keys that have the same hash value are undistinguished by
170    /// the implementation.
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// use unc_sdk::collections::TreeMap;
176    ///
177    /// let mut tree: TreeMap<u32, u32> = TreeMap::new(b"t");
178    /// assert_eq!(tree.insert(&1, &10), None);
179    /// assert_eq!(tree.insert(&1, &20), Some(10));
180    /// assert_eq!(tree.contains_key(&1), true);
181    /// ```
182    pub fn insert(&mut self, key: &K, val: &V) -> Option<V> {
183        if !self.contains_key(key) {
184            self.root = self.insert_at(self.root, self.len(), key);
185        }
186        self.val.insert(key, val)
187    }
188
189    /// Removes a key from the tree, returning the value at the key if the key was previously in the
190    /// tree.
191    ///
192    /// # Examples
193    ///
194    /// ```
195    /// use unc_sdk::collections::TreeMap;
196    ///
197    /// let mut tree: TreeMap<u32, u32> = TreeMap::new(b"t");
198    /// assert_eq!(tree.remove(&1), None);
199    /// tree.insert(&1, &10);
200    /// assert_eq!(tree.remove(&1), Some(10));
201    /// assert_eq!(tree.contains_key(&1), false);
202    /// ```
203    pub fn remove(&mut self, key: &K) -> Option<V> {
204        if self.contains_key(key) {
205            self.root = self.do_remove(key);
206            self.val.remove(key)
207        } else {
208            // no such key, nothing to do
209            None
210        }
211    }
212
213    /// Returns the smallest stored key from the tree
214    pub fn min(&self) -> Option<K> {
215        self.min_at(self.root, self.root).map(|(n, _)| n.key)
216    }
217
218    /// Returns the largest stored key from the tree
219    pub fn max(&self) -> Option<K> {
220        self.max_at(self.root, self.root).map(|(n, _)| n.key)
221    }
222
223    /// Returns the smallest key that is strictly greater than key given as the parameter
224    pub fn higher(&self, key: &K) -> Option<K> {
225        self.above_at(self.root, key)
226    }
227
228    /// Returns the largest key that is strictly less than key given as the parameter
229    pub fn lower(&self, key: &K) -> Option<K> {
230        self.below_at(self.root, key)
231    }
232
233    /// Returns the smallest key that is greater or equal to key given as the parameter
234    ///
235    /// # Examples
236    ///
237    /// ```
238    /// use unc_sdk::collections::TreeMap;
239    ///
240    /// let mut map: TreeMap<u32, u32> = TreeMap::new(b"t");
241    /// let vec: Vec<u32> = vec![10, 20, 30, 40, 50];
242    ///
243    /// for x in vec.iter() {
244    ///     map.insert(x, &1);
245    /// }
246    ///
247    /// assert_eq!(map.ceil_key(&5), Some(10));
248    /// assert_eq!(map.ceil_key(&10), Some(10));
249    /// assert_eq!(map.ceil_key(&11), Some(20));
250    /// assert_eq!(map.ceil_key(&20), Some(20));
251    /// assert_eq!(map.ceil_key(&49), Some(50));
252    /// assert_eq!(map.ceil_key(&50), Some(50));
253    /// assert_eq!(map.ceil_key(&51), None);
254    /// ```
255    pub fn ceil_key(&self, key: &K) -> Option<K> {
256        if self.contains_key(key) {
257            Some(key.clone())
258        } else {
259            self.higher(key)
260        }
261    }
262
263    /// Returns the largest key that is less or equal to key given as the parameter
264    ///
265    /// # Examples
266    ///
267    /// ```
268    /// use unc_sdk::collections::TreeMap;
269    ///
270    /// let mut map: TreeMap<u32, u32> = TreeMap::new(b"t");
271    /// let vec: Vec<u32> = vec![10, 20, 30, 40, 50];
272    /// for x in vec.iter() {
273    ///     map.insert(x, &1);
274    /// }
275    ///
276    /// assert_eq!(map.floor_key(&5), None);
277    /// assert_eq!(map.floor_key(&10), Some(10));
278    /// assert_eq!(map.floor_key(&11), Some(10));
279    /// assert_eq!(map.floor_key(&20), Some(20));
280    /// assert_eq!(map.floor_key(&49), Some(40));
281    /// assert_eq!(map.floor_key(&50), Some(50));
282    /// assert_eq!(map.floor_key(&51), Some(50));
283    /// ```
284    pub fn floor_key(&self, key: &K) -> Option<K> {
285        if self.contains_key(key) {
286            Some(key.clone())
287        } else {
288            self.lower(key)
289        }
290    }
291
292    /// Iterate all entries in ascending order: min to max, both inclusive
293    pub fn iter(&self) -> impl Iterator<Item = (K, V)> + '_ {
294        Cursor::asc(self)
295    }
296
297    /// Iterate entries in ascending order: given key (exclusive) to max (inclusive)
298    ///
299    /// # Examples
300    ///
301    /// ```
302    /// use unc_sdk::collections::TreeMap;
303    ///
304    /// let mut map: TreeMap<u32, u32> = TreeMap::new(b"t");
305    /// let one: Vec<u32> = vec![10, 20, 30, 40, 50,45, 35, 25, 15, 5];
306    /// for x in &one {
307    ///     map.insert(x, &42);
308    /// }
309    /// assert_eq!(
310    ///     map.iter_from(29).collect::<Vec<(u32, u32)>>(),
311    ///     vec![(30, 42), (35, 42), (40, 42), (45, 42), (50, 42)]
312    /// )
313    /// ```
314    pub fn iter_from(&self, key: K) -> impl Iterator<Item = (K, V)> + '_ {
315        Cursor::asc_from(self, key)
316    }
317
318    /// Iterate all entries in descending order: max to min, both inclusive
319    pub fn iter_rev(&self) -> impl Iterator<Item = (K, V)> + '_ {
320        Cursor::desc(self)
321    }
322
323    /// Iterate entries in descending order: given key (exclusive) to min (inclusive)
324    ///
325    /// # Examples
326    ///
327    /// ```
328    /// use unc_sdk::collections::TreeMap;
329    ///
330    /// let mut map: TreeMap<u32, u32> = TreeMap::new(b"t");
331    /// let one: Vec<u32> = vec![10, 20, 30, 40, 50,45, 35, 25, 15, 5];
332    /// for x in &one {
333    ///     map.insert(x, &42);
334    /// }
335    /// assert_eq!(
336    ///     map.iter_rev_from(45).collect::<Vec<(u32, u32)>>(),
337    ///     vec![(40, 42), (35, 42), (30, 42), (25, 42), (20, 42), (15, 42), (10, 42), (5, 42)]
338    /// );
339    /// ```
340    pub fn iter_rev_from(&self, key: K) -> impl Iterator<Item = (K, V)> + '_ {
341        Cursor::desc_from(self, key)
342    }
343
344    /// Iterate entries in ascending order according to specified bounds.
345    ///
346    /// # Panics
347    ///
348    /// Panics if range start > end.
349    /// Panics if range start == end and both bounds are Excluded.
350    ///
351    /// # Examples
352    ///
353    /// ```
354    /// use unc_sdk::collections::TreeMap;
355    /// use std::ops::Bound;
356    ///
357    /// let mut map: TreeMap<u32, u32> = TreeMap::new(b"t");
358    /// let one: Vec<u32> = vec![10, 20, 30, 40, 50];
359    /// let two: Vec<u32> = vec![45, 35, 25, 15];
360    /// for x in &one {
361    ///     map.insert(x, &0);
362    /// }
363    /// for x in &two {
364    ///     map.insert(x, &0);
365    /// }
366    /// assert_eq!(
367    ///     map.range((Bound::Included(20), Bound::Excluded(30))).collect::<Vec<(u32, u32)>>(),
368    ///     vec![(20, 0), (25, 0)]
369    /// );
370    /// ```
371    pub fn range(&self, r: (Bound<K>, Bound<K>)) -> impl Iterator<Item = (K, V)> + '_ {
372        let (lo, hi) = match r {
373            (Bound::Included(a), Bound::Included(b)) if a > b => env::panic_str("Invalid range."),
374            (Bound::Excluded(a), Bound::Included(b)) if a > b => env::panic_str("Invalid range."),
375            (Bound::Included(a), Bound::Excluded(b)) if a > b => env::panic_str("Invalid range."),
376            (Bound::Excluded(a), Bound::Excluded(b)) if a >= b => env::panic_str("Invalid range."),
377            (lo, hi) => (lo, hi),
378        };
379
380        Cursor::range(self, lo, hi)
381    }
382
383    /// Helper function which creates a [`Vec<(K, V)>`] of all items in the [`TreeMap`].
384    /// This function collects elements from [`TreeMap::iter`].
385    pub fn to_vec(&self) -> Vec<(K, V)> {
386        self.iter().collect()
387    }
388
389    //
390    // Internal utilities
391    //
392
393    /// Returns (node, parent node) of left-most lower (min) node starting from given node `at`.
394    /// As min_at only traverses the tree down, if a node `at` is the minimum node in a subtree,
395    /// its parent must be explicitly provided in advance.
396    fn min_at(&self, mut at: u64, p: u64) -> Option<(Node<K>, Node<K>)> {
397        let mut parent: Option<Node<K>> = self.node(p);
398        loop {
399            let node = self.node(at);
400            match node.as_ref().and_then(|n| n.lft) {
401                Some(lft) => {
402                    at = lft;
403                    parent = node;
404                }
405                None => {
406                    return node.and_then(|n| parent.map(|p| (n, p)));
407                }
408            }
409        }
410    }
411
412    /// Returns (node, parent node) of right-most lower (max) node starting from given node `at`.
413    /// As min_at only traverses the tree down, if a node `at` is the minimum node in a subtree,
414    /// its parent must be explicitly provided in advance.
415    fn max_at(&self, mut at: u64, p: u64) -> Option<(Node<K>, Node<K>)> {
416        let mut parent: Option<Node<K>> = self.node(p);
417        loop {
418            let node = self.node(at);
419            match node.as_ref().and_then(|n| n.rgt) {
420                Some(rgt) => {
421                    parent = node;
422                    at = rgt;
423                }
424                None => {
425                    return node.and_then(|n| parent.map(|p| (n, p)));
426                }
427            }
428        }
429    }
430
431    fn above_at(&self, mut at: u64, key: &K) -> Option<K> {
432        let mut seen: Option<K> = None;
433        loop {
434            let node = self.node(at);
435            match node.as_ref().map(|n| &n.key) {
436                Some(k) => {
437                    if k.le(key) {
438                        match node.and_then(|n| n.rgt) {
439                            Some(rgt) => at = rgt,
440                            None => break,
441                        }
442                    } else {
443                        seen = Some(k.clone());
444                        match node.and_then(|n| n.lft) {
445                            Some(lft) => at = lft,
446                            None => break,
447                        }
448                    }
449                }
450                None => break,
451            }
452        }
453        seen
454    }
455
456    fn below_at(&self, mut at: u64, key: &K) -> Option<K> {
457        let mut seen: Option<K> = None;
458        loop {
459            let node = self.node(at);
460            match node.as_ref().map(|n| &n.key) {
461                Some(k) => {
462                    if k.lt(key) {
463                        seen = Some(k.clone());
464                        match node.and_then(|n| n.rgt) {
465                            Some(rgt) => at = rgt,
466                            None => break,
467                        }
468                    } else {
469                        match node.and_then(|n| n.lft) {
470                            Some(lft) => at = lft,
471                            None => break,
472                        }
473                    }
474                }
475                None => break,
476            }
477        }
478        seen
479    }
480
481    fn insert_at(&mut self, at: u64, id: u64, key: &K) -> u64 {
482        match self.node(at) {
483            None => {
484                self.save(&Node::of(id, key.clone()));
485                at
486            }
487            Some(mut node) => {
488                if key.eq(&node.key) {
489                    at
490                } else {
491                    if key.lt(&node.key) {
492                        let idx = match node.lft {
493                            Some(lft) => self.insert_at(lft, id, key),
494                            None => self.insert_at(id, id, key),
495                        };
496                        node.lft = Some(idx);
497                    } else {
498                        let idx = match node.rgt {
499                            Some(rgt) => self.insert_at(rgt, id, key),
500                            None => self.insert_at(id, id, key),
501                        };
502                        node.rgt = Some(idx);
503                    };
504
505                    self.update_height(&mut node);
506                    self.enforce_balance(&mut node)
507                }
508            }
509        }
510    }
511
512    // Calculate and save the height of a subtree at node `at`:
513    // height[at] = 1 + max(height[at.L], height[at.R])
514    fn update_height(&mut self, node: &mut Node<K>) {
515        let lft = node.lft.and_then(|id| self.node(id).map(|n| n.ht)).unwrap_or_default();
516        let rgt = node.rgt.and_then(|id| self.node(id).map(|n| n.ht)).unwrap_or_default();
517
518        node.ht = 1 + std::cmp::max(lft, rgt);
519        self.save(node);
520    }
521
522    // Balance = difference in heights between left and right subtrees at given node.
523    fn get_balance(&self, node: &Node<K>) -> i64 {
524        let lht = node.lft.and_then(|id| self.node(id).map(|n| n.ht)).unwrap_or_default();
525        let rht = node.rgt.and_then(|id| self.node(id).map(|n| n.ht)).unwrap_or_default();
526
527        lht as i64 - rht as i64
528    }
529
530    // Left rotation of an AVL subtree with at node `at`.
531    // New root of subtree is returned, caller is responsible for updating proper link from parent.
532    fn rotate_left(&mut self, node: &mut Node<K>) -> u64 {
533        let mut lft = node.lft.and_then(|id| self.node(id)).unwrap();
534        let lft_rgt = lft.rgt;
535
536        // at.L = at.L.R
537        node.lft = lft_rgt;
538
539        // at.L.R = at
540        lft.rgt = Some(node.id);
541
542        // at = at.L
543        self.update_height(node);
544        self.update_height(&mut lft);
545
546        lft.id
547    }
548
549    // Right rotation of an AVL subtree at node in `at`.
550    // New root of subtree is returned, caller is responsible for updating proper link from parent.
551    fn rotate_right(&mut self, node: &mut Node<K>) -> u64 {
552        let mut rgt = node.rgt.and_then(|id| self.node(id)).unwrap();
553        let rgt_lft = rgt.lft;
554
555        // at.R = at.R.L
556        node.rgt = rgt_lft;
557
558        // at.R.L = at
559        rgt.lft = Some(node.id);
560
561        // at = at.R
562        self.update_height(node);
563        self.update_height(&mut rgt);
564
565        rgt.id
566    }
567
568    // Check balance at a given node and enforce it if necessary with respective rotations.
569    fn enforce_balance(&mut self, node: &mut Node<K>) -> u64 {
570        let balance = self.get_balance(node);
571        if balance > 1 {
572            let mut lft = node.lft.and_then(|id| self.node(id)).unwrap();
573            if self.get_balance(&lft) < 0 {
574                let rotated = self.rotate_right(&mut lft);
575                node.lft = Some(rotated);
576            }
577            self.rotate_left(node)
578        } else if balance < -1 {
579            let mut rgt = node.rgt.and_then(|id| self.node(id)).unwrap();
580            if self.get_balance(&rgt) > 0 {
581                let rotated = self.rotate_left(&mut rgt);
582                node.rgt = Some(rotated);
583            }
584            self.rotate_right(node)
585        } else {
586            node.id
587        }
588    }
589
590    // Returns (node, parent node) for a node that holds the `key`.
591    // For root node, same node is returned for node and parent node.
592    fn lookup_at(&self, mut at: u64, key: &K) -> Option<(Node<K>, Node<K>)> {
593        let mut p: Node<K> = self.node(at).unwrap();
594        while let Some(node) = self.node(at) {
595            if node.key.eq(key) {
596                return Some((node, p));
597            } else if node.key.lt(key) {
598                match node.rgt {
599                    Some(rgt) => {
600                        p = node;
601                        at = rgt;
602                    }
603                    None => break,
604                }
605            } else {
606                match node.lft {
607                    Some(lft) => {
608                        p = node;
609                        at = lft;
610                    }
611                    None => break,
612                }
613            }
614        }
615        None
616    }
617
618    // Navigate from root to node holding `key` and backtrace back to the root
619    // enforcing balance (if necessary) along the way.
620    fn check_balance(&mut self, at: u64, key: &K) -> u64 {
621        match self.node(at) {
622            Some(mut node) => {
623                if !node.key.eq(key) {
624                    if node.key.gt(key) {
625                        if let Some(l) = node.lft {
626                            let id = self.check_balance(l, key);
627                            node.lft = Some(id);
628                        }
629                    } else if let Some(r) = node.rgt {
630                        let id = self.check_balance(r, key);
631                        node.rgt = Some(id);
632                    }
633                }
634                self.update_height(&mut node);
635                self.enforce_balance(&mut node)
636            }
637            None => at,
638        }
639    }
640
641    // Node holding the key is not removed from the tree - instead the substitute node is found,
642    // the key is copied to 'removed' node from substitute node, and then substitute node gets
643    // removed from the tree.
644    //
645    // The substitute node is either:
646    // - right-most (max) node of the left subtree (containing smaller keys) of node holding `key`
647    // - or left-most (min) node of the right subtree (containing larger keys) of node holding `key`
648    //
649    fn do_remove(&mut self, key: &K) -> u64 {
650        // r_node - node containing key of interest
651        // p_node - immediate parent node of r_node
652        let (mut r_node, mut p_node) = match self.lookup_at(self.root, key) {
653            Some(x) => x,
654            None => return self.root, // cannot remove a missing key, no changes to the tree needed
655        };
656
657        let lft_opt = r_node.lft;
658        let rgt_opt = r_node.rgt;
659
660        if lft_opt.is_none() && rgt_opt.is_none() {
661            // remove leaf
662            if p_node.key.lt(key) {
663                p_node.rgt = None;
664            } else {
665                p_node.lft = None;
666            }
667            self.update_height(&mut p_node);
668
669            self.swap_with_last(r_node.id);
670
671            // removing node might have caused a imbalance - balance the tree up to the root,
672            // starting from lowest affected key - the parent of a leaf node in this case
673            self.check_balance(self.root, &p_node.key)
674        } else {
675            // non-leaf node, select subtree to proceed with
676            let b = self.get_balance(&r_node);
677            if b >= 0 {
678                // proceed with left subtree
679                let lft = lft_opt.unwrap();
680
681                // k - max key from left subtree
682                // n - node that holds key k, p - immediate parent of n
683                let (n, mut p) = self.max_at(lft, r_node.id).unwrap();
684                let k = n.key.clone();
685
686                if p.rgt.as_ref().map(|&id| id == n.id).unwrap_or_default() {
687                    // n is on right link of p
688                    p.rgt = n.lft;
689                } else {
690                    // n is on left link of p
691                    p.lft = n.lft;
692                }
693
694                self.update_height(&mut p);
695
696                if r_node.id == p.id {
697                    // r_node.id and p.id can overlap on small trees (2 levels, 2-3 nodes)
698                    // that leads to nasty lost update of the key, refresh below fixes that
699                    r_node = self.node(r_node.id).unwrap();
700                }
701                r_node.key = k;
702                self.save(&r_node);
703
704                self.swap_with_last(n.id);
705
706                // removing node might have caused an imbalance - balance the tree up to the root,
707                // starting from the lowest affected key (max key from left subtree in this case)
708                self.check_balance(self.root, &p.key)
709            } else {
710                // proceed with right subtree
711                let rgt = rgt_opt.unwrap();
712
713                // k - min key from right subtree
714                // n - node that holds key k, p - immediate parent of n
715                let (n, mut p) = self.min_at(rgt, r_node.id).unwrap();
716                let k = n.key.clone();
717
718                if p.lft.map(|id| id == n.id).unwrap_or_default() {
719                    // n is on left link of p
720                    p.lft = n.rgt;
721                } else {
722                    // n is on right link of p
723                    p.rgt = n.rgt;
724                }
725
726                self.update_height(&mut p);
727
728                if r_node.id == p.id {
729                    // r_node.id and p.id can overlap on small trees (2 levels, 2-3 nodes)
730                    // that leads to nasty lost update of the key, refresh below fixes that
731                    r_node = self.node(r_node.id).unwrap();
732                }
733                r_node.key = k;
734                self.save(&r_node);
735
736                self.swap_with_last(n.id);
737
738                // removing node might have caused a imbalance - balance the tree up to the root,
739                // starting from the lowest affected key (min key from right subtree in this case)
740                self.check_balance(self.root, &p.key)
741            }
742        }
743    }
744
745    // Move content of node with id = `len - 1` (parent left or right link, left, right, key, height)
746    // to node with given `id`, and remove node `len - 1` (pop the vector of nodes).
747    // This ensures that among `n` nodes in the tree, max `id` is `n-1`, so when new node is inserted,
748    // it gets an `id` as its position in the vector.
749    fn swap_with_last(&mut self, id: u64) {
750        if id == self.len() - 1 {
751            // noop: id is already last element in the vector
752            self.tree.pop();
753            return;
754        }
755
756        let key = self.node(self.len() - 1).map(|n| n.key).unwrap();
757        let (mut n, mut p) = self.lookup_at(self.root, &key).unwrap();
758
759        if n.id != p.id {
760            if p.lft.map(|id| id == n.id).unwrap_or_default() {
761                p.lft = Some(id);
762            } else {
763                p.rgt = Some(id);
764            }
765            self.save(&p);
766        }
767
768        if self.root == n.id {
769            self.root = id;
770        }
771
772        n.id = id;
773        self.save(&n);
774        self.tree.pop();
775    }
776}
777
778impl<K, V> std::fmt::Debug for TreeMap<K, V>
779where
780    K: std::fmt::Debug + Ord + Clone + BorshSerialize + BorshDeserialize,
781    V: std::fmt::Debug + BorshSerialize + BorshDeserialize,
782{
783    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
784        f.debug_struct("TreeMap").field("root", &self.root).field("tree", &self.tree).finish()
785    }
786}
787
788impl<'a, K, V> IntoIterator for &'a TreeMap<K, V>
789where
790    K: Ord + Clone + BorshSerialize + BorshDeserialize,
791    V: BorshSerialize + BorshDeserialize,
792{
793    type Item = (K, V);
794    type IntoIter = Cursor<'a, K, V>;
795
796    fn into_iter(self) -> Self::IntoIter {
797        Cursor::asc(self)
798    }
799}
800
801impl<K, V> Iterator for Cursor<'_, K, V>
802where
803    K: Ord + Clone + BorshSerialize + BorshDeserialize,
804    V: BorshSerialize + BorshDeserialize,
805{
806    type Item = (K, V);
807
808    fn next(&mut self) -> Option<Self::Item> {
809        <Self as Iterator>::nth(self, 0)
810    }
811
812    fn size_hint(&self) -> (usize, Option<usize>) {
813        // Constrains max count. Not worth it to cause storage reads to make this more accurate.
814        (0, Some(self.map.len() as usize))
815    }
816
817    fn count(mut self) -> usize {
818        // Because this Cursor allows for bounded/starting from a key, there is no way of knowing
819        // how many elements are left to iterate without loading keys in order. This could be
820        // optimized in the case of a standard iterator by having a separate type, but this would
821        // be a breaking change, so there will be slightly more reads than necessary in this case.
822        let mut count = 0;
823        while self.key.is_some() {
824            count += 1;
825            self.progress_key();
826        }
827        count
828    }
829
830    fn nth(&mut self, n: usize) -> Option<Self::Item> {
831        for _ in 0..n {
832            // Skip over elements not iterated over to get to `nth`. This avoids loading values
833            // from storage.
834            self.progress_key();
835        }
836
837        let key = self.progress_key()?;
838        let value = self.map.get(&key)?;
839
840        Some((key, value))
841    }
842
843    fn last(mut self) -> Option<Self::Item> {
844        if self.asc && matches!(self.hi, Bound::Unbounded) {
845            self.map.max().and_then(|k| self.map.get(&k).map(|v| (k, v)))
846        } else if !self.asc && matches!(self.lo, Bound::Unbounded) {
847            self.map.min().and_then(|k| self.map.get(&k).map(|v| (k, v)))
848        } else {
849            // Cannot guarantee what the last is within the range, must load keys until last.
850            let key = core::iter::from_fn(|| self.progress_key()).last();
851            key.and_then(|k| self.map.get(&k).map(|v| (k, v)))
852        }
853    }
854}
855
856impl<K, V> std::iter::FusedIterator for Cursor<'_, K, V>
857where
858    K: Ord + Clone + BorshSerialize + BorshDeserialize,
859    V: BorshSerialize + BorshDeserialize,
860{
861}
862
863fn fits<K: Ord>(key: &K, lo: &Bound<K>, hi: &Bound<K>) -> bool {
864    (match lo {
865        Bound::Included(ref x) => key >= x,
866        Bound::Excluded(ref x) => key > x,
867        Bound::Unbounded => true,
868    }) && (match hi {
869        Bound::Included(ref x) => key <= x,
870        Bound::Excluded(ref x) => key < x,
871        Bound::Unbounded => true,
872    })
873}
874
875pub struct Cursor<'a, K, V> {
876    asc: bool,
877    lo: Bound<K>,
878    hi: Bound<K>,
879    key: Option<K>,
880    map: &'a TreeMap<K, V>,
881}
882
883impl<'a, K, V> Cursor<'a, K, V>
884where
885    K: Ord + Clone + BorshSerialize + BorshDeserialize,
886    V: BorshSerialize + BorshDeserialize,
887{
888    fn asc(map: &'a TreeMap<K, V>) -> Self {
889        let key: Option<K> = map.min();
890        Self { asc: true, key, lo: Bound::Unbounded, hi: Bound::Unbounded, map }
891    }
892
893    fn asc_from(map: &'a TreeMap<K, V>, key: K) -> Self {
894        let key = map.higher(&key);
895        Self { asc: true, key, lo: Bound::Unbounded, hi: Bound::Unbounded, map }
896    }
897
898    fn desc(map: &'a TreeMap<K, V>) -> Self {
899        let key: Option<K> = map.max();
900        Self { asc: false, key, lo: Bound::Unbounded, hi: Bound::Unbounded, map }
901    }
902
903    fn desc_from(map: &'a TreeMap<K, V>, key: K) -> Self {
904        let key = map.lower(&key);
905        Self { asc: false, key, lo: Bound::Unbounded, hi: Bound::Unbounded, map }
906    }
907
908    fn range(map: &'a TreeMap<K, V>, lo: Bound<K>, hi: Bound<K>) -> Self {
909        let key = match &lo {
910            Bound::Included(k) if map.contains_key(k) => Some(k.clone()),
911            Bound::Included(k) | Bound::Excluded(k) => map.higher(k),
912            _ => None,
913        };
914        let key = key.filter(|k| fits(k, &lo, &hi));
915
916        Self { asc: true, key, lo, hi, map }
917    }
918
919    /// Progresses the key one index, will return the previous key
920    fn progress_key(&mut self) -> Option<K> {
921        let new_key = self
922            .key
923            .as_ref()
924            .and_then(|k| if self.asc { self.map.higher(k) } else { self.map.lower(k) })
925            .filter(|k| fits(k, &self.lo, &self.hi));
926        core::mem::replace(&mut self.key, new_key)
927    }
928}
929
930#[cfg(not(target_arch = "wasm32"))]
931#[cfg(test)]
932mod tests {
933    use super::*;
934    use crate::test_utils::{next_trie_id, test_env};
935
936    extern crate rand;
937    use self::rand::RngCore;
938    use quickcheck::QuickCheck;
939    use std::collections::BTreeMap;
940    use std::collections::HashSet;
941
942    /// Return height of the tree - number of nodes on the longest path starting from the root node.
943    fn height<K, V>(tree: &TreeMap<K, V>) -> u64
944    where
945        K: Ord + Clone + BorshSerialize + BorshDeserialize,
946        V: BorshSerialize + BorshDeserialize,
947    {
948        tree.node(tree.root).map(|n| n.ht).unwrap_or_default()
949    }
950
951    fn random(n: u64) -> Vec<u32> {
952        let mut rng = rand::thread_rng();
953        let mut vec = Vec::with_capacity(n as usize);
954        (0..n).for_each(|_| {
955            vec.push(rng.next_u32() % 1000);
956        });
957        vec
958    }
959
960    fn log2(x: f64) -> f64 {
961        std::primitive::f64::log(x, 2.0f64)
962    }
963
964    fn max_tree_height(n: u64) -> u64 {
965        // h <= C * log2(n + D) + B
966        // where:
967        // C =~ 1.440, D =~ 1.065, B =~ 0.328
968        // (source: https://en.wikipedia.org/wiki/AVL_tree)
969        const B: f64 = -0.328;
970        const C: f64 = 1.440;
971        const D: f64 = 1.065;
972
973        let h = C * log2(n as f64 + D) + B;
974        h.ceil() as u64
975    }
976
977    #[test]
978    fn test_empty() {
979        let map: TreeMap<u8, u8> = TreeMap::new(b't');
980        assert_eq!(map.len(), 0);
981        assert_eq!(height(&map), 0);
982        assert_eq!(map.get(&42), None);
983        assert!(!map.contains_key(&42));
984        assert_eq!(map.min(), None);
985        assert_eq!(map.max(), None);
986        assert_eq!(map.lower(&42), None);
987        assert_eq!(map.higher(&42), None);
988    }
989
990    #[test]
991    fn test_insert_3_rotate_l_l() {
992        let mut map: TreeMap<i32, i32> = TreeMap::new(next_trie_id());
993        assert_eq!(height(&map), 0);
994
995        map.insert(&3, &3);
996        assert_eq!(height(&map), 1);
997
998        map.insert(&2, &2);
999        assert_eq!(height(&map), 2);
1000
1001        map.insert(&1, &1);
1002        assert_eq!(height(&map), 2);
1003
1004        let root = map.root;
1005        assert_eq!(root, 1);
1006        assert_eq!(map.node(root).map(|n| n.key), Some(2));
1007
1008        map.clear();
1009    }
1010
1011    #[test]
1012    fn test_insert_3_rotate_r_r() {
1013        let mut map: TreeMap<i32, i32> = TreeMap::new(next_trie_id());
1014        assert_eq!(height(&map), 0);
1015
1016        map.insert(&1, &1);
1017        assert_eq!(height(&map), 1);
1018
1019        map.insert(&2, &2);
1020        assert_eq!(height(&map), 2);
1021
1022        map.insert(&3, &3);
1023
1024        let root = map.root;
1025        assert_eq!(root, 1);
1026        assert_eq!(map.node(root).map(|n| n.key), Some(2));
1027        assert_eq!(height(&map), 2);
1028
1029        map.clear();
1030    }
1031
1032    #[test]
1033    fn test_insert_lookup_n_asc() {
1034        let mut map: TreeMap<i32, i32> = TreeMap::new(next_trie_id());
1035
1036        let n: u64 = 30;
1037        let cases = (0..2 * (n as i32)).collect::<Vec<i32>>();
1038
1039        let mut counter = 0;
1040        for k in &cases {
1041            if *k % 2 == 0 {
1042                counter += 1;
1043                map.insert(k, &counter);
1044            }
1045        }
1046
1047        counter = 0;
1048        for k in &cases {
1049            if *k % 2 == 0 {
1050                counter += 1;
1051                assert_eq!(map.get(k), Some(counter));
1052            } else {
1053                assert_eq!(map.get(k), None);
1054            }
1055        }
1056
1057        assert!(height(&map) <= max_tree_height(n));
1058        map.clear();
1059    }
1060
1061    #[test]
1062    pub fn test_insert_one() {
1063        let mut map = TreeMap::new(b"m");
1064        assert_eq!(None, map.insert(&1, &2));
1065        assert_eq!(2, map.insert(&1, &3).unwrap());
1066    }
1067
1068    #[test]
1069    fn test_insert_lookup_n_desc() {
1070        let mut map: TreeMap<i32, i32> = TreeMap::new(next_trie_id());
1071
1072        let n: u64 = 30;
1073        let cases = (0..2 * (n as i32)).rev().collect::<Vec<i32>>();
1074
1075        let mut counter = 0;
1076        for k in &cases {
1077            if *k % 2 == 0 {
1078                counter += 1;
1079                map.insert(k, &counter);
1080            }
1081        }
1082
1083        counter = 0;
1084        for k in &cases {
1085            if *k % 2 == 0 {
1086                counter += 1;
1087                assert_eq!(map.get(k), Some(counter));
1088            } else {
1089                assert_eq!(map.get(k), None);
1090            }
1091        }
1092
1093        assert!(height(&map) <= max_tree_height(n));
1094        map.clear();
1095    }
1096
1097    #[test]
1098    fn insert_n_random() {
1099        test_env::setup_free();
1100
1101        for k in 1..10 {
1102            // tree size is 2^k
1103            let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1104
1105            let n = 1 << k;
1106            let input: Vec<u32> = random(n);
1107
1108            for x in &input {
1109                map.insert(x, &42);
1110            }
1111
1112            for x in &input {
1113                assert_eq!(map.get(x), Some(42));
1114            }
1115
1116            assert!(height(&map) <= max_tree_height(n));
1117            map.clear();
1118        }
1119    }
1120
1121    #[test]
1122    fn test_min() {
1123        let n: u64 = 30;
1124        let vec = random(n);
1125
1126        let mut map: TreeMap<u32, u32> = TreeMap::new(b't');
1127        for x in vec.iter().rev() {
1128            map.insert(x, &1);
1129        }
1130
1131        assert_eq!(map.min().unwrap(), *vec.iter().min().unwrap());
1132        map.clear();
1133    }
1134
1135    #[test]
1136    fn test_max() {
1137        let n: u64 = 30;
1138        let vec = random(n);
1139
1140        let mut map: TreeMap<u32, u32> = TreeMap::new(b't');
1141        for x in vec.iter().rev() {
1142            map.insert(x, &1);
1143        }
1144
1145        assert_eq!(map.max().unwrap(), *vec.iter().max().unwrap());
1146        map.clear();
1147    }
1148
1149    #[test]
1150    fn test_lower() {
1151        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1152        let vec = [10, 20, 30, 40, 50];
1153
1154        for x in vec.iter() {
1155            map.insert(x, &1);
1156        }
1157
1158        assert_eq!(map.lower(&5), None);
1159        assert_eq!(map.lower(&10), None);
1160        assert_eq!(map.lower(&11), Some(10));
1161        assert_eq!(map.lower(&20), Some(10));
1162        assert_eq!(map.lower(&49), Some(40));
1163        assert_eq!(map.lower(&50), Some(40));
1164        assert_eq!(map.lower(&51), Some(50));
1165
1166        map.clear();
1167    }
1168
1169    #[test]
1170    fn test_higher() {
1171        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1172        let vec = [10, 20, 30, 40, 50];
1173
1174        for x in vec.iter() {
1175            map.insert(x, &1);
1176        }
1177
1178        assert_eq!(map.higher(&5), Some(10));
1179        assert_eq!(map.higher(&10), Some(20));
1180        assert_eq!(map.higher(&11), Some(20));
1181        assert_eq!(map.higher(&20), Some(30));
1182        assert_eq!(map.higher(&49), Some(50));
1183        assert_eq!(map.higher(&50), None);
1184        assert_eq!(map.higher(&51), None);
1185
1186        map.clear();
1187    }
1188
1189    #[test]
1190    fn test_floor_key() {
1191        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1192        let vec = [10, 20, 30, 40, 50];
1193
1194        for x in vec.iter() {
1195            map.insert(x, &1);
1196        }
1197
1198        assert_eq!(map.floor_key(&5), None);
1199        assert_eq!(map.floor_key(&10), Some(10));
1200        assert_eq!(map.floor_key(&11), Some(10));
1201        assert_eq!(map.floor_key(&20), Some(20));
1202        assert_eq!(map.floor_key(&49), Some(40));
1203        assert_eq!(map.floor_key(&50), Some(50));
1204        assert_eq!(map.floor_key(&51), Some(50));
1205
1206        map.clear();
1207    }
1208
1209    #[test]
1210    fn test_ceil_key() {
1211        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1212        let vec = [10, 20, 30, 40, 50];
1213
1214        for x in vec.iter() {
1215            map.insert(x, &1);
1216        }
1217
1218        assert_eq!(map.ceil_key(&5), Some(10));
1219        assert_eq!(map.ceil_key(&10), Some(10));
1220        assert_eq!(map.ceil_key(&11), Some(20));
1221        assert_eq!(map.ceil_key(&20), Some(20));
1222        assert_eq!(map.ceil_key(&49), Some(50));
1223        assert_eq!(map.ceil_key(&50), Some(50));
1224        assert_eq!(map.ceil_key(&51), None);
1225
1226        map.clear();
1227    }
1228
1229    #[test]
1230    fn test_remove_1() {
1231        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1232        map.insert(&1, &1);
1233        assert_eq!(map.get(&1), Some(1));
1234        map.remove(&1);
1235        assert_eq!(map.get(&1), None);
1236        assert_eq!(map.tree.len(), 0);
1237        map.clear();
1238    }
1239
1240    #[test]
1241    fn test_remove_3() {
1242        let map: TreeMap<u32, u32> = avl(&[(0, 0)], &[0, 0, 1]);
1243
1244        assert_eq!(map.iter().collect::<Vec<(u32, u32)>>(), vec![]);
1245    }
1246
1247    #[test]
1248    fn test_remove_3_desc() {
1249        let vec = [3, 2, 1];
1250        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1251
1252        for x in &vec {
1253            assert_eq!(map.get(x), None);
1254            map.insert(x, &1);
1255            assert_eq!(map.get(x), Some(1));
1256        }
1257
1258        for x in &vec {
1259            assert_eq!(map.get(x), Some(1));
1260            map.remove(x);
1261            assert_eq!(map.get(x), None);
1262        }
1263        map.clear();
1264    }
1265
1266    #[test]
1267    fn test_remove_3_asc() {
1268        let vec = [1, 2, 3];
1269        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1270
1271        for x in &vec {
1272            assert_eq!(map.get(x), None);
1273            map.insert(x, &1);
1274            assert_eq!(map.get(x), Some(1));
1275        }
1276
1277        for x in &vec {
1278            assert_eq!(map.get(x), Some(1));
1279            map.remove(x);
1280            assert_eq!(map.get(x), None);
1281        }
1282        map.clear();
1283    }
1284
1285    #[test]
1286    fn test_remove_7_regression_1() {
1287        let vec =
1288            [2104297040, 552624607, 4269683389, 3382615941, 155419892, 4102023417, 1795725075];
1289        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1290
1291        for x in &vec {
1292            assert_eq!(map.get(x), None);
1293            map.insert(x, &1);
1294            assert_eq!(map.get(x), Some(1));
1295        }
1296
1297        for x in &vec {
1298            assert_eq!(map.get(x), Some(1));
1299            map.remove(x);
1300            assert_eq!(map.get(x), None);
1301        }
1302        map.clear();
1303    }
1304
1305    #[test]
1306    fn test_remove_7_regression_2() {
1307        let vec = [700623085, 87488544, 1500140781, 1111706290, 3187278102, 4042663151, 3731533080];
1308        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1309
1310        for x in &vec {
1311            assert_eq!(map.get(x), None);
1312            map.insert(x, &1);
1313            assert_eq!(map.get(x), Some(1));
1314        }
1315
1316        for x in &vec {
1317            assert_eq!(map.get(x), Some(1));
1318            map.remove(x);
1319            assert_eq!(map.get(x), None);
1320        }
1321        map.clear();
1322    }
1323
1324    #[test]
1325    fn test_remove_9_regression() {
1326        let vec = [
1327            1186903464, 506371929, 1738679820, 1883936615, 1815331350, 1512669683, 3581743264,
1328            1396738166, 1902061760,
1329        ];
1330        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1331
1332        for x in &vec {
1333            assert_eq!(map.get(x), None);
1334            map.insert(x, &1);
1335            assert_eq!(map.get(x), Some(1));
1336        }
1337
1338        for x in &vec {
1339            assert_eq!(map.get(x), Some(1));
1340            map.remove(x);
1341            assert_eq!(map.get(x), None);
1342        }
1343        map.clear();
1344    }
1345
1346    #[test]
1347    fn test_remove_20_regression_1() {
1348        let vec = [
1349            552517392, 3638992158, 1015727752, 2500937532, 638716734, 586360620, 2476692174,
1350            1425948996, 3608478547, 757735878, 2709959928, 2092169539, 3620770200, 783020918,
1351            1986928932, 200210441, 1972255302, 533239929, 497054557, 2137924638,
1352        ];
1353        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1354
1355        for x in &vec {
1356            assert_eq!(map.get(x), None);
1357            map.insert(x, &1);
1358            assert_eq!(map.get(x), Some(1));
1359        }
1360
1361        for x in &vec {
1362            assert_eq!(map.get(x), Some(1));
1363            map.remove(x);
1364            assert_eq!(map.get(x), None);
1365        }
1366        map.clear();
1367    }
1368
1369    #[test]
1370    fn test_remove_7_regression() {
1371        let vec = [280, 606, 163, 857, 436, 508, 44, 801];
1372
1373        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1374
1375        for x in &vec {
1376            assert_eq!(map.get(x), None);
1377            map.insert(x, &1);
1378            assert_eq!(map.get(x), Some(1));
1379        }
1380
1381        for x in &vec {
1382            assert_eq!(map.get(x), Some(1));
1383            map.remove(x);
1384            assert_eq!(map.get(x), None);
1385        }
1386
1387        assert_eq!(map.len(), 0, "map.len() > 0");
1388        assert_eq!(map.tree.len(), 0, "map.tree is not empty");
1389        map.clear();
1390    }
1391
1392    #[test]
1393    fn test_insert_8_remove_4_regression() {
1394        let insert = [882, 398, 161, 76];
1395        let remove = [242, 687, 860, 811];
1396
1397        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1398
1399        for (i, (k1, k2)) in insert.iter().zip(remove.iter()).enumerate() {
1400            let v = i as u32;
1401            map.insert(k1, &v);
1402            map.insert(k2, &v);
1403        }
1404
1405        for k in remove.iter() {
1406            map.remove(k);
1407        }
1408
1409        assert_eq!(map.len(), insert.len() as u64);
1410
1411        for (i, k) in insert.iter().enumerate() {
1412            assert_eq!(map.get(k), Some(i as u32));
1413        }
1414    }
1415
1416    #[test]
1417    fn test_remove_n() {
1418        let n: u64 = 20;
1419        let vec = random(n);
1420
1421        let mut set: HashSet<u32> = HashSet::new();
1422        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1423        for x in &vec {
1424            map.insert(x, &1);
1425            set.insert(*x);
1426        }
1427
1428        assert_eq!(map.len(), set.len() as u64);
1429
1430        for x in &set {
1431            assert_eq!(map.get(x), Some(1));
1432            map.remove(x);
1433            assert_eq!(map.get(x), None);
1434        }
1435
1436        assert_eq!(map.len(), 0, "map.len() > 0");
1437        assert_eq!(map.tree.len(), 0, "map.tree is not empty");
1438        map.clear();
1439    }
1440
1441    #[test]
1442    fn test_remove_root_3() {
1443        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1444        map.insert(&2, &1);
1445        map.insert(&3, &1);
1446        map.insert(&1, &1);
1447        map.insert(&4, &1);
1448
1449        map.remove(&2);
1450
1451        assert_eq!(map.get(&1), Some(1));
1452        assert_eq!(map.get(&2), None);
1453        assert_eq!(map.get(&3), Some(1));
1454        assert_eq!(map.get(&4), Some(1));
1455        map.clear();
1456    }
1457
1458    #[test]
1459    fn test_insert_2_remove_2_regression() {
1460        let ins = [11760225, 611327897];
1461        let rem = [2982517385, 1833990072];
1462
1463        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1464        map.insert(&ins[0], &1);
1465        map.insert(&ins[1], &1);
1466
1467        map.remove(&rem[0]);
1468        map.remove(&rem[1]);
1469
1470        let h = height(&map);
1471        let h_max = max_tree_height(map.len());
1472        assert!(h <= h_max, "h={} h_max={}", h, h_max);
1473        map.clear();
1474    }
1475
1476    #[test]
1477    fn test_insert_n_duplicates() {
1478        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1479
1480        for x in 0..30 {
1481            map.insert(&x, &x);
1482            map.insert(&42, &x);
1483        }
1484
1485        assert_eq!(map.get(&42), Some(29));
1486        assert_eq!(map.len(), 31);
1487        assert_eq!(map.tree.len(), 31);
1488
1489        map.clear();
1490    }
1491
1492    #[test]
1493    fn test_insert_2n_remove_n_random() {
1494        for k in 1..4 {
1495            let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1496            let mut set: HashSet<u32> = HashSet::new();
1497
1498            let n = 1 << k;
1499            let ins: Vec<u32> = random(n);
1500            let rem: Vec<u32> = random(n);
1501
1502            for x in &ins {
1503                set.insert(*x);
1504                map.insert(x, &42);
1505            }
1506
1507            for x in &rem {
1508                set.insert(*x);
1509                map.insert(x, &42);
1510            }
1511
1512            for x in &rem {
1513                set.remove(x);
1514                map.remove(x);
1515            }
1516
1517            assert_eq!(map.len(), set.len() as u64);
1518
1519            let h = height(&map);
1520            let h_max = max_tree_height(n);
1521            assert!(h <= h_max, "[n={}] tree is too high: {} (max is {}).", n, h, h_max);
1522
1523            map.clear();
1524        }
1525    }
1526
1527    #[test]
1528    fn test_remove_empty() {
1529        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1530        assert_eq!(map.remove(&1), None);
1531    }
1532
1533    #[test]
1534    fn test_to_vec() {
1535        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1536        map.insert(&1, &41);
1537        map.insert(&2, &42);
1538        map.insert(&3, &43);
1539
1540        assert_eq!(map.to_vec(), vec![(1, 41), (2, 42), (3, 43)]);
1541        map.clear();
1542    }
1543
1544    #[test]
1545    fn test_to_vec_empty() {
1546        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1547        assert!(map.to_vec().is_empty());
1548    }
1549
1550    #[test]
1551    fn test_iter() {
1552        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1553        map.insert(&1, &41);
1554        map.insert(&2, &42);
1555        map.insert(&3, &43);
1556
1557        assert_eq!(map.iter().collect::<Vec<(u32, u32)>>(), vec![(1, 41), (2, 42), (3, 43)]);
1558
1559        // Test custom iterator impls
1560        assert_eq!(map.iter().nth(1), Some((2, 42)));
1561        assert_eq!(map.iter().count(), 3);
1562        assert_eq!(map.iter().last(), Some((3, 43)));
1563        map.clear();
1564    }
1565
1566    #[test]
1567    fn test_iter_empty() {
1568        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1569        assert_eq!(map.iter().count(), 0);
1570    }
1571
1572    #[test]
1573    fn test_iter_rev() {
1574        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1575        map.insert(&1, &41);
1576        map.insert(&2, &42);
1577        map.insert(&3, &43);
1578
1579        assert_eq!(map.iter_rev().collect::<Vec<(u32, u32)>>(), vec![(3, 43), (2, 42), (1, 41)]);
1580
1581        // Test custom iterator impls
1582        assert_eq!(map.iter_rev().nth(1), Some((2, 42)));
1583        assert_eq!(map.iter_rev().count(), 3);
1584        assert_eq!(map.iter_rev().last(), Some((1, 41)));
1585        map.clear();
1586    }
1587
1588    #[test]
1589    fn test_iter_rev_empty() {
1590        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1591        assert_eq!(map.iter_rev().count(), 0);
1592    }
1593
1594    #[test]
1595    fn test_iter_from() {
1596        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1597
1598        let one = [10, 20, 30, 40, 50];
1599        let two = [45, 35, 25, 15, 5];
1600
1601        for x in &one {
1602            map.insert(x, &42);
1603        }
1604
1605        for x in &two {
1606            map.insert(x, &42);
1607        }
1608
1609        assert_eq!(
1610            map.iter_from(29).collect::<Vec<(u32, u32)>>(),
1611            vec![(30, 42), (35, 42), (40, 42), (45, 42), (50, 42)]
1612        );
1613
1614        assert_eq!(
1615            map.iter_from(30).collect::<Vec<(u32, u32)>>(),
1616            vec![(35, 42), (40, 42), (45, 42), (50, 42)]
1617        );
1618
1619        assert_eq!(
1620            map.iter_from(31).collect::<Vec<(u32, u32)>>(),
1621            vec![(35, 42), (40, 42), (45, 42), (50, 42)]
1622        );
1623
1624        // Test custom iterator impls
1625        assert_eq!(map.iter_from(31).nth(2), Some((45, 42)));
1626        assert_eq!(map.iter_from(31).count(), 4);
1627        assert_eq!(map.iter_from(31).last(), Some((50, 42)));
1628
1629        map.clear();
1630    }
1631
1632    #[test]
1633    fn test_iter_from_empty() {
1634        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1635        assert_eq!(map.iter_from(42).count(), 0);
1636    }
1637
1638    #[test]
1639    fn test_iter_rev_from() {
1640        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1641
1642        let one = [10, 20, 30, 40, 50];
1643        let two = [45, 35, 25, 15, 5];
1644
1645        for x in &one {
1646            map.insert(x, &42);
1647        }
1648
1649        for x in &two {
1650            map.insert(x, &42);
1651        }
1652
1653        assert_eq!(
1654            map.iter_rev_from(29).collect::<Vec<(u32, u32)>>(),
1655            vec![(25, 42), (20, 42), (15, 42), (10, 42), (5, 42)]
1656        );
1657
1658        assert_eq!(
1659            map.iter_rev_from(30).collect::<Vec<(u32, u32)>>(),
1660            vec![(25, 42), (20, 42), (15, 42), (10, 42), (5, 42)]
1661        );
1662
1663        assert_eq!(
1664            map.iter_rev_from(31).collect::<Vec<(u32, u32)>>(),
1665            vec![(30, 42), (25, 42), (20, 42), (15, 42), (10, 42), (5, 42)]
1666        );
1667
1668        // Test custom iterator impls
1669        assert_eq!(map.iter_rev_from(31).nth(2), Some((20, 42)));
1670        assert_eq!(map.iter_rev_from(31).count(), 6);
1671        assert_eq!(map.iter_rev_from(31).last(), Some((5, 42)));
1672
1673        map.clear();
1674    }
1675
1676    #[test]
1677    fn test_range() {
1678        let mut map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1679
1680        let one = [10, 20, 30, 40, 50];
1681        let two = [45, 35, 25, 15, 5];
1682
1683        for x in &one {
1684            map.insert(x, &42);
1685        }
1686
1687        for x in &two {
1688            map.insert(x, &42);
1689        }
1690
1691        assert_eq!(
1692            map.range((Bound::Included(20), Bound::Excluded(30))).collect::<Vec<(u32, u32)>>(),
1693            vec![(20, 42), (25, 42)]
1694        );
1695
1696        assert_eq!(
1697            map.range((Bound::Excluded(10), Bound::Included(40))).collect::<Vec<(u32, u32)>>(),
1698            vec![(15, 42), (20, 42), (25, 42), (30, 42), (35, 42), (40, 42)]
1699        );
1700
1701        assert_eq!(
1702            map.range((Bound::Included(20), Bound::Included(40))).collect::<Vec<(u32, u32)>>(),
1703            vec![(20, 42), (25, 42), (30, 42), (35, 42), (40, 42)]
1704        );
1705
1706        assert_eq!(
1707            map.range((Bound::Excluded(20), Bound::Excluded(45))).collect::<Vec<(u32, u32)>>(),
1708            vec![(25, 42), (30, 42), (35, 42), (40, 42)]
1709        );
1710
1711        assert_eq!(
1712            map.range((Bound::Excluded(25), Bound::Excluded(30))).collect::<Vec<(u32, u32)>>(),
1713            vec![]
1714        );
1715
1716        assert_eq!(
1717            map.range((Bound::Included(25), Bound::Included(25))).collect::<Vec<(u32, u32)>>(),
1718            vec![(25, 42)]
1719        );
1720
1721        assert_eq!(
1722            map.range((Bound::Excluded(25), Bound::Included(25))).collect::<Vec<(u32, u32)>>(),
1723            vec![]
1724        ); // the range makes no sense, but `BTreeMap` does not panic in this case
1725
1726        // Test custom iterator impls
1727        assert_eq!(map.range((Bound::Excluded(20), Bound::Excluded(45))).nth(2), Some((35, 42)));
1728        assert_eq!(map.range((Bound::Excluded(20), Bound::Excluded(45))).count(), 4);
1729        assert_eq!(map.range((Bound::Excluded(20), Bound::Excluded(45))).last(), Some((40, 42)));
1730
1731        map.clear();
1732    }
1733
1734    #[test]
1735    #[should_panic(expected = "Invalid range.")]
1736    fn test_range_panics_same_excluded() {
1737        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1738        let _ = map.range((Bound::Excluded(1), Bound::Excluded(1)));
1739    }
1740
1741    #[test]
1742    #[should_panic(expected = "Invalid range.")]
1743    fn test_range_panics_non_overlap_incl_exlc() {
1744        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1745        let _ = map.range((Bound::Included(2), Bound::Excluded(1)));
1746    }
1747
1748    #[test]
1749    #[should_panic(expected = "Invalid range.")]
1750    fn test_range_panics_non_overlap_excl_incl() {
1751        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1752        let _ = map.range((Bound::Excluded(2), Bound::Included(1)));
1753    }
1754
1755    #[test]
1756    #[should_panic(expected = "Invalid range.")]
1757    fn test_range_panics_non_overlap_incl_incl() {
1758        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1759        let _ = map.range((Bound::Included(2), Bound::Included(1)));
1760    }
1761
1762    #[test]
1763    #[should_panic(expected = "Invalid range.")]
1764    fn test_range_panics_non_overlap_excl_excl() {
1765        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1766        let _ = map.range((Bound::Excluded(2), Bound::Excluded(1)));
1767    }
1768
1769    #[test]
1770    fn test_iter_rev_from_empty() {
1771        let map: TreeMap<u32, u32> = TreeMap::new(next_trie_id());
1772        assert_eq!(map.iter_rev_from(42).count(), 0);
1773    }
1774
1775    #[test]
1776    fn test_balance_regression_1() {
1777        let insert = [(2, 0), (3, 0), (4, 0)];
1778        let remove = [0, 0, 0, 1];
1779
1780        let map = avl(&insert, &remove);
1781        assert!(is_balanced(&map, map.root));
1782    }
1783
1784    #[test]
1785    fn test_balance_regression_2() {
1786        let insert = [(1, 0), (2, 0), (0, 0), (3, 0), (5, 0), (6, 0)];
1787        let remove = [0, 0, 0, 3, 5, 6, 7, 4];
1788
1789        let map = avl(&insert, &remove);
1790        assert!(is_balanced(&map, map.root));
1791    }
1792
1793    //
1794    // Property-based tests of AVL-based TreeMap against std::collections::BTreeMap
1795    //
1796
1797    fn avl<K, V>(insert: &[(K, V)], remove: &[K]) -> TreeMap<K, V>
1798    where
1799        K: Ord + Clone + BorshSerialize + BorshDeserialize,
1800        V: Default + BorshSerialize + BorshDeserialize,
1801    {
1802        test_env::setup_free();
1803        let mut map: TreeMap<K, V> = TreeMap::new(next_trie_id());
1804        for k in remove {
1805            map.insert(k, &Default::default());
1806        }
1807        let n = insert.len().max(remove.len());
1808        for i in 0..n {
1809            if i < remove.len() {
1810                map.remove(&remove[i]);
1811            }
1812            if i < insert.len() {
1813                let (k, v) = &insert[i];
1814                map.insert(k, v);
1815            }
1816        }
1817        map
1818    }
1819
1820    fn rb<K, V>(insert: &[(K, V)], remove: &[K]) -> BTreeMap<K, V>
1821    where
1822        K: Ord + Clone + BorshSerialize + BorshDeserialize,
1823        V: Clone + Default + BorshSerialize + BorshDeserialize,
1824    {
1825        let mut map: BTreeMap<K, V> = BTreeMap::default();
1826        for k in remove {
1827            map.insert(k.clone(), Default::default());
1828        }
1829        let n = insert.len().max(remove.len());
1830        for i in 0..n {
1831            if i < remove.len() {
1832                map.remove(&remove[i]);
1833            }
1834            if i < insert.len() {
1835                let (k, v) = &insert[i];
1836                map.insert(k.clone(), v.clone());
1837            }
1838        }
1839        map
1840    }
1841
1842    #[test]
1843    fn prop_avl_vs_rb() {
1844        fn prop(insert: Vec<(u32, u32)>, remove: Vec<u32>) -> bool {
1845            let a = avl(&insert, &remove);
1846            let b = rb(&insert, &remove);
1847            let v1: Vec<(u32, u32)> = a.iter().collect();
1848            let v2: Vec<(u32, u32)> = b.into_iter().collect();
1849            v1 == v2
1850        }
1851
1852        QuickCheck::new()
1853            .tests(300)
1854            .quickcheck(prop as fn(std::vec::Vec<(u32, u32)>, std::vec::Vec<u32>) -> bool);
1855    }
1856
1857    fn is_balanced<K, V>(map: &TreeMap<K, V>, root: u64) -> bool
1858    where
1859        K: std::fmt::Debug + Ord + Clone + BorshSerialize + BorshDeserialize,
1860        V: std::fmt::Debug + BorshSerialize + BorshDeserialize,
1861    {
1862        let node = map.node(root).unwrap();
1863        let balance = map.get_balance(&node);
1864
1865        (-1..=1).contains(&balance)
1866            && node.lft.map(|id| is_balanced(map, id)).unwrap_or(true)
1867            && node.rgt.map(|id| is_balanced(map, id)).unwrap_or(true)
1868    }
1869
1870    #[test]
1871    fn prop_avl_balance() {
1872        test_env::setup_free();
1873
1874        fn prop(insert: Vec<(u32, u32)>, remove: Vec<u32>) -> bool {
1875            let map = avl(&insert, &remove);
1876            map.is_empty() || is_balanced(&map, map.root)
1877        }
1878
1879        QuickCheck::new()
1880            .tests(300)
1881            .quickcheck(prop as fn(std::vec::Vec<(u32, u32)>, std::vec::Vec<u32>) -> bool);
1882    }
1883
1884    #[test]
1885    fn prop_avl_height() {
1886        test_env::setup_free();
1887
1888        fn prop(insert: Vec<(u32, u32)>, remove: Vec<u32>) -> bool {
1889            let map = avl(&insert, &remove);
1890            height(&map) <= max_tree_height(map.len())
1891        }
1892
1893        QuickCheck::new()
1894            .tests(300)
1895            .quickcheck(prop as fn(std::vec::Vec<(u32, u32)>, std::vec::Vec<u32>) -> bool);
1896    }
1897
1898    fn range_prop(
1899        insert: Vec<(u32, u32)>,
1900        remove: Vec<u32>,
1901        range: (Bound<u32>, Bound<u32>),
1902    ) -> bool {
1903        let a = avl(&insert, &remove);
1904        let b = rb(&insert, &remove);
1905        let v1: Vec<(u32, u32)> = a.range(range).collect();
1906        let v2: Vec<(u32, u32)> = b.range(range).map(|(k, v)| (*k, *v)).collect();
1907        v1 == v2
1908    }
1909
1910    type Prop = fn(std::vec::Vec<(u32, u32)>, std::vec::Vec<u32>, u32, u32) -> bool;
1911
1912    #[test]
1913    fn prop_avl_vs_rb_range_incl_incl() {
1914        fn prop(insert: Vec<(u32, u32)>, remove: Vec<u32>, r1: u32, r2: u32) -> bool {
1915            let range = (Bound::Included(r1.min(r2)), Bound::Included(r1.max(r2)));
1916            range_prop(insert, remove, range)
1917        }
1918
1919        QuickCheck::new().tests(300).quickcheck(prop as Prop);
1920    }
1921
1922    #[test]
1923    fn prop_avl_vs_rb_range_incl_excl() {
1924        fn prop(insert: Vec<(u32, u32)>, remove: Vec<u32>, r1: u32, r2: u32) -> bool {
1925            let range = (Bound::Included(r1.min(r2)), Bound::Excluded(r1.max(r2)));
1926            range_prop(insert, remove, range)
1927        }
1928
1929        QuickCheck::new().tests(300).quickcheck(prop as Prop);
1930    }
1931
1932    #[test]
1933    fn prop_avl_vs_rb_range_excl_incl() {
1934        fn prop(insert: Vec<(u32, u32)>, remove: Vec<u32>, r1: u32, r2: u32) -> bool {
1935            let range = (Bound::Excluded(r1.min(r2)), Bound::Included(r1.max(r2)));
1936            range_prop(insert, remove, range)
1937        }
1938
1939        QuickCheck::new().tests(300).quickcheck(prop as Prop);
1940    }
1941
1942    #[test]
1943    fn prop_avl_vs_rb_range_excl_excl() {
1944        fn prop(insert: Vec<(u32, u32)>, remove: Vec<u32>, r1: u32, r2: u32) -> bool {
1945            // (Excluded(x), Excluded(x)) is invalid range, checking against it makes no sense
1946            r1 == r2 || {
1947                let range = (Bound::Excluded(r1.min(r2)), Bound::Excluded(r1.max(r2)));
1948                range_prop(insert, remove, range)
1949            }
1950        }
1951
1952        QuickCheck::new().tests(300).quickcheck(prop as Prop);
1953    }
1954
1955    #[test]
1956    fn test_debug() {
1957        let mut map = TreeMap::new(b"m");
1958        map.insert(&1, &100);
1959        map.insert(&3, &300);
1960        map.insert(&2, &200);
1961
1962        if cfg!(feature = "expensive-debug") {
1963            let node1 = "Node { id: 0, key: 1, lft: None, rgt: None, ht: 1 }";
1964            let node2 = "Node { id: 2, key: 2, lft: Some(0), rgt: Some(1), ht: 2 }";
1965            let node3 = "Node { id: 1, key: 3, lft: None, rgt: None, ht: 1 }";
1966            assert_eq!(
1967                format!("{:?}", map),
1968                format!("TreeMap {{ root: 2, tree: [{}, {}, {}] }}", node1, node3, node2)
1969            );
1970        } else {
1971            assert_eq!(
1972                format!("{:?}", map),
1973                "TreeMap { root: 2, tree: Vector { len: 3, prefix: [109, 110] } }"
1974            );
1975        }
1976    }
1977}