serum_dex/
critbit.rs

1use crate::{
2    error::{DexErrorCode, DexResult},
3    fees::FeeTier,
4};
5use arrayref::{array_refs, mut_array_refs};
6use bytemuck::{cast, cast_mut, cast_ref, cast_slice, cast_slice_mut, Pod, Zeroable};
7
8use num_enum::{IntoPrimitive, TryFromPrimitive};
9use static_assertions::const_assert_eq;
10use std::{
11    convert::{identity, TryFrom},
12    mem::{align_of, size_of},
13    num::NonZeroU64,
14};
15
16pub type NodeHandle = u32;
17
18#[derive(IntoPrimitive, TryFromPrimitive)]
19#[repr(u32)]
20enum NodeTag {
21    Uninitialized = 0,
22    InnerNode = 1,
23    LeafNode = 2,
24    FreeNode = 3,
25    LastFreeNode = 4,
26}
27
28#[derive(Copy, Clone)]
29#[repr(packed)]
30#[allow(dead_code)]
31struct InnerNode {
32    tag: u32,
33    prefix_len: u32,
34    key: u128,
35    children: [u32; 2],
36    _padding: [u64; 5],
37}
38unsafe impl Zeroable for InnerNode {}
39unsafe impl Pod for InnerNode {}
40
41impl InnerNode {
42    fn walk_down(&self, search_key: u128) -> (NodeHandle, bool) {
43        let crit_bit_mask = (1u128 << 127) >> self.prefix_len;
44        let crit_bit = (search_key & crit_bit_mask) != 0;
45        (self.children[crit_bit as usize], crit_bit)
46    }
47}
48
49#[derive(Debug, Copy, Clone, PartialEq, Eq)]
50#[repr(packed)]
51pub struct LeafNode {
52    tag: u32,
53    owner_slot: u8,
54    fee_tier: u8,
55    padding: [u8; 2],
56    key: u128,
57    owner: [u64; 4],
58    quantity: u64,
59    client_order_id: u64,
60}
61unsafe impl Zeroable for LeafNode {}
62unsafe impl Pod for LeafNode {}
63
64impl LeafNode {
65    #[inline]
66    pub fn new(
67        owner_slot: u8,
68        key: u128,
69        owner: [u64; 4],
70        quantity: u64,
71        fee_tier: FeeTier,
72        client_order_id: u64,
73    ) -> Self {
74        LeafNode {
75            tag: NodeTag::LeafNode.into(),
76            owner_slot,
77            fee_tier: fee_tier.into(),
78            padding: [0; 2],
79            key,
80            owner,
81            quantity,
82            client_order_id,
83        }
84    }
85
86    #[inline]
87    pub fn fee_tier(&self) -> FeeTier {
88        FeeTier::try_from_primitive(self.fee_tier).unwrap()
89    }
90
91    #[inline]
92    pub fn price(&self) -> NonZeroU64 {
93        NonZeroU64::new((self.key >> 64) as u64).unwrap()
94    }
95
96    #[inline]
97    pub fn order_id(&self) -> u128 {
98        self.key
99    }
100
101    #[inline]
102    pub fn quantity(&self) -> u64 {
103        self.quantity
104    }
105
106    #[inline]
107    pub fn set_quantity(&mut self, quantity: u64) {
108        self.quantity = quantity;
109    }
110
111    #[inline]
112    pub fn owner(&self) -> [u64; 4] {
113        self.owner
114    }
115
116    #[inline]
117    pub fn owner_slot(&self) -> u8 {
118        self.owner_slot
119    }
120
121    #[inline]
122    pub fn client_order_id(&self) -> u64 {
123        self.client_order_id
124    }
125}
126
127#[derive(Copy, Clone)]
128#[repr(packed)]
129#[allow(dead_code)]
130struct FreeNode {
131    tag: u32,
132    next: u32,
133    _padding: [u64; 8],
134}
135unsafe impl Zeroable for FreeNode {}
136unsafe impl Pod for FreeNode {}
137
138const fn _const_max(a: usize, b: usize) -> usize {
139    let gt = (a > b) as usize;
140    gt * a + (1 - gt) * b
141}
142
143const _INNER_NODE_SIZE: usize = size_of::<InnerNode>();
144const _LEAF_NODE_SIZE: usize = size_of::<LeafNode>();
145const _FREE_NODE_SIZE: usize = size_of::<FreeNode>();
146const _NODE_SIZE: usize = 72;
147
148const _INNER_NODE_ALIGN: usize = align_of::<InnerNode>();
149const _LEAF_NODE_ALIGN: usize = align_of::<LeafNode>();
150const _FREE_NODE_ALIGN: usize = align_of::<FreeNode>();
151const _NODE_ALIGN: usize = 1;
152
153const_assert_eq!(_NODE_SIZE, _INNER_NODE_SIZE);
154const_assert_eq!(_NODE_SIZE, _LEAF_NODE_SIZE);
155const_assert_eq!(_NODE_SIZE, _FREE_NODE_SIZE);
156
157const_assert_eq!(_NODE_ALIGN, _INNER_NODE_ALIGN);
158const_assert_eq!(_NODE_ALIGN, _LEAF_NODE_ALIGN);
159const_assert_eq!(_NODE_ALIGN, _FREE_NODE_ALIGN);
160
161#[derive(Copy, Clone)]
162#[repr(packed)]
163#[allow(dead_code)]
164pub struct AnyNode {
165    tag: u32,
166    data: [u32; 17],
167}
168unsafe impl Zeroable for AnyNode {}
169unsafe impl Pod for AnyNode {}
170
171enum NodeRef<'a> {
172    Inner(&'a InnerNode),
173    Leaf(&'a LeafNode),
174}
175
176enum NodeRefMut<'a> {
177    Inner(&'a mut InnerNode),
178    Leaf(&'a mut LeafNode),
179}
180
181impl AnyNode {
182    fn key(&self) -> Option<u128> {
183        match self.case()? {
184            NodeRef::Inner(inner) => Some(inner.key),
185            NodeRef::Leaf(leaf) => Some(leaf.key),
186        }
187    }
188
189    #[cfg(test)]
190    fn prefix_len(&self) -> u32 {
191        match self.case().unwrap() {
192            NodeRef::Inner(&InnerNode { prefix_len, .. }) => prefix_len,
193            NodeRef::Leaf(_) => 128,
194        }
195    }
196
197    fn children(&self) -> Option<[u32; 2]> {
198        match self.case().unwrap() {
199            NodeRef::Inner(&InnerNode { children, .. }) => Some(children),
200            NodeRef::Leaf(_) => None,
201        }
202    }
203
204    fn case(&self) -> Option<NodeRef> {
205        match NodeTag::try_from(self.tag) {
206            Ok(NodeTag::InnerNode) => Some(NodeRef::Inner(cast_ref(self))),
207            Ok(NodeTag::LeafNode) => Some(NodeRef::Leaf(cast_ref(self))),
208            _ => None,
209        }
210    }
211
212    fn case_mut(&mut self) -> Option<NodeRefMut> {
213        match NodeTag::try_from(self.tag) {
214            Ok(NodeTag::InnerNode) => Some(NodeRefMut::Inner(cast_mut(self))),
215            Ok(NodeTag::LeafNode) => Some(NodeRefMut::Leaf(cast_mut(self))),
216            _ => None,
217        }
218    }
219
220    #[inline]
221    pub fn as_leaf(&self) -> Option<&LeafNode> {
222        match self.case() {
223            Some(NodeRef::Leaf(leaf_ref)) => Some(leaf_ref),
224            _ => None,
225        }
226    }
227
228    #[inline]
229    pub fn as_leaf_mut(&mut self) -> Option<&mut LeafNode> {
230        match self.case_mut() {
231            Some(NodeRefMut::Leaf(leaf_ref)) => Some(leaf_ref),
232            _ => None,
233        }
234    }
235}
236
237impl AsRef<AnyNode> for InnerNode {
238    fn as_ref(&self) -> &AnyNode {
239        cast_ref(self)
240    }
241}
242
243impl AsRef<AnyNode> for LeafNode {
244    #[inline]
245    fn as_ref(&self) -> &AnyNode {
246        cast_ref(self)
247    }
248}
249
250const_assert_eq!(_NODE_SIZE, size_of::<AnyNode>());
251const_assert_eq!(_NODE_ALIGN, align_of::<AnyNode>());
252
253#[derive(Copy, Clone)]
254#[repr(packed)]
255struct SlabHeader {
256    bump_index: u64,
257    free_list_len: u64,
258    free_list_head: u32,
259
260    root_node: u32,
261    leaf_count: u64,
262}
263unsafe impl Zeroable for SlabHeader {}
264unsafe impl Pod for SlabHeader {}
265
266const SLAB_HEADER_LEN: usize = size_of::<SlabHeader>();
267
268#[cfg(debug_assertions)]
269unsafe fn invariant(check: bool) {
270    if check {
271        unreachable!();
272    }
273}
274
275#[cfg(not(debug_assertions))]
276#[inline(always)]
277unsafe fn invariant(check: bool) {
278    if check {
279        std::hint::unreachable_unchecked();
280    }
281}
282
283#[repr(transparent)]
284pub struct Slab([u8]);
285
286impl Slab {
287    /// Creates a slab that holds and references the bytes
288    ///
289    /// ```compile_fail
290    /// let slab = {
291    ///     let mut bytes = [10; 100];
292    ///     serum_dex::critbit::Slab::new(&mut bytes)
293    /// };
294    /// ```
295    #[inline]
296    pub fn new(bytes: &mut [u8]) -> &mut Self {
297        let len_without_header = bytes.len().checked_sub(SLAB_HEADER_LEN).unwrap();
298        let slop = len_without_header % size_of::<AnyNode>();
299        let truncated_len = bytes.len() - slop;
300        let bytes = &mut bytes[..truncated_len];
301        let slab: &mut Self = unsafe { &mut *(bytes as *mut [u8] as *mut Slab) };
302        slab.check_size_align(); // check alignment
303        slab
304    }
305
306    #[inline]
307    pub fn assert_minimum_capacity(&self, capacity: u32) -> DexResult {
308        if self.nodes().len() <= (capacity as usize) * 2 {
309            Err(DexErrorCode::SlabTooSmall)?
310        }
311        Ok(())
312    }
313
314    fn check_size_align(&self) {
315        let (header_bytes, nodes_bytes) = array_refs![&self.0, SLAB_HEADER_LEN; .. ;];
316        let _header: &SlabHeader = cast_ref(header_bytes);
317        let _nodes: &[AnyNode] = cast_slice(nodes_bytes);
318    }
319
320    fn parts(&self) -> (&SlabHeader, &[AnyNode]) {
321        unsafe {
322            invariant(self.0.len() < size_of::<SlabHeader>());
323            invariant((self.0.as_ptr() as usize) % align_of::<SlabHeader>() != 0);
324            invariant(
325                ((self.0.as_ptr() as usize) + size_of::<SlabHeader>()) % align_of::<AnyNode>() != 0,
326            );
327        }
328
329        let (header_bytes, nodes_bytes) = array_refs![&self.0, SLAB_HEADER_LEN; .. ;];
330        let header = cast_ref(header_bytes);
331        let nodes = cast_slice(nodes_bytes);
332        (header, nodes)
333    }
334
335    fn parts_mut(&mut self) -> (&mut SlabHeader, &mut [AnyNode]) {
336        unsafe {
337            invariant(self.0.len() < size_of::<SlabHeader>());
338            invariant((self.0.as_ptr() as usize) % align_of::<SlabHeader>() != 0);
339            invariant(
340                ((self.0.as_ptr() as usize) + size_of::<SlabHeader>()) % align_of::<AnyNode>() != 0,
341            );
342        }
343
344        let (header_bytes, nodes_bytes) = mut_array_refs![&mut self.0, SLAB_HEADER_LEN; .. ;];
345        let header = cast_mut(header_bytes);
346        let nodes = cast_slice_mut(nodes_bytes);
347        (header, nodes)
348    }
349
350    fn header(&self) -> &SlabHeader {
351        self.parts().0
352    }
353
354    fn header_mut(&mut self) -> &mut SlabHeader {
355        self.parts_mut().0
356    }
357
358    fn nodes(&self) -> &[AnyNode] {
359        self.parts().1
360    }
361
362    fn nodes_mut(&mut self) -> &mut [AnyNode] {
363        self.parts_mut().1
364    }
365}
366
367pub trait SlabView<T> {
368    fn capacity(&self) -> u64;
369    fn clear(&mut self);
370    fn is_empty(&self) -> bool;
371    fn get(&self, h: NodeHandle) -> Option<&T>;
372    fn get_mut(&mut self, h: NodeHandle) -> Option<&mut T>;
373    fn insert(&mut self, val: &T) -> Result<u32, ()>;
374    fn remove(&mut self, h: NodeHandle) -> Option<T>;
375    fn contains(&self, h: NodeHandle) -> bool;
376}
377
378impl SlabView<AnyNode> for Slab {
379    fn capacity(&self) -> u64 {
380        self.nodes().len() as u64
381    }
382
383    fn clear(&mut self) {
384        let (header, _nodes) = self.parts_mut();
385        *header = SlabHeader {
386            bump_index: 0,
387            free_list_len: 0,
388            free_list_head: 0,
389
390            root_node: 0,
391            leaf_count: 0,
392        }
393    }
394
395    fn is_empty(&self) -> bool {
396        let SlabHeader {
397            bump_index,
398            free_list_len,
399            ..
400        } = *self.header();
401        bump_index == free_list_len
402    }
403
404    fn get(&self, key: u32) -> Option<&AnyNode> {
405        let node = self.nodes().get(key as usize)?;
406        let tag = NodeTag::try_from(node.tag);
407        match tag {
408            Ok(NodeTag::InnerNode) | Ok(NodeTag::LeafNode) => Some(node),
409            _ => None,
410        }
411    }
412
413    fn get_mut(&mut self, key: u32) -> Option<&mut AnyNode> {
414        let node = self.nodes_mut().get_mut(key as usize)?;
415        let tag = NodeTag::try_from(node.tag);
416        match tag {
417            Ok(NodeTag::InnerNode) | Ok(NodeTag::LeafNode) => Some(node),
418            _ => None,
419        }
420    }
421
422    fn insert(&mut self, val: &AnyNode) -> Result<u32, ()> {
423        match NodeTag::try_from(identity(val.tag)) {
424            Ok(NodeTag::InnerNode) | Ok(NodeTag::LeafNode) => (),
425            _ => unreachable!(),
426        };
427
428        let (header, nodes) = self.parts_mut();
429
430        if header.free_list_len == 0 {
431            if header.bump_index as usize == nodes.len() {
432                return Err(());
433            }
434
435            if header.bump_index == std::u32::MAX as u64 {
436                return Err(());
437            }
438            let key = header.bump_index as u32;
439            header.bump_index += 1;
440
441            nodes[key as usize] = *val;
442            return Ok(key);
443        }
444
445        let key = header.free_list_head;
446        let node = &mut nodes[key as usize];
447
448        match NodeTag::try_from(node.tag) {
449            Ok(NodeTag::FreeNode) => assert!(header.free_list_len > 1),
450            Ok(NodeTag::LastFreeNode) => assert_eq!(identity(header.free_list_len), 1),
451            _ => unreachable!(),
452        };
453
454        let next_free_list_head: u32;
455        {
456            let free_list_item: &FreeNode = cast_ref(node);
457            next_free_list_head = free_list_item.next;
458        }
459        header.free_list_head = next_free_list_head;
460        header.free_list_len -= 1;
461        *node = *val;
462        Ok(key)
463    }
464
465    fn remove(&mut self, key: u32) -> Option<AnyNode> {
466        let val = *self.get(key)?;
467        let (header, nodes) = self.parts_mut();
468        let any_node_ref = &mut nodes[key as usize];
469        let free_node_ref: &mut FreeNode = cast_mut(any_node_ref);
470        *free_node_ref = FreeNode {
471            tag: if header.free_list_len == 0 {
472                NodeTag::LastFreeNode.into()
473            } else {
474                NodeTag::FreeNode.into()
475            },
476            next: header.free_list_head,
477            _padding: Zeroable::zeroed(),
478        };
479        header.free_list_len += 1;
480        header.free_list_head = key;
481        Some(val)
482    }
483
484    fn contains(&self, key: u32) -> bool {
485        self.get(key).is_some()
486    }
487}
488
489#[derive(Debug)]
490pub enum SlabTreeError {
491    OutOfSpace,
492}
493
494impl Slab {
495    fn root(&self) -> Option<NodeHandle> {
496        if self.header().leaf_count == 0 {
497            return None;
498        }
499
500        Some(self.header().root_node)
501    }
502
503    fn find_min_max(&self, find_max: bool) -> Option<NodeHandle> {
504        let mut root: NodeHandle = self.root()?;
505        loop {
506            let root_contents = self.get(root).unwrap();
507            match root_contents.case().unwrap() {
508                NodeRef::Inner(&InnerNode { children, .. }) => {
509                    root = children[if find_max { 1 } else { 0 }];
510                    continue;
511                }
512                _ => return Some(root),
513            }
514        }
515    }
516
517    #[inline]
518    pub fn find_min(&self) -> Option<NodeHandle> {
519        self.find_min_max(false)
520    }
521
522    #[inline]
523    pub fn find_max(&self) -> Option<NodeHandle> {
524        self.find_min_max(true)
525    }
526
527    #[inline]
528    pub fn insert_leaf(
529        &mut self,
530        new_leaf: &LeafNode,
531    ) -> Result<(NodeHandle, Option<LeafNode>), SlabTreeError> {
532        let mut root: NodeHandle = match self.root() {
533            Some(h) => h,
534            None => {
535                // create a new root if none exists
536                match self.insert(new_leaf.as_ref()) {
537                    Ok(handle) => {
538                        self.header_mut().root_node = handle;
539                        self.header_mut().leaf_count = 1;
540                        return Ok((handle, None));
541                    }
542                    Err(()) => return Err(SlabTreeError::OutOfSpace),
543                }
544            }
545        };
546        loop {
547            // check if the new node will be a child of the root
548            let root_contents = *self.get(root).unwrap();
549            let root_key = root_contents.key().unwrap();
550            if root_key == new_leaf.key {
551                if let Some(NodeRef::Leaf(&old_root_as_leaf)) = root_contents.case() {
552                    // clobber the existing leaf
553                    *self.get_mut(root).unwrap() = *new_leaf.as_ref();
554                    return Ok((root, Some(old_root_as_leaf)));
555                }
556            }
557            let shared_prefix_len: u32 = (root_key ^ new_leaf.key).leading_zeros();
558            match root_contents.case() {
559                None => unreachable!(),
560                Some(NodeRef::Inner(inner)) => {
561                    let keep_old_root = shared_prefix_len >= inner.prefix_len;
562                    if keep_old_root {
563                        root = inner.walk_down(new_leaf.key).0;
564                        continue;
565                    };
566                }
567                _ => (),
568            };
569
570            // change the root in place to represent the LCA of [new_leaf] and [root]
571            let crit_bit_mask: u128 = (1u128 << 127) >> shared_prefix_len;
572            let new_leaf_crit_bit = (crit_bit_mask & new_leaf.key) != 0;
573            let old_root_crit_bit = !new_leaf_crit_bit;
574
575            let new_leaf_handle = self
576                .insert(new_leaf.as_ref())
577                .map_err(|()| SlabTreeError::OutOfSpace)?;
578            let moved_root_handle = match self.insert(&root_contents) {
579                Ok(h) => h,
580                Err(()) => {
581                    self.remove(new_leaf_handle).unwrap();
582                    return Err(SlabTreeError::OutOfSpace);
583                }
584            };
585
586            let new_root: &mut InnerNode = cast_mut(self.get_mut(root).unwrap());
587            *new_root = InnerNode {
588                tag: NodeTag::InnerNode.into(),
589                prefix_len: shared_prefix_len,
590                key: new_leaf.key,
591                children: [0; 2],
592                _padding: Zeroable::zeroed(),
593            };
594
595            new_root.children[new_leaf_crit_bit as usize] = new_leaf_handle;
596            new_root.children[old_root_crit_bit as usize] = moved_root_handle;
597            self.header_mut().leaf_count += 1;
598            return Ok((new_leaf_handle, None));
599        }
600    }
601
602    #[cfg(test)]
603    fn find_by_key(&self, search_key: u128) -> Option<NodeHandle> {
604        let mut node_handle: NodeHandle = self.root()?;
605        loop {
606            let node_ref = self.get(node_handle).unwrap();
607            let node_prefix_len = node_ref.prefix_len();
608            let node_key = node_ref.key().unwrap();
609            let common_prefix_len = (search_key ^ node_key).leading_zeros();
610            if common_prefix_len < node_prefix_len {
611                return None;
612            }
613            match node_ref.case().unwrap() {
614                NodeRef::Leaf(_) => break Some(node_handle),
615                NodeRef::Inner(inner) => {
616                    let crit_bit_mask = (1u128 << 127) >> node_prefix_len;
617                    let _search_key_crit_bit = (search_key & crit_bit_mask) != 0;
618                    node_handle = inner.walk_down(search_key).0;
619                    continue;
620                }
621            }
622        }
623    }
624
625    pub(crate) fn find_by<F: Fn(&LeafNode) -> bool>(
626        &self,
627        limit: &mut u16,
628        predicate: F,
629    ) -> Vec<u128> {
630        let mut found = Vec::new();
631        let mut nodes_to_search: Vec<NodeHandle> = Vec::new();
632        let mut current_node: Option<&AnyNode>;
633
634        let top_node = self.root();
635
636        // No found nodes.
637        if top_node.is_none() {
638            return found;
639        }
640
641        nodes_to_search.push(top_node.unwrap());
642
643        // Search through the tree.
644        while !nodes_to_search.is_empty() && *limit > 0 {
645            *limit -= 1;
646
647            current_node = self.get(nodes_to_search.pop().unwrap());
648
649            // Node not found.
650            if current_node.is_none() {
651                break;
652            }
653
654            match current_node.unwrap().case().unwrap() {
655                NodeRef::Leaf(leaf) if predicate(leaf) => {
656                    // Found a matching leaf.
657                    found.push(leaf.key)
658                }
659                NodeRef::Inner(inner) => {
660                    // Search the children.
661                    nodes_to_search.push(inner.children[0]);
662                    nodes_to_search.push(inner.children[1]);
663                }
664                _ => (),
665            }
666        }
667
668        found
669    }
670
671    #[inline]
672    pub fn remove_by_key(&mut self, search_key: u128) -> Option<LeafNode> {
673        let mut parent_h = self.root()?;
674        let mut child_h;
675        let mut crit_bit;
676        match self.get(parent_h).unwrap().case().unwrap() {
677            NodeRef::Leaf(&leaf) if leaf.key == search_key => {
678                let header = self.header_mut();
679                assert_eq!(identity(header.leaf_count), 1);
680                header.root_node = 0;
681                header.leaf_count = 0;
682                let _old_root = self.remove(parent_h).unwrap();
683                return Some(leaf);
684            }
685            NodeRef::Leaf(_) => return None,
686            NodeRef::Inner(inner) => {
687                let (ch, cb) = inner.walk_down(search_key);
688                child_h = ch;
689                crit_bit = cb;
690            }
691        }
692        loop {
693            match self.get(child_h).unwrap().case().unwrap() {
694                NodeRef::Inner(inner) => {
695                    let (grandchild_h, grandchild_crit_bit) = inner.walk_down(search_key);
696                    parent_h = child_h;
697                    child_h = grandchild_h;
698                    crit_bit = grandchild_crit_bit;
699                    continue;
700                }
701                NodeRef::Leaf(&leaf) => {
702                    if leaf.key != search_key {
703                        return None;
704                    }
705
706                    break;
707                }
708            }
709        }
710        // replace parent with its remaining child node
711        // free child_h, replace *parent_h with *other_child_h, free other_child_h
712        let other_child_h = self.get(parent_h).unwrap().children().unwrap()[!crit_bit as usize];
713        let other_child_node_contents = self.remove(other_child_h).unwrap();
714        *self.get_mut(parent_h).unwrap() = other_child_node_contents;
715        self.header_mut().leaf_count -= 1;
716        Some(cast(self.remove(child_h).unwrap()))
717    }
718
719    #[inline]
720    pub fn remove_min(&mut self) -> Option<LeafNode> {
721        self.remove_by_key(self.get(self.find_min()?)?.key()?)
722    }
723
724    #[inline]
725    pub fn remove_max(&mut self) -> Option<LeafNode> {
726        self.remove_by_key(self.get(self.find_max()?)?.key()?)
727    }
728
729    #[cfg(test)]
730    fn traverse(&self) -> Vec<&LeafNode> {
731        fn walk_rec<'a>(slab: &'a Slab, sub_root: NodeHandle, buf: &mut Vec<&'a LeafNode>) {
732            match slab.get(sub_root).unwrap().case().unwrap() {
733                NodeRef::Leaf(leaf) => {
734                    buf.push(leaf);
735                }
736                NodeRef::Inner(inner) => {
737                    walk_rec(slab, inner.children[0], buf);
738                    walk_rec(slab, inner.children[1], buf);
739                }
740            }
741        }
742
743        let mut buf = Vec::with_capacity(self.header().leaf_count as usize);
744        if let Some(r) = self.root() {
745            walk_rec(self, r, &mut buf);
746        }
747        if buf.len() != buf.capacity() {
748            self.hexdump();
749        }
750        assert_eq!(buf.len(), buf.capacity());
751        buf
752    }
753
754    #[cfg(test)]
755    fn hexdump(&self) {
756        println!("Header:");
757        hexdump::hexdump(bytemuck::bytes_of(self.header()));
758        println!("Data:");
759        hexdump::hexdump(cast_slice(self.nodes()));
760    }
761
762    #[cfg(test)]
763    fn check_invariants(&self) {
764        // first check the live tree contents
765        let mut count = 0;
766        fn check_rec(
767            slab: &Slab,
768            key: NodeHandle,
769            last_prefix_len: u32,
770            last_prefix: u128,
771            last_crit_bit: bool,
772            count: &mut u64,
773        ) {
774            *count += 1;
775            let node = slab.get(key).unwrap();
776            assert!(node.prefix_len() > last_prefix_len);
777            let node_key = node.key().unwrap();
778            assert_eq!(
779                last_crit_bit,
780                (node_key & ((1u128 << 127) >> last_prefix_len)) != 0
781            );
782            let prefix_mask = (((((1u128) << 127) as i128) >> last_prefix_len) as u128) << 1;
783            assert_eq!(last_prefix & prefix_mask, node.key().unwrap() & prefix_mask);
784            if let Some([c0, c1]) = node.children() {
785                check_rec(slab, c0, node.prefix_len(), node_key, false, count);
786                check_rec(slab, c1, node.prefix_len(), node_key, true, count);
787            }
788        }
789        if let Some(root) = self.root() {
790            count += 1;
791            let node = self.get(root).unwrap();
792            let node_key = node.key().unwrap();
793            if let Some([c0, c1]) = node.children() {
794                check_rec(self, c0, node.prefix_len(), node_key, false, &mut count);
795                check_rec(self, c1, node.prefix_len(), node_key, true, &mut count);
796            }
797        }
798        assert_eq!(
799            count + self.header().free_list_len as u64,
800            identity(self.header().bump_index)
801        );
802
803        let mut free_nodes_remaining = self.header().free_list_len;
804        let mut next_free_node = self.header().free_list_head;
805        loop {
806            let contents;
807            match free_nodes_remaining {
808                0 => break,
809                1 => {
810                    contents = &self.nodes()[next_free_node as usize];
811                    assert_eq!(identity(contents.tag), u32::from(NodeTag::LastFreeNode));
812                }
813                _ => {
814                    contents = &self.nodes()[next_free_node as usize];
815                    assert_eq!(identity(contents.tag), u32::from(NodeTag::FreeNode));
816                }
817            };
818            let typed_ref: &FreeNode = cast_ref(contents);
819            next_free_node = typed_ref.next;
820            free_nodes_remaining -= 1;
821        }
822    }
823}
824
825#[cfg(test)]
826mod tests {
827    use super::*;
828    use bytemuck::bytes_of;
829    use rand::prelude::*;
830
831    #[test]
832    fn simulate_find_min() {
833        use std::collections::BTreeMap;
834
835        for trial in 0..10u64 {
836            let mut aligned_buf = vec![0u64; 10_000];
837            let bytes: &mut [u8] = cast_slice_mut(aligned_buf.as_mut_slice());
838
839            let slab: &mut Slab = Slab::new(bytes);
840            let mut model: BTreeMap<u128, LeafNode> = BTreeMap::new();
841
842            let mut all_keys = vec![];
843
844            let mut rng = StdRng::seed_from_u64(trial);
845
846            assert_eq!(slab.find_min(), None);
847            assert_eq!(slab.find_max(), None);
848
849            for i in 0..100 {
850                let offset = rng.gen();
851                let key = rng.gen();
852                let owner = rng.gen();
853                let qty = rng.gen();
854                let leaf = LeafNode::new(offset, key, owner, qty, FeeTier::Base, 0);
855
856                println!("{:x}", key);
857                println!("{}", i);
858
859                slab.insert_leaf(&leaf).unwrap();
860                model.insert(key, leaf).ok_or(()).unwrap_err();
861                all_keys.push(key);
862
863                // test find_by_key
864                let valid_search_key = *all_keys.choose(&mut rng).unwrap();
865                let invalid_search_key = rng.gen();
866
867                for &search_key in &[valid_search_key, invalid_search_key] {
868                    let slab_value = slab
869                        .find_by_key(search_key)
870                        .map(|x| slab.get(x))
871                        .flatten()
872                        .map(bytes_of);
873                    let model_value = model.get(&search_key).map(bytes_of);
874                    assert_eq!(slab_value, model_value);
875                }
876
877                // test find_min
878                let slab_min = slab.get(slab.find_min().unwrap()).unwrap();
879                let model_min = model.iter().next().unwrap().1;
880                assert_eq!(bytes_of(slab_min), bytes_of(model_min));
881
882                // test find_max
883                let slab_max = slab.get(slab.find_max().unwrap()).unwrap();
884                let model_max = model.iter().next_back().unwrap().1;
885                assert_eq!(bytes_of(slab_max), bytes_of(model_max));
886            }
887        }
888    }
889
890    #[test]
891    fn simulate_operations() {
892        use rand::distributions::WeightedIndex;
893        use std::collections::BTreeMap;
894
895        let mut aligned_buf = vec![0u64; 1_250_000];
896        let bytes: &mut [u8] = &mut cast_slice_mut(aligned_buf.as_mut_slice());
897        let slab: &mut Slab = Slab::new(bytes);
898        let mut model: BTreeMap<u128, LeafNode> = BTreeMap::new();
899
900        let mut all_keys = vec![];
901        let mut rng = StdRng::seed_from_u64(0);
902
903        #[derive(Copy, Clone)]
904        enum Op {
905            InsertNew,
906            InsertDup,
907            Delete,
908            Min,
909            Max,
910            End,
911        }
912
913        for weights in &[
914            [
915                (Op::InsertNew, 2000),
916                (Op::InsertDup, 200),
917                (Op::Delete, 2210),
918                (Op::Min, 500),
919                (Op::Max, 500),
920                (Op::End, 1),
921            ],
922            [
923                (Op::InsertNew, 10),
924                (Op::InsertDup, 200),
925                (Op::Delete, 5210),
926                (Op::Min, 500),
927                (Op::Max, 500),
928                (Op::End, 1),
929            ],
930        ] {
931            let dist = WeightedIndex::new(weights.iter().map(|(_op, wt)| wt)).unwrap();
932
933            for i in 0..100_000 {
934                slab.check_invariants();
935                let model_state = model.values().collect::<Vec<_>>();
936                let slab_state = slab.traverse();
937                assert_eq!(model_state, slab_state);
938
939                match weights[dist.sample(&mut rng)].0 {
940                    op @ Op::InsertNew | op @ Op::InsertDup => {
941                        let offset = rng.gen();
942                        let key = match op {
943                            Op::InsertNew => rng.gen(),
944                            Op::InsertDup => *all_keys.choose(&mut rng).unwrap(),
945                            _ => unreachable!(),
946                        };
947                        let owner = rng.gen();
948                        let qty = rng.gen();
949                        let leaf = LeafNode::new(offset, key, owner, qty, FeeTier::SRM5, 5);
950
951                        println!("Insert {:x}", key);
952
953                        all_keys.push(key);
954                        let slab_value = slab.insert_leaf(&leaf).unwrap().1;
955                        let model_value = model.insert(key, leaf);
956                        if slab_value != model_value {
957                            slab.hexdump();
958                        }
959                        assert_eq!(slab_value, model_value);
960                    }
961                    Op::Delete => {
962                        let key = all_keys
963                            .choose(&mut rng)
964                            .map(|x| *x)
965                            .unwrap_or_else(|| rng.gen());
966
967                        println!("Remove {:x}", key);
968
969                        let slab_value = slab.remove_by_key(key);
970                        let model_value = model.remove(&key);
971                        assert_eq!(slab_value.as_ref().map(cast_ref), model_value.as_ref());
972                    }
973                    Op::Min => {
974                        if model.len() == 0 {
975                            assert_eq!(identity(slab.header().leaf_count), 0);
976                        } else {
977                            let slab_min = slab.get(slab.find_min().unwrap()).unwrap();
978                            let model_min = model.iter().next().unwrap().1;
979                            assert_eq!(bytes_of(slab_min), bytes_of(model_min));
980                        }
981                    }
982                    Op::Max => {
983                        if model.len() == 0 {
984                            assert_eq!(identity(slab.header().leaf_count), 0);
985                        } else {
986                            let slab_max = slab.get(slab.find_max().unwrap()).unwrap();
987                            let model_max = model.iter().next_back().unwrap().1;
988                            assert_eq!(bytes_of(slab_max), bytes_of(model_max));
989                        }
990                    }
991                    Op::End => {
992                        if i > 10_000 {
993                            break;
994                        }
995                    }
996                }
997            }
998        }
999    }
1000}