trie_hard/
lib.rs

1// Copyright 2024 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
16#![deny(
17    missing_docs,
18    missing_debug_implementations,
19    unreachable_pub,
20    rustdoc::broken_intra_doc_links,
21    unsafe_code
22)]
23#![warn(rust_2018_idioms)]
24
25mod u256;
26
27use std::{
28    collections::{BTreeMap, BTreeSet, VecDeque},
29    ops::RangeFrom,
30};
31
32use u256::U256;
33
34#[derive(Debug, Clone)]
35#[repr(transparent)]
36struct MasksByByteSized<I>([I; 256]);
37
38impl<I> Default for MasksByByteSized<I>
39where
40    I: Default + Copy,
41{
42    fn default() -> Self {
43        Self([I::default(); 256])
44    }
45}
46
47#[allow(clippy::large_enum_variant)]
48enum MasksByByte {
49    U8(MasksByByteSized<u8>),
50    U16(MasksByByteSized<u16>),
51    U32(MasksByByteSized<u32>),
52    U64(MasksByByteSized<u64>),
53    U128(MasksByByteSized<u128>),
54    U256(MasksByByteSized<U256>),
55}
56
57impl MasksByByte {
58    fn new(used_bytes: BTreeSet<u8>) -> Self {
59        match used_bytes.len() {
60            ..=8 => MasksByByte::U8(MasksByByteSized::<u8>::new(used_bytes)),
61            9..=16 => {
62                MasksByByte::U16(MasksByByteSized::<u16>::new(used_bytes))
63            }
64            17..=32 => {
65                MasksByByte::U32(MasksByByteSized::<u32>::new(used_bytes))
66            }
67            33..=64 => {
68                MasksByByte::U64(MasksByByteSized::<u64>::new(used_bytes))
69            }
70            65..=128 => {
71                MasksByByte::U128(MasksByByteSized::<u128>::new(used_bytes))
72            }
73            129..=256 => {
74                MasksByByte::U256(MasksByByteSized::<U256>::new(used_bytes))
75            }
76            _ => unreachable!("There are only 256 possible u8s"),
77        }
78    }
79}
80
81/// Inner representation of a trie-hard trie that is generic to a specific size
82/// of integer.
83#[derive(Debug, Clone)]
84pub struct TrieHardSized<'a, T, I> {
85    masks: MasksByByteSized<I>,
86    nodes: Vec<TrieState<'a, T, I>>,
87}
88
89impl<'a, T, I> Default for TrieHardSized<'a, T, I>
90where
91    I: Default + Copy,
92{
93    fn default() -> Self {
94        Self {
95            masks: MasksByByteSized::default(),
96            nodes: Default::default(),
97        }
98    }
99}
100
101#[derive(PartialEq, Eq, PartialOrd, Ord)]
102struct StateSpec<'a> {
103    prefix: &'a [u8],
104    index: usize,
105}
106
107#[derive(Debug, Clone)]
108struct SearchNode<I> {
109    mask: I,
110    edge_start: usize,
111}
112
113#[derive(Debug, Clone)]
114enum TrieState<'a, T, I> {
115    Leaf(&'a [u8], T),
116    Search(SearchNode<I>),
117    SearchOrLeaf(&'a [u8], T, SearchNode<I>),
118}
119
120/// Enumeration of all the possible sizes of trie-hard tries. An instance of
121/// this enum can be created from any set of arbitrary string or byte slices.
122/// The variant returned will depend on the number of distinct bytes contained
123/// in the set.
124///
125/// ```
126/// # use trie_hard::TrieHard;
127/// let trie = ["and", "ant", "dad", "do", "dot"]
128///     .into_iter()
129///     .collect::<TrieHard<'_, _>>();
130///
131/// assert!(trie.get("dad").is_some());
132/// assert!(trie.get("do").is_some());
133/// assert!(trie.get("don't").is_none());
134/// ```
135///
136/// _Note_: This enum has a very large variant which dominates the size for
137/// the enum. That means that a small trie using `u8`s for storage will take up
138/// way (32x) more storage than it needs to. If you are concerned about extra
139/// space (and you know ahead of time the trie size needed), you should extract
140/// the inner, `[TrieHardSized]` which will use only the size required.
141#[allow(clippy::large_enum_variant)]
142#[derive(Debug, Clone)]
143pub enum TrieHard<'a, T> {
144    /// Trie-hard using u8s for storage. For sets with 1..=8 unique bytes
145    U8(TrieHardSized<'a, T, u8>),
146    /// Trie-hard using u16s for storage. For sets with 9..=16 unique bytes
147    U16(TrieHardSized<'a, T, u16>),
148    /// Trie-hard using u32s for storage. For sets with 17..=32 unique bytes
149    U32(TrieHardSized<'a, T, u32>),
150    /// Trie-hard using u64s for storage. For sets with 33..=64 unique bytes
151    U64(TrieHardSized<'a, T, u64>),
152    /// Trie-hard using u128s for storage. For sets with 65..=126 unique bytes
153    U128(TrieHardSized<'a, T, u128>),
154    /// Trie-hard using U256s for storage. For sets with 129.. unique bytes
155    U256(TrieHardSized<'a, T, U256>),
156}
157
158impl<'a, T> Default for TrieHard<'a, T> {
159    fn default() -> Self {
160        TrieHard::U8(TrieHardSized::default())
161    }
162}
163
164impl<'a, T> TrieHard<'a, T>
165where
166    T: 'a + Copy,
167{
168    /// Create an instance of a trie-hard trie with the given keys and values.
169    /// The variant returned will be determined based on the number of unique
170    /// bytes in the keys.
171    ///
172    /// ```
173    /// # use trie_hard::TrieHard;
174    /// let trie = TrieHard::new(vec![
175    ///     (b"and", 0),
176    ///     (b"ant", 1),
177    ///     (b"dad", 2),
178    ///     (b"do", 3),
179    ///     (b"dot", 4)
180    /// ]);
181    ///
182    /// // Only 5 unique characters produces a u8 TrieHard
183    /// assert!(matches!(trie, TrieHard::U8(_)));
184    ///
185    /// assert_eq!(trie.get("dad"), Some(2));
186    /// assert_eq!(trie.get("do"), Some(3));
187    /// assert!(trie.get("don't").is_none());
188    /// ```
189    pub fn new(values: Vec<(&'a [u8], T)>) -> Self {
190        if values.is_empty() {
191            return Self::default();
192        }
193
194        let used_bytes = values
195            .iter()
196            .flat_map(|(k, _)| k.iter())
197            .cloned()
198            .collect::<BTreeSet<_>>();
199
200        let masks = MasksByByte::new(used_bytes);
201
202        match masks {
203            MasksByByte::U8(masks) => {
204                TrieHard::U8(TrieHardSized::<'_, _, u8>::new(masks, values))
205            }
206            MasksByByte::U16(masks) => {
207                TrieHard::U16(TrieHardSized::<'_, _, u16>::new(masks, values))
208            }
209            MasksByByte::U32(masks) => {
210                TrieHard::U32(TrieHardSized::<'_, _, u32>::new(masks, values))
211            }
212            MasksByByte::U64(masks) => {
213                TrieHard::U64(TrieHardSized::<'_, _, u64>::new(masks, values))
214            }
215            MasksByByte::U128(masks) => {
216                TrieHard::U128(TrieHardSized::<'_, _, u128>::new(masks, values))
217            }
218            MasksByByte::U256(masks) => {
219                TrieHard::U256(TrieHardSized::<'_, _, U256>::new(masks, values))
220            }
221        }
222    }
223
224    /// Get the value stored for the given key. Any key type can be used here as
225    /// long as the type implements `AsRef<[u8]>`. The byte slice referenced
226    /// will serve as the actual key.
227    /// ```
228    /// # use trie_hard::TrieHard;
229    /// let trie = ["and", "ant", "dad", "do", "dot"]
230    ///     .into_iter()
231    ///     .collect::<TrieHard<'_, _>>();
232    ///
233    /// assert!(trie.get("dad".to_owned()).is_some());
234    /// assert!(trie.get(b"do").is_some());
235    /// assert!(trie.get(b"don't".to_vec()).is_none());
236    /// ```
237    pub fn get<K: AsRef<[u8]>>(&self, raw_key: K) -> Option<T> {
238        match self {
239            TrieHard::U8(trie) => trie.get(raw_key),
240            TrieHard::U16(trie) => trie.get(raw_key),
241            TrieHard::U32(trie) => trie.get(raw_key),
242            TrieHard::U64(trie) => trie.get(raw_key),
243            TrieHard::U128(trie) => trie.get(raw_key),
244            TrieHard::U256(trie) => trie.get(raw_key),
245        }
246    }
247
248    /// Get the value stored for the given byte-slice key
249    /// ```
250    /// # use trie_hard::TrieHard;
251    /// let trie = ["and", "ant", "dad", "do", "dot"]
252    ///     .into_iter()
253    ///     .collect::<TrieHard<'_, _>>();
254    ///
255    /// assert!(trie.get_from_bytes(b"dad").is_some());
256    /// assert!(trie.get_from_bytes(b"do").is_some());
257    /// assert!(trie.get_from_bytes(b"don't").is_none());
258    /// ```
259    pub fn get_from_bytes(&self, key: &[u8]) -> Option<T> {
260        match self {
261            TrieHard::U8(trie) => trie.get_from_bytes(key),
262            TrieHard::U16(trie) => trie.get_from_bytes(key),
263            TrieHard::U32(trie) => trie.get_from_bytes(key),
264            TrieHard::U64(trie) => trie.get_from_bytes(key),
265            TrieHard::U128(trie) => trie.get_from_bytes(key),
266            TrieHard::U256(trie) => trie.get_from_bytes(key),
267        }
268    }
269
270    /// Create an iterator over the entire trie. Emitted items will be ordered
271    /// by their keys
272    ///
273    /// ```
274    /// # use trie_hard::TrieHard;
275    /// let trie = ["dad", "ant", "and", "dot", "do"]
276    ///     .into_iter()
277    ///     .collect::<TrieHard<'_, _>>();
278    ///
279    /// assert_eq!(
280    ///     trie.iter().map(|(_, v)| v).collect::<Vec<_>>(),
281    ///     ["and", "ant", "dad", "do", "dot"]
282    /// );
283    /// ```
284    pub fn iter(&self) -> TrieIter<'_, 'a, T> {
285        match self {
286            TrieHard::U8(trie) => TrieIter::U8(trie.iter()),
287            TrieHard::U16(trie) => TrieIter::U16(trie.iter()),
288            TrieHard::U32(trie) => TrieIter::U32(trie.iter()),
289            TrieHard::U64(trie) => TrieIter::U64(trie.iter()),
290            TrieHard::U128(trie) => TrieIter::U128(trie.iter()),
291            TrieHard::U256(trie) => TrieIter::U256(trie.iter()),
292        }
293    }
294
295    /// Create an iterator over the portion of the trie starting with the given
296    /// prefix
297    ///
298    /// ```
299    /// # use trie_hard::TrieHard;
300    /// let trie = ["dad", "ant", "and", "dot", "do"]
301    ///     .into_iter()
302    ///     .collect::<TrieHard<'_, _>>();
303    ///
304    /// assert_eq!(
305    ///     trie.prefix_search("d").map(|(_, v)| v).collect::<Vec<_>>(),
306    ///     ["dad", "do", "dot"]
307    /// );
308    /// ```
309    pub fn prefix_search<K: AsRef<[u8]>>(
310        &self,
311        prefix: K,
312    ) -> TrieIter<'_, 'a, T> {
313        match self {
314            TrieHard::U8(trie) => TrieIter::U8(trie.prefix_search(prefix)),
315            TrieHard::U16(trie) => TrieIter::U16(trie.prefix_search(prefix)),
316            TrieHard::U32(trie) => TrieIter::U32(trie.prefix_search(prefix)),
317            TrieHard::U64(trie) => TrieIter::U64(trie.prefix_search(prefix)),
318            TrieHard::U128(trie) => TrieIter::U128(trie.prefix_search(prefix)),
319            TrieHard::U256(trie) => TrieIter::U256(trie.prefix_search(prefix)),
320        }
321    }
322}
323
324/// Structure used for iterative over the contents of trie
325#[derive(Debug)]
326pub enum TrieIter<'b, 'a, T> {
327    /// Variant for iterating over trie-hard tries built on u8
328    U8(TrieIterSized<'b, 'a, T, u8>),
329    /// Variant for iterating over trie-hard tries built on u16
330    U16(TrieIterSized<'b, 'a, T, u16>),
331    /// Variant for iterating over trie-hard tries built on u32
332    U32(TrieIterSized<'b, 'a, T, u32>),
333    /// Variant for iterating over trie-hard tries built on u64
334    U64(TrieIterSized<'b, 'a, T, u64>),
335    /// Variant for iterating over trie-hard tries built on u128
336    U128(TrieIterSized<'b, 'a, T, u128>),
337    /// Variant for iterating over trie-hard tries built on u256
338    U256(TrieIterSized<'b, 'a, T, U256>),
339}
340
341#[derive(Debug, Default)]
342struct TrieNodeIter {
343    node_index: usize,
344    stage: TrieNodeIterStage,
345}
346
347#[derive(Debug, Default)]
348enum TrieNodeIterStage {
349    #[default]
350    Inner,
351    Child(usize, usize),
352}
353
354/// Structure for iterating of a trie-hard trie built on specific a specific
355/// integer size
356#[derive(Debug)]
357pub struct TrieIterSized<'b, 'a, T, I> {
358    stack: Vec<TrieNodeIter>,
359    trie: &'b TrieHardSized<'a, T, I>,
360}
361
362impl<'b, 'a, T, I> TrieIterSized<'b, 'a, T, I> {
363    fn empty(trie: &'b TrieHardSized<'a, T, I>) -> Self {
364        Self {
365            stack: Default::default(),
366            trie,
367        }
368    }
369
370    fn new(trie: &'b TrieHardSized<'a, T, I>, node_index: usize) -> Self {
371        Self {
372            stack: vec![TrieNodeIter {
373                node_index,
374                stage: Default::default(),
375            }],
376            trie,
377        }
378    }
379}
380
381impl<'b, 'a, T> Iterator for TrieIter<'b, 'a, T>
382where
383    T: Copy,
384{
385    type Item = (&'a [u8], T);
386
387    fn next(&mut self) -> Option<Self::Item> {
388        match self {
389            TrieIter::U8(iter) => iter.next(),
390            TrieIter::U16(iter) => iter.next(),
391            TrieIter::U32(iter) => iter.next(),
392            TrieIter::U64(iter) => iter.next(),
393            TrieIter::U128(iter) => iter.next(),
394            TrieIter::U256(iter) => iter.next(),
395        }
396    }
397}
398
399impl<'a, T> FromIterator<&'a T> for TrieHard<'a, &'a T>
400where
401    T: 'a + AsRef<[u8]> + ?Sized,
402{
403    fn from_iter<I: IntoIterator<Item = &'a T>>(values: I) -> Self {
404        let values = values
405            .into_iter()
406            .map(|v| (v.as_ref(), v))
407            .collect::<Vec<_>>();
408
409        Self::new(values)
410    }
411}
412
413macro_rules! trie_impls {
414    ($($int_type:ty),+) => {
415        $(
416            trie_impls!(_impl $int_type);
417        )+
418    };
419
420    (_impl $int_type:ty) => {
421
422        impl SearchNode<$int_type> {
423            fn evaluate<T>(&self, c: u8, trie: &TrieHardSized<'_, T, $int_type>) -> Option<usize> {
424                let c_mask = trie.masks.0[c as usize];
425                let mask_res = self.mask & c_mask;
426                (mask_res > 0).then(|| {
427                    let smaller_bits = mask_res - 1;
428                    let smaller_bits_mask = smaller_bits & self.mask;
429                    let index_offset = smaller_bits_mask.count_ones() as usize;
430                    self.edge_start + index_offset
431                })
432            }
433        }
434
435        impl<'a, T> TrieHardSized<'a, T, $int_type>
436        where
437            T: Copy
438        {
439
440            /// Get the value stored for the given key. Any key type can be used
441            /// here as long as the type implements `AsRef<[u8]>`. The byte slice
442            /// referenced will serve as the actual key.
443            /// ```
444            /// # use trie_hard::TrieHard;
445            /// let trie = ["and", "ant", "dad", "do", "dot"]
446            ///     .into_iter()
447            ///     .collect::<TrieHard<'_, _>>();
448            ///
449            /// let TrieHard::U8(sized_trie) = trie else {
450            ///     unreachable!()
451            /// };
452            ///
453            /// assert!(sized_trie.get("dad".to_owned()).is_some());
454            /// assert!(sized_trie.get(b"do").is_some());
455            /// assert!(sized_trie.get(b"don't".to_vec()).is_none());
456            /// ```
457            pub fn get<K: AsRef<[u8]>>(&self, key: K) -> Option<T> {
458                self.get_from_bytes(key.as_ref())
459            }
460
461            /// Get the value stored for the given byte-slice key.
462            /// ```
463            /// # use trie_hard::TrieHard;
464            /// let trie = ["and", "ant", "dad", "do", "dot"]
465            ///     .into_iter()
466            ///     .collect::<TrieHard<'_, _>>();
467            ///
468            /// let TrieHard::U8(sized_trie) = trie else {
469            ///     unreachable!()
470            /// };
471            ///
472            /// assert!(sized_trie.get_from_bytes(b"dad").is_some());
473            /// assert!(sized_trie.get_from_bytes(b"do").is_some());
474            /// assert!(sized_trie.get_from_bytes(b"don't").is_none());
475            /// ```
476            pub fn get_from_bytes(&self, key: &[u8]) -> Option<T> {
477                let mut state = self.nodes.get(0)?;
478
479                for (i, c) in key.iter().enumerate() {
480
481                    let next_state_opt = match state {
482                        TrieState::Leaf(k, value) => {
483                            return (
484                                k.len() == key.len()
485                                && k[i..] == key[i..]
486                            ).then_some(*value)
487                        }
488                        TrieState::Search(search)
489                        | TrieState::SearchOrLeaf(_, _, search) => {
490                            search.evaluate(*c, self)
491                        }
492                    };
493
494                    if let Some(next_state_index) = next_state_opt {
495                        state = &self.nodes[next_state_index];
496                    } else {
497                        return None;
498                    }
499                }
500
501                if let TrieState::Leaf(k, value)
502                    | TrieState::SearchOrLeaf(k, value, _) = state
503                {
504                    (k.len() == key.len()).then_some(*value)
505                } else {
506                    None
507                }
508            }
509
510            /// Create an iterator over the entire trie. Emitted items will be
511            /// ordered by their keys
512            ///
513            /// ```
514            /// # use trie_hard::TrieHard;
515            /// let trie = ["dad", "ant", "and", "dot", "do"]
516            ///     .into_iter()
517            ///     .collect::<TrieHard<'_, _>>();
518            ///
519            /// let TrieHard::U8(sized_trie) = trie else {
520            ///     unreachable!()
521            /// };
522            ///
523            /// assert_eq!(
524            ///     sized_trie.iter().map(|(_, v)| v).collect::<Vec<_>>(),
525            ///     ["and", "ant", "dad", "do", "dot"]
526            /// );
527            /// ```
528            pub fn iter(&self) -> TrieIterSized<'_, 'a, T, $int_type> {
529                TrieIterSized {
530                    stack: vec![TrieNodeIter::default()],
531                    trie: self
532                }
533            }
534
535
536            /// Create an iterator over the portion of the trie starting with the given
537            /// prefix
538            ///
539            /// ```
540            /// # use trie_hard::TrieHard;
541            /// let trie = ["dad", "ant", "and", "dot", "do"]
542            ///     .into_iter()
543            ///     .collect::<TrieHard<'_, _>>();
544            ///
545            /// let TrieHard::U8(sized_trie) = trie else {
546            ///     unreachable!()
547            /// };
548            ///
549            /// assert_eq!(
550            ///     sized_trie.prefix_search("d").map(|(_, v)| v).collect::<Vec<_>>(),
551            ///     ["dad", "do", "dot"]
552            /// );
553            /// ```
554            pub fn prefix_search<K: AsRef<[u8]>>(&self, prefix: K) -> TrieIterSized<'_, 'a, T, $int_type> {
555                let key = prefix.as_ref();
556                let mut node_index = 0;
557                let Some(mut state) = self.nodes.get(node_index) else {
558                    return TrieIterSized::empty(self);
559                };
560
561                for (i, c) in key.iter().enumerate() {
562                    let next_state_opt = match state {
563                        TrieState::Leaf(k, _) => {
564                            if k.len() == key.len() && k[i..] == key[i..] {
565                                return TrieIterSized::new(self, node_index);
566                            } else {
567                                return TrieIterSized::empty(self);
568                            }
569                        }
570                        TrieState::Search(search)
571                        | TrieState::SearchOrLeaf(_, _, search) => {
572                            search.evaluate(*c, self)
573                        }
574                    };
575
576                    if let Some(next_state_index) = next_state_opt {
577                        node_index = next_state_index;
578                        state = &self.nodes[next_state_index];
579                    } else {
580                        return TrieIterSized::empty(self);
581                    }
582                }
583
584                TrieIterSized::new(self, node_index)
585            }
586        }
587
588        impl<'a, T> TrieHardSized<'a, T, $int_type> where T: 'a + Copy {
589            fn new(masks: MasksByByteSized<$int_type>, values: Vec<(&'a [u8], T)>) -> Self {
590                let values = values.into_iter().collect::<Vec<_>>();
591                let sorted = values
592                    .iter()
593                    .map(|(k, v)| (*k, *v))
594                    .collect::<BTreeMap<_, _>>();
595
596                let mut nodes = Vec::new();
597                let mut next_index = 1;
598
599                let root_state_spec = StateSpec {
600                    prefix: &[],
601                    index: 0,
602                };
603
604                let mut spec_queue = VecDeque::new();
605                spec_queue.push_back(root_state_spec);
606
607                while let Some(spec) = spec_queue.pop_front() {
608                    debug_assert_eq!(spec.index, nodes.len());
609                    let (state, next_specs) = TrieState::<'_, _, $int_type>::new(
610                        spec,
611                        next_index,
612                        &masks.0,
613                        &sorted,
614                    );
615
616                    next_index += next_specs.len();
617                    spec_queue.extend(next_specs);
618                    nodes.push(state);
619                }
620
621                TrieHardSized {
622                    nodes,
623                    masks,
624                }
625            }
626        }
627
628
629        impl <'a, T> TrieState<'a, T, $int_type> where T: 'a + Copy {
630            fn new(
631                spec: StateSpec<'a>,
632                edge_start: usize,
633                byte_masks: &[$int_type; 256],
634                sorted: &BTreeMap<&'a [u8], T>,
635            ) -> (Self, Vec<StateSpec<'a>>) {
636                let StateSpec { prefix, .. } = spec;
637
638                let prefix_len = prefix.len();
639                let next_prefix_len = prefix_len + 1;
640
641                let mut prefix_match = None;
642                let mut children_seen = 0;
643                let mut last_seen = None;
644
645                let next_states_paired = sorted
646                    .range(RangeFrom { start: prefix })
647                    .take_while(|(key, _)| key.starts_with(prefix))
648                    .filter_map(|(key, val)| {
649                        children_seen += 1;
650                        last_seen = Some((key, *val));
651
652                        if *key == prefix {
653                            prefix_match = Some((key, *val));
654                            None
655                        } else {
656                            // Safety: The byte at prefix_len must exist otherwise we
657                            // would have ended up in the other branch of this statement
658                            let next_c = key.get(prefix_len).unwrap();
659                            let next_prefix = &key[..next_prefix_len];
660
661                            Some((
662                                *next_c,
663                                StateSpec {
664                                    prefix: next_prefix,
665                                    index: 0,
666                                },
667                            ))
668                        }
669                    })
670                    .collect::<BTreeMap<_, _>>()
671                    .into_iter()
672                    .collect::<Vec<_>>();
673
674                // Safety: last_seen will be present because we saw at least one
675                //         entry must be present for this function to be called
676                let (last_k, last_v) = last_seen.unwrap();
677
678                if children_seen == 1 {
679                    return (TrieState::Leaf(last_k, last_v), vec![]);
680                }
681
682                // No next_states means we hit a leaf node
683                if next_states_paired.is_empty() {
684                    return (TrieState::Leaf(last_k, last_v), vec![], );
685                }
686
687                let mut mask = Default::default();
688
689                // Update the index for the next state now that we have ordered by
690                let next_state_specs = next_states_paired
691                    .into_iter()
692                    .enumerate()
693                    .map(|(i, (c, mut next_state))| {
694                        let next_node = edge_start + i;
695                        next_state.index = next_node;
696                        mask |= byte_masks[c as usize];
697                        next_state
698                    })
699                    .collect();
700
701                let search_node = SearchNode { mask, edge_start };
702                let state = match prefix_match {
703                    Some((key, value)) => {
704                        TrieState::SearchOrLeaf(key, value, search_node)
705                    }
706                    _ => TrieState::Search(search_node),
707                };
708
709                (state, next_state_specs)
710            }
711        }
712
713        impl MasksByByteSized<$int_type> {
714            fn new(used_bytes: BTreeSet<u8>) -> Self {
715                let mut mask = Default::default();
716                mask += 1;
717
718                let mut byte_masks = [Default::default(); 256];
719
720                for c in used_bytes.into_iter() {
721                    byte_masks[c as usize] = mask;
722                    mask <<= 1;
723
724                }
725
726                Self(byte_masks)
727            }
728        }
729
730        impl <'b, 'a, T> Iterator for TrieIterSized<'b, 'a, T, $int_type>
731        where
732            T: Copy
733        {
734            type Item = (&'a [u8], T);
735
736            fn next(&mut self) -> Option<Self::Item> {
737
738                use TrieState as T;
739                use TrieNodeIterStage as S;
740
741                while let Some((node, node_index, stage)) = self.stack.pop()
742                    .and_then(|TrieNodeIter { node_index, stage }| {
743                        self.trie.nodes.get(node_index).map(|node| (node, node_index, stage))
744                    })
745                {
746                    match (node, stage) {
747                        (T::Leaf(key, value), S::Inner) => return Some((*key, *value)),
748                        (T::SearchOrLeaf(key, value, search), S::Inner) => {
749                            self.stack.push(TrieNodeIter {
750                                node_index,
751                                stage: TrieNodeIterStage::Child(0, search.mask.count_ones() as usize)
752                            });
753                            self.stack.push(TrieNodeIter {
754                                node_index: search.edge_start,
755                                stage: Default::default()
756                            });
757                            return Some((*key, *value));
758                        }
759                        (T::Search(search), S::Inner) => {
760                            self.stack.push(TrieNodeIter {
761                                node_index,
762                                stage: TrieNodeIterStage::Child(0, search.mask.count_ones() as usize)
763                            });
764                            self.stack.push(TrieNodeIter {
765                                node_index: search.edge_start,
766                                stage: Default::default()
767                            });
768                        }
769                        (
770                            T::SearchOrLeaf(_, _, search) | T::Search(search),
771                            S::Child(mut child, child_count)
772                        ) => {
773                            child += 1;
774                            if child < child_count {
775                                self.stack.push(TrieNodeIter {
776                                    node_index,
777                                    stage: TrieNodeIterStage::Child(child, child_count)
778                                });
779                                self.stack.push(TrieNodeIter {
780                                    node_index: search.edge_start + child,
781                                    stage: Default::default()
782                                });
783                            }
784                        }
785                        _ => unreachable!()
786                    }
787                }
788
789                None
790            }
791        }
792    }
793}
794
795trie_impls! {u8, u16, u32, u64, u128, U256}
796
797#[cfg(test)]
798mod tests {
799    use rstest::rstest;
800
801    use super::*;
802
803    #[test]
804    fn test_trivial() {
805        let empty: Vec<&str> = vec![];
806        let empty_trie = empty.iter().collect::<TrieHard<'_, _>>();
807
808        assert_eq!(None, empty_trie.get("anything"));
809    }
810
811    #[rstest]
812    #[case("", Some(""))]
813    #[case("a", Some("a"))]
814    #[case("ab", Some("ab"))]
815    #[case("abc", None)]
816    #[case("aac", Some("aac"))]
817    #[case("aa", None)]
818    #[case("aab", None)]
819    #[case("adddd", Some("adddd"))]
820    fn test_small_get(#[case] key: &str, #[case] expected: Option<&str>) {
821        let trie = ["", "a", "ab", "aac", "adddd", "addde"]
822            .into_iter()
823            .collect::<TrieHard<'_, _>>();
824        assert_eq!(expected, trie.get(key));
825    }
826
827    #[test]
828    fn test_skip_to_leaf() {
829        let trie = ["a", "aa", "aaa"].into_iter().collect::<TrieHard<'_, _>>();
830
831        assert_eq!(trie.get("aa"), Some("aa"))
832    }
833
834    #[rstest]
835    #[case(8)]
836    #[case(16)]
837    #[case(32)]
838    #[case(64)]
839    #[case(128)]
840    #[case(256)]
841    fn test_sizes(#[case] bits: usize) {
842        let range = 0..bits;
843        let bytes = range.map(|b| [b as u8]).collect::<Vec<_>>();
844        let trie = bytes.iter().collect::<TrieHard<'_, _>>();
845
846        use TrieHard as T;
847
848        match (bits, trie) {
849            (8, T::U8(_)) => (),
850            (16, T::U16(_)) => (),
851            (32, T::U32(_)) => (),
852            (64, T::U64(_)) => (),
853            (128, T::U128(_)) => (),
854            (256, T::U256(_)) => (),
855            _ => panic!("Mismatched trie sizes"),
856        }
857    }
858
859    #[rstest]
860    #[case(include_str!("../data/1984.txt"))]
861    #[case(include_str!("../data/sun-rising.txt"))]
862    fn test_full_text(#[case] text: &str) {
863        let words: Vec<&str> =
864            text.split(|c: char| c.is_whitespace()).collect();
865        let trie: TrieHard<'_, _> = words.iter().copied().collect();
866
867        let unique_words = words
868            .into_iter()
869            .collect::<BTreeSet<_>>()
870            .into_iter()
871            .collect::<Vec<_>>();
872
873        for word in &unique_words {
874            assert!(trie.get(word).is_some())
875        }
876
877        assert_eq!(
878            unique_words,
879            trie.iter().map(|(_, v)| v).collect::<Vec<_>>()
880        );
881    }
882
883    #[test]
884    fn test_unicode() {
885        let trie: TrieHard<'_, _> = ["bär", "bären"].into_iter().collect();
886
887        assert_eq!(trie.get("bär"), Some("bär"));
888        assert_eq!(trie.get("bä"), None);
889        assert_eq!(trie.get("bären"), Some("bären"));
890        assert_eq!(trie.get("bärën"), None);
891    }
892
893    #[rstest]
894    #[case(&[], &[])]
895    #[case(&[""], &[""])]
896    #[case(&["aaa", "a", ""], &["", "a", "aaa"])]
897    #[case(&["aaa", "a", ""], &["", "a", "aaa"])]
898    #[case(&["", "a", "ab", "aac", "adddd", "addde"], &["", "a", "aac", "ab", "adddd", "addde"])]
899    fn test_iter(#[case] input: &[&str], #[case] output: &[&str]) {
900        let trie = input.iter().copied().collect::<TrieHard<'_, _>>();
901        let emitted = trie.iter().map(|(_, v)| v).collect::<Vec<_>>();
902        assert_eq!(emitted, output);
903    }
904
905    #[rstest]
906    #[case(&[], "", &[])]
907    #[case(&[""], "", &[""])]
908    #[case(&["aaa", "a", ""], "", &["", "a", "aaa"])]
909    #[case(&["aaa", "a", ""], "a", &["a", "aaa"])]
910    #[case(&["aaa", "a", ""], "aa", &["aaa"])]
911    #[case(&["aaa", "a", ""], "aab", &[])]
912    #[case(&["aaa", "a", ""], "aaa", &["aaa"])]
913    #[case(&["aaa", "a", ""], "b", &[])]
914    #[case(&["dad", "ant", "and", "dot", "do"], "d", &["dad", "do", "dot"])]
915    fn test_prefix_search(
916        #[case] input: &[&str],
917        #[case] prefix: &str,
918        #[case] output: &[&str],
919    ) {
920        let trie = input.iter().copied().collect::<TrieHard<'_, _>>();
921        let emitted = trie
922            .prefix_search(prefix)
923            .map(|(_, v)| v)
924            .collect::<Vec<_>>();
925        assert_eq!(emitted, output);
926    }
927}