trie_rs/map/
trie.rs

1//! A trie map stores a value with each word or key.
2use super::Trie;
3use crate::inc_search::IncSearch;
4use crate::iter::{PostfixIter, PrefixIter, SearchIter};
5use crate::try_collect::{TryCollect, TryFromIterator};
6use louds_rs::{AncestorNodeIter, ChildNodeIter, LoudsNodeNum};
7use std::iter::FromIterator;
8
9impl<Label: Ord, Value> Trie<Label, Value> {
10    /// Return `Some(&Value)` if query is an exact match.
11    pub fn exact_match(&self, query: impl AsRef<[Label]>) -> Option<&Value> {
12        self.exact_match_node(query)
13            .and_then(move |x| self.value(x))
14    }
15
16    /// Return `Node` if query is an exact match.
17    #[inline]
18    fn exact_match_node(&self, query: impl AsRef<[Label]>) -> Option<LoudsNodeNum> {
19        let mut cur_node_num = LoudsNodeNum(1);
20
21        for (i, chr) in query.as_ref().iter().enumerate() {
22            let children_node_nums: Vec<LoudsNodeNum> =
23                self.children_node_nums(cur_node_num).collect();
24            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
25
26            match res {
27                Ok(j) => {
28                    let child_node_num = children_node_nums[j];
29                    if i == query.as_ref().len() - 1 && self.is_terminal(child_node_num) {
30                        return Some(child_node_num);
31                    }
32                    cur_node_num = child_node_num;
33                }
34                Err(_) => return None,
35            }
36        }
37        None
38    }
39
40    /// Return `Some(&mut value)` if query is an exact match.
41    pub fn exact_match_mut(&mut self, query: impl AsRef<[Label]>) -> Option<&mut Value> {
42        self.exact_match_node(query)
43            .and_then(move |x| self.value_mut(x))
44    }
45
46    /// Create an incremental search. Useful for interactive applications. See
47    /// [crate::inc_search] for details.
48    pub fn inc_search(&self) -> IncSearch<'_, Label, Value> {
49        IncSearch::new(self)
50    }
51
52    /// Return true if `query` is a prefix.
53    ///
54    /// Note: A prefix may be an exact match or not, and an exact match may be a
55    /// prefix or not.
56    pub fn is_prefix(&self, query: impl AsRef<[Label]>) -> bool {
57        let mut cur_node_num = LoudsNodeNum(1);
58
59        for chr in query.as_ref().iter() {
60            let children_node_nums: Vec<_> = self.children_node_nums(cur_node_num).collect();
61            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
62            match res {
63                Ok(j) => cur_node_num = children_node_nums[j],
64                Err(_) => return false,
65            }
66        }
67        // Are there more nodes after our query?
68        self.has_children_node_nums(cur_node_num)
69    }
70
71    /// Return all entries and their values that match `query`.
72    pub fn predictive_search<C, M>(
73        &self,
74        query: impl AsRef<[Label]>,
75    ) -> SearchIter<'_, Label, Value, C, M>
76    where
77        C: TryFromIterator<Label, M> + Clone,
78        Label: Clone,
79    {
80        SearchIter::new(self, query)
81    }
82
83    /// Return the postfixes and values of all entries that match `query`.
84    pub fn postfix_search<C, M>(
85        &self,
86        query: impl AsRef<[Label]>,
87    ) -> PostfixIter<'_, Label, Value, C, M>
88    where
89        C: TryFromIterator<Label, M>,
90        Label: Clone,
91    {
92        let mut cur_node_num = LoudsNodeNum(1);
93
94        // Consumes query (prefix)
95        for chr in query.as_ref() {
96            let children_node_nums: Vec<_> = self.children_node_nums(cur_node_num).collect();
97            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
98            match res {
99                Ok(i) => cur_node_num = children_node_nums[i],
100                Err(_) => {
101                    return PostfixIter::empty(self);
102                }
103            }
104        }
105
106        PostfixIter::new(self, cur_node_num)
107    }
108
109    /// Returns an iterator across all keys in the trie.
110    ///
111    /// # Examples
112    /// In the following example we illustrate how to iterate over all keys in the trie.
113    /// Note that the order of the keys is not guaranteed, as they will be returned in
114    /// lexicographical order.
115    ///
116    /// ```rust
117    /// use trie_rs::map::Trie;
118    /// let trie = Trie::from_iter([("a", 0), ("app", 1), ("apple", 2), ("better", 3), ("application", 4)]);
119    /// let results: Vec<(String, &u8)> = trie.iter().collect();
120    /// assert_eq!(results, [("a".to_string(), &0u8), ("app".to_string(), &1u8), ("apple".to_string(), &2u8), ("application".to_string(), &4u8), ("better".to_string(), &3u8)]);
121    /// ```
122    pub fn iter<C, M>(&self) -> PostfixIter<'_, Label, Value, C, M>
123    where
124        C: TryFromIterator<Label, M>,
125        Label: Clone,
126    {
127        self.postfix_search([])
128    }
129
130    /// Return the common prefixes of `query`.
131    pub fn common_prefix_search<C, M>(
132        &self,
133        query: impl AsRef<[Label]>,
134    ) -> PrefixIter<'_, Label, Value, C, M>
135    where
136        C: TryFromIterator<Label, M>,
137        Label: Clone,
138    {
139        PrefixIter::new(self, query)
140    }
141
142    /// Return the longest shared prefix or terminal of `query`.
143    pub fn longest_prefix<C, M>(&self, query: impl AsRef<[Label]>) -> Option<C>
144    where
145        C: TryFromIterator<Label, M>,
146        Label: Clone,
147    {
148        let mut cur_node_num = LoudsNodeNum(1);
149        let mut buffer = Vec::new();
150
151        // Consumes query (prefix)
152        for chr in query.as_ref() {
153            let children_node_nums: Vec<_> = self.children_node_nums(cur_node_num).collect();
154            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
155            match res {
156                Ok(i) => {
157                    cur_node_num = children_node_nums[i];
158                    buffer.push(cur_node_num);
159                }
160                Err(_) => {
161                    return None;
162                }
163            }
164        }
165
166        // Walk the trie as long as there is only one path and it isn't a terminal value.
167        while !self.is_terminal(cur_node_num) {
168            let mut iter = self.children_node_nums(cur_node_num);
169            let first = iter.next();
170            let second = iter.next();
171            match (first, second) {
172                (Some(child_node_num), None) => {
173                    cur_node_num = child_node_num;
174                    buffer.push(child_node_num);
175                }
176                _ => break,
177            }
178        }
179        if buffer.is_empty() {
180            None
181        } else {
182            Some(
183                buffer
184                    .into_iter()
185                    .map(|x| self.label(x).clone())
186                    .try_collect()
187                    .expect("Could not collect"),
188            )
189        }
190    }
191
192    pub(crate) fn has_children_node_nums(&self, node_num: LoudsNodeNum) -> bool {
193        self.louds
194            .parent_to_children_indices(node_num)
195            .next()
196            .is_some()
197    }
198
199    pub(crate) fn children_node_nums(&self, node_num: LoudsNodeNum) -> ChildNodeIter {
200        self.louds.parent_to_children_nodes(node_num)
201    }
202
203    pub(crate) fn bin_search_by_children_labels(
204        &self,
205        query: &Label,
206        children_node_nums: &[LoudsNodeNum],
207    ) -> Result<usize, usize> {
208        children_node_nums.binary_search_by(|child_node_num| self.label(*child_node_num).cmp(query))
209    }
210
211    pub(crate) fn label(&self, node_num: LoudsNodeNum) -> &Label {
212        &self.trie_labels[(node_num.0 - 2) as usize].label
213    }
214
215    pub(crate) fn is_terminal(&self, node_num: LoudsNodeNum) -> bool {
216        if node_num.0 >= 2 {
217            self.trie_labels[(node_num.0 - 2) as usize].value.is_some()
218        } else {
219            false
220        }
221    }
222
223    pub(crate) fn value(&self, node_num: LoudsNodeNum) -> Option<&Value> {
224        if node_num.0 >= 2 {
225            self.trie_labels[(node_num.0 - 2) as usize].value.as_ref()
226        } else {
227            None
228        }
229    }
230
231    pub(crate) fn value_mut(&mut self, node_num: LoudsNodeNum) -> Option<&mut Value> {
232        self.trie_labels[(node_num.0 - 2) as usize].value.as_mut()
233    }
234
235    pub(crate) fn child_to_ancestors(&self, node_num: LoudsNodeNum) -> AncestorNodeIter {
236        self.louds.child_to_ancestors(node_num)
237    }
238}
239
240impl<Label, Value, C> FromIterator<(C, Value)> for Trie<Label, Value>
241where
242    C: AsRef<[Label]>,
243    Label: Ord + Clone,
244{
245    fn from_iter<T>(iter: T) -> Self
246    where
247        Self: Sized,
248        T: IntoIterator<Item = (C, Value)>,
249    {
250        let mut builder = super::TrieBuilder::new();
251        for (k, v) in iter {
252            builder.push(k, v)
253        }
254        builder.build()
255    }
256}
257
258#[cfg(test)]
259mod search_tests {
260    use crate::map::{Trie, TrieBuilder};
261    use std::iter::FromIterator;
262
263    fn build_trie() -> Trie<u8, u8> {
264        let mut builder = TrieBuilder::new();
265        builder.push("a", 0);
266        builder.push("app", 1);
267        builder.push("apple", 2);
268        builder.push("better", 3);
269        builder.push("application", 4);
270        builder.push("アップル🍎", 5);
271        builder.build()
272    }
273
274    fn build_trie2() -> Trie<char, u8> {
275        let mut builder: TrieBuilder<char, u8> = TrieBuilder::new();
276        builder.insert("a".chars(), 0);
277        builder.insert("app".chars(), 1);
278        builder.insert("apple".chars(), 2);
279        builder.insert("better".chars(), 3);
280        builder.insert("application".chars(), 4);
281        builder.insert("アップル🍎".chars(), 5);
282        builder.build()
283    }
284
285    #[test]
286    fn sanity_check() {
287        let trie = build_trie();
288        let v: Vec<(String, &u8)> = trie.predictive_search("apple").collect();
289        assert_eq!(v, vec![("apple".to_string(), &2)]);
290    }
291
292    #[test]
293    fn clone() {
294        let trie = build_trie();
295        let _c: Trie<u8, u8> = trie.clone();
296    }
297
298    #[test]
299    fn value_mut() {
300        let mut trie = build_trie();
301        assert_eq!(trie.exact_match("apple"), Some(&2));
302        let v = trie.exact_match_mut("apple").unwrap();
303        *v = 10;
304        assert_eq!(trie.exact_match("apple"), Some(&10));
305    }
306
307    #[test]
308    fn trie_from_iter() {
309        let trie = Trie::<u8, u8>::from_iter([
310            ("a", 0),
311            ("app", 1),
312            ("apple", 2),
313            ("better", 3),
314            ("application", 4),
315        ]);
316        assert_eq!(trie.exact_match("application"), Some(&4));
317    }
318
319    #[test]
320    fn collect_a_trie() {
321        // Does not work with arrays in rust 2018 because into_iter() returns references instead of owned types.
322        // let trie: Trie<u8, u8> = [("a", 0), ("app", 1), ("apple", 2), ("better", 3), ("application", 4)].into_iter().collect();
323        let trie: Trie<u8, u8> = vec![
324            ("a", 0),
325            ("app", 1),
326            ("apple", 2),
327            ("better", 3),
328            ("application", 4),
329        ]
330        .into_iter()
331        .collect();
332        assert_eq!(trie.exact_match("application"), Some(&4));
333    }
334
335    #[test]
336    fn use_empty_queries() {
337        let trie = build_trie();
338        assert!(trie.exact_match("").is_none());
339        let _ = trie.predictive_search::<String, _>("").next();
340        let _ = trie.postfix_search::<String, _>("").next();
341        let _ = trie.common_prefix_search::<String, _>("").next();
342    }
343
344    #[test]
345    fn insert_order_dependent() {
346        let trie = Trie::from_iter([("a", 0), ("app", 1), ("apple", 2)]);
347        let results: Vec<(String, &u8)> = trie.iter().collect();
348        assert_eq!(
349            results,
350            [
351                ("a".to_string(), &0u8),
352                ("app".to_string(), &1u8),
353                ("apple".to_string(), &2u8)
354            ]
355        );
356
357        let trie = Trie::from_iter([("a", 0), ("apple", 2), ("app", 1)]);
358        let results: Vec<(String, &u8)> = trie.iter().collect();
359        assert_eq!(
360            results,
361            [
362                ("a".to_string(), &0u8),
363                ("app".to_string(), &1u8),
364                ("apple".to_string(), &2u8)
365            ]
366        );
367    }
368
369    mod exact_match_tests {
370        macro_rules! parameterized_tests {
371            ($($name:ident: $value:expr,)*) => {
372            $(
373                #[test]
374                fn $name() {
375                    let (query, expected_match) = $value;
376                    let trie = super::build_trie();
377                    let result = trie.exact_match(query);
378                    assert_eq!(result, expected_match);
379                }
380            )*
381            }
382        }
383
384        parameterized_tests! {
385            t1: ("a", Some(&0)),
386            t2: ("app", Some(&1)),
387            t3: ("apple", Some(&2)),
388            t4: ("application", Some(&4)),
389            t5: ("better", Some(&3)),
390            t6: ("アップル🍎", Some(&5)),
391            t7: ("appl", None),
392            t8: ("appler", None),
393        }
394    }
395
396    mod is_prefix_tests {
397        macro_rules! parameterized_tests {
398            ($($name:ident: $value:expr,)*) => {
399            $(
400                #[test]
401                fn $name() {
402                    let (query, expected_match) = $value;
403                    let trie = super::build_trie();
404                    let result = trie.is_prefix(query);
405                    assert_eq!(result, expected_match);
406                }
407            )*
408            }
409        }
410
411        parameterized_tests! {
412            t1: ("a", true),
413            t2: ("app", true),
414            t3: ("apple", false),
415            t4: ("application", false),
416            t5: ("better", false),
417            t6: ("アップル🍎", false),
418            t7: ("appl", true),
419            t8: ("appler", false),
420            t9: ("アップル", true),
421        }
422    }
423
424    mod longest_prefix_tests {
425        macro_rules! parameterized_tests {
426            ($($name:ident: $value:expr,)*) => {
427            $(
428                #[test]
429                fn $name() {
430                    let (query, expected_match) = $value;
431                    let trie = super::build_trie();
432                    let result: Option<String> = trie.longest_prefix(query);
433                    let expected_match = expected_match.map(str::to_string);
434                    assert_eq!(result, expected_match);
435                }
436            )*
437            }
438        }
439
440        parameterized_tests! {
441            t1: ("a", Some("a")),
442            t2: ("ap", Some("app")),
443            t3: ("app", Some("app")),
444            t4: ("appl", Some("appl")),
445            t5: ("appli", Some("application")),
446            t6: ("b", Some("better")),
447            t7: ("アップル🍎", Some("アップル🍎")),
448            t8: ("appler", None),
449            t9: ("アップル", Some("アップル🍎")),
450            t10: ("z", None),
451            t11: ("applesDONTEXIST", None),
452            t12: ("", None),
453        }
454    }
455
456    mod predictive_search_tests {
457        macro_rules! parameterized_tests {
458            ($($name:ident: $value:expr,)*) => {
459            $(
460                #[test]
461                fn $name() {
462                    let (query, expected_results) = $value;
463                    let trie = super::build_trie();
464                    let results: Vec<(String, &u8)> = trie.predictive_search(query).collect();
465                    let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
466                    assert_eq!(results, expected_results);
467                }
468            )*
469            }
470        }
471
472        parameterized_tests! {
473            t1: ("a", vec![("a", 0), ("app", 1), ("apple", 2), ("application", 4)]),
474            t2: ("app", vec![("app", 1), ("apple", 2), ("application", 4)]),
475            t3: ("appl", vec![("apple", 2), ("application", 4)]),
476            t4: ("apple", vec![("apple", 2)]),
477            t5: ("b", vec![("better", 3)]),
478            t6: ("c", Vec::<(&str, u8)>::new()),
479            t7: ("アップ", vec![("アップル🍎", 5)]),
480        }
481    }
482
483    mod common_prefix_search_tests {
484        macro_rules! parameterized_tests {
485            ($($name:ident: $value:expr,)*) => {
486            $(
487                #[test]
488                fn $name() {
489                    let (query, expected_results) = $value;
490                    let trie = super::build_trie();
491                    let results: Vec<(String, &u8)> = trie.common_prefix_search(query).collect();
492                    let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
493                    assert_eq!(results, expected_results);
494                }
495            )*
496            }
497        }
498
499        parameterized_tests! {
500            t1: ("a", vec![("a", 0)]),
501            t2: ("ap", vec![("a", 0)]),
502            t3: ("appl", vec![("a", 0), ("app", 1)]),
503            t4: ("appler", vec![("a", 0), ("app", 1), ("apple", 2)]),
504            t5: ("bette", Vec::<(&str, u8)>::new()),
505            t6: ("betterment", vec![("better", 3)]),
506            t7: ("c", Vec::<(&str, u8)>::new()),
507            t8: ("アップル🍎🍏", vec![("アップル🍎", 5)]),
508        }
509    }
510
511    mod postfix_search_tests {
512        macro_rules! parameterized_tests {
513            ($($name:ident: $value:expr,)*) => {
514            $(
515                #[test]
516                fn $name() {
517                    let (query, expected_results) = $value;
518                    let trie = super::build_trie();
519                    let results: Vec<(String, &u8)> = trie.postfix_search(query).collect();
520                    let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
521                    assert_eq!(results, expected_results);
522                }
523            )*
524            }
525        }
526
527        parameterized_tests! {
528            t1: ("a", vec![("pp", 1), ("pple", 2), ("pplication", 4)]),
529            t2: ("ap", vec![("p", 1), ("ple", 2), ("plication", 4)]),
530            t3: ("appl", vec![("e", 2), ("ication", 4)]),
531            t4: ("appler", Vec::<(&str, u8)>::new()),
532            t5: ("bette", vec![("r", 3)]),
533            t6: ("betterment", Vec::<(&str, u8)>::new()),
534            t7: ("c", Vec::<(&str, u8)>::new()),
535            t8: ("アップル🍎🍏", Vec::<(&str, u8)>::new()),
536        }
537    }
538
539    mod postfix_search_char_tests {
540        macro_rules! parameterized_tests {
541            ($($name:ident: $value:expr,)*) => {
542            $(
543                #[test]
544                fn $name() {
545                    let (query, expected_results) = $value;
546                    let trie = super::build_trie2();
547                    let chars: Vec<char> = query.chars().collect();
548                    let results: Vec<(String, &u8)> = trie.postfix_search(chars).collect();
549                    let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
550                    assert_eq!(results, expected_results);
551                }
552            )*
553            }
554        }
555
556        parameterized_tests! {
557            t1: ("a", vec![("pp", 1), ("pple", 2), ("pplication", 4)]),
558            t2: ("ap", vec![("p", 1), ("ple", 2), ("plication", 4)]),
559            t3: ("appl", vec![("e", 2), ("ication", 4)]),
560            t4: ("appler", Vec::<(&str, u8)>::new()),
561            t5: ("bette", vec![("r", 3)]),
562            t6: ("betterment", Vec::<(&str, u8)>::new()),
563            t7: ("c", Vec::<(&str, u8)>::new()),
564            t8: ("アップル🍎🍏", Vec::<(&str, u8)>::new()),
565        }
566    }
567}