trie_rs/trie/
trie_impl.rs

1use crate::inc_search::IncSearch;
2use crate::iter::{Keys, KeysExt, PostfixIter, PrefixIter, SearchIter};
3use crate::map;
4use crate::try_collect::TryFromIterator;
5use std::iter::FromIterator;
6
7#[cfg(feature = "mem_dbg")]
8use mem_dbg::MemDbg;
9
10#[derive(Debug, Clone)]
11#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemDbg, mem_dbg::MemSize))]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13/// A trie for sequences of the type `Label`.
14pub struct Trie<Label>(pub map::Trie<Label, ()>);
15
16impl<Label: Ord> Trie<Label> {
17    /// Return true if `query` is an exact match.
18    ///
19    /// # Arguments
20    /// * `query` - The query to search for.
21    ///
22    /// # Examples
23    /// In the following example we illustrate how to query an exact match.
24    ///
25    /// ```rust
26    /// use trie_rs::Trie;
27    ///
28    /// let trie = Trie::from_iter(["a", "app", "apple", "better", "application"]);
29    ///
30    /// assert!(trie.exact_match("application"));
31    /// assert!(trie.exact_match("app"));
32    /// assert!(!trie.exact_match("appla"));
33    ///
34    /// ```
35    pub fn exact_match(&self, query: impl AsRef<[Label]>) -> bool {
36        self.0.exact_match(query).is_some()
37    }
38
39    /// Return the common prefixes of `query`.
40    ///
41    /// # Arguments
42    /// * `query` - The query to search for.
43    ///
44    /// # Examples
45    /// In the following example we illustrate how to query the common prefixes of a query string.
46    ///
47    /// ```rust
48    /// use trie_rs::Trie;
49    ///
50    /// let trie = Trie::from_iter(["a", "app", "apple", "better", "application"]);
51    ///
52    /// let results: Vec<String> = trie.common_prefix_search("application").collect();
53    ///
54    /// assert_eq!(results, vec!["a", "app", "application"]);
55    ///
56    /// ```
57    pub fn common_prefix_search<C, M>(
58        &self,
59        query: impl AsRef<[Label]>,
60    ) -> Keys<PrefixIter<'_, Label, (), C, M>>
61    where
62        C: TryFromIterator<Label, M>,
63        Label: Clone,
64    {
65        // TODO: We could return Keys iterators instead of collecting.
66        self.0.common_prefix_search(query).keys()
67    }
68
69    /// Return all entries that match `query`.
70    pub fn predictive_search<C, M>(
71        &self,
72        query: impl AsRef<[Label]>,
73    ) -> Keys<SearchIter<'_, Label, (), C, M>>
74    where
75        C: TryFromIterator<Label, M> + Clone,
76        Label: Clone,
77    {
78        self.0.predictive_search(query).keys()
79    }
80
81    /// Return the postfixes of all entries that match `query`.
82    ///
83    /// # Arguments
84    /// * `query` - The query to search for.
85    ///
86    /// # Examples
87    /// In the following example we illustrate how to query the postfixes of a query string.
88    ///
89    /// ```rust
90    /// use trie_rs::Trie;
91    ///
92    /// let trie = Trie::from_iter(["a", "app", "apple", "better", "application"]);
93    ///
94    /// let results: Vec<String> = trie.postfix_search("application").collect();
95    ///
96    /// assert!(results.is_empty());
97    ///
98    /// let results: Vec<String> = trie.postfix_search("app").collect();
99    ///
100    /// assert_eq!(results, vec!["le", "lication"]);
101    ///
102    /// ```
103    pub fn postfix_search<C, M>(
104        &self,
105        query: impl AsRef<[Label]>,
106    ) -> Keys<PostfixIter<'_, Label, (), C, M>>
107    where
108        C: TryFromIterator<Label, M>,
109        Label: Clone,
110    {
111        self.0.postfix_search(query).keys()
112    }
113
114    /// Returns an iterator across all keys in the trie.
115    ///
116    /// # Examples
117    /// In the following example we illustrate how to iterate over all keys in the trie.
118    /// Note that the order of the keys is not guaranteed, as they will be returned in
119    /// lexicographical order.
120    ///
121    /// ```rust
122    /// use trie_rs::Trie;
123    ///
124    /// let trie = Trie::from_iter(["a", "app", "apple", "better", "application"]);
125    ///
126    /// let results: Vec<String> = trie.iter().collect();
127    ///
128    /// assert_eq!(results, vec!["a", "app", "apple", "application", "better"]);
129    ///
130    /// ```
131    pub fn iter<C, M>(&self) -> Keys<PostfixIter<'_, Label, (), C, M>>
132    where
133        C: TryFromIterator<Label, M>,
134        Label: Clone,
135    {
136        self.postfix_search([])
137    }
138
139    /// Create an incremental search. Useful for interactive applications. See
140    /// [crate::inc_search] for details.
141    pub fn inc_search(&self) -> IncSearch<'_, Label, ()> {
142        IncSearch::new(&self.0)
143    }
144
145    /// Return true if `query` is a prefix.
146    ///
147    /// Note: A prefix may be an exact match or not, and an exact match may be a
148    /// prefix or not.
149    pub fn is_prefix(&self, query: impl AsRef<[Label]>) -> bool {
150        self.0.is_prefix(query)
151    }
152
153    /// Return the longest shared prefix of `query`.
154    pub fn longest_prefix<C, M>(&self, query: impl AsRef<[Label]>) -> Option<C>
155    where
156        C: TryFromIterator<Label, M>,
157        Label: Clone,
158    {
159        self.0.longest_prefix(query)
160    }
161}
162
163impl<Label, C> FromIterator<C> for Trie<Label>
164where
165    C: AsRef<[Label]>,
166    Label: Ord + Clone,
167{
168    fn from_iter<T>(iter: T) -> Self
169    where
170        Self: Sized,
171        T: IntoIterator<Item = C>,
172    {
173        let mut builder = super::TrieBuilder::new();
174        for k in iter {
175            builder.push(k)
176        }
177        builder.build()
178    }
179}
180
181#[cfg(test)]
182mod search_tests {
183    use crate::{Trie, TrieBuilder};
184    use std::iter::FromIterator;
185
186    fn build_trie() -> Trie<u8> {
187        let mut builder = TrieBuilder::new();
188        builder.push("a");
189        builder.push("app");
190        builder.push("apple");
191        builder.push("better");
192        builder.push("application");
193        builder.push("アップル🍎");
194        builder.build()
195    }
196
197    #[test]
198    fn trie_from_iter() {
199        let trie = Trie::<u8>::from_iter(["a", "app", "apple", "better", "application"]);
200        assert!(trie.exact_match("application"));
201    }
202
203    #[test]
204    fn collect_a_trie() {
205        let trie: Trie<u8> =
206            IntoIterator::into_iter(["a", "app", "apple", "better", "application"]).collect();
207        assert!(trie.exact_match("application"));
208    }
209
210    #[test]
211    fn clone() {
212        let trie = build_trie();
213        let _c: Trie<u8> = trie.clone();
214    }
215
216    #[rustfmt::skip]
217    #[test]
218    fn print_debug() {
219        let trie: Trie<u8> = ["a"].into_iter().collect();
220        assert_eq!(format!("{:?}", trie),
221"Trie(Trie { louds: Louds { lbs: Fid { byte_vec: [160], bit_len: 5, chunks: Chunks { chunks: [Chunk { value: 2, blocks: Blocks { blocks: [Block { value: 1, length: 1 }, Block { value: 1, length: 1 }, Block { value: 2, length: 1 }, Block { value: 2, length: 1 }], blocks_cnt: 4 } }, Chunk { value: 2, blocks: Blocks { blocks: [Block { value: 0, length: 1 }], blocks_cnt: 1 } }], chunks_cnt: 2 }, table: PopcountTable { bit_length: 1, table: [0, 1] } } }, trie_labels: [TrieLabel { label: 97, value: Some(()) }] })"
222        );
223    }
224
225    #[rustfmt::skip]
226    #[test]
227    fn print_debug_builder() {
228
229        let mut builder = TrieBuilder::new();
230        builder.push("a");
231        builder.push("app");
232        assert_eq!(format!("{:?}", builder),
233"TrieBuilder(TrieBuilder { naive_trie: Root(NaiveTrieRoot { children: [IntermOrLeaf(NaiveTrieIntermOrLeaf { children: [IntermOrLeaf(NaiveTrieIntermOrLeaf { children: [IntermOrLeaf(NaiveTrieIntermOrLeaf { children: [], label: 112, value: Some(()) })], label: 112, value: None })], label: 97, value: Some(()) })] }) })"
234        );
235    }
236
237    #[test]
238    fn use_empty_queries() {
239        let trie = build_trie();
240        assert!(!trie.exact_match(""));
241        let _ = trie.predictive_search::<String, _>("").next();
242        let _ = trie.postfix_search::<String, _>("").next();
243        let _ = trie.common_prefix_search::<String, _>("").next();
244    }
245
246    #[cfg(feature = "mem_dbg")]
247    #[test]
248    /// ```sh
249    /// cargo test --features mem_dbg memsize -- --nocapture
250    /// ```
251    fn memsize() {
252        use mem_dbg::*;
253        use std::{
254            env,
255            fs::File,
256            io::{BufRead, BufReader},
257        };
258
259        const COUNT: usize = 100;
260        let mut builder = TrieBuilder::new();
261
262        let repo_root = env::var("CARGO_MANIFEST_DIR")
263            .expect("CARGO_MANIFEST_DIR environment variable must be set.");
264        let edict2_path = format!("{}/benches/edict.furigana", repo_root);
265        println!("Reading dictionary file from: {}", edict2_path);
266
267        let mut n_words = 0;
268        let mut accum = 0;
269        for result in BufReader::new(File::open(edict2_path).unwrap())
270            .lines()
271            .take(COUNT)
272        {
273            let l = result.unwrap();
274            accum += l.len();
275            builder.push(l);
276            n_words += 1;
277        }
278        println!("Read {} words, {} bytes.", n_words, accum);
279
280        let trie = builder.build();
281        let trie_size = trie.mem_size(SizeFlags::default());
282        eprintln!("Trie size {trie_size}");
283        let uncompressed: Vec<String> = trie.iter().collect();
284        let uncompressed_size = uncompressed.mem_size(SizeFlags::default());
285        eprintln!("Uncompressed size {}", uncompressed_size);
286        assert!(accum < trie_size); // This seems wrong to me.
287        assert!(trie_size < uncompressed_size);
288    }
289
290    mod exact_match_tests {
291        macro_rules! parameterized_tests {
292            ($($name:ident: $value:expr,)*) => {
293            $(
294                #[test]
295                fn $name() {
296                    let (query, expected_match) = $value;
297                    let trie = super::build_trie();
298                    let result = trie.exact_match(query);
299                    assert_eq!(result, expected_match);
300                }
301            )*
302            }
303        }
304
305        parameterized_tests! {
306            t1: ("a", true),
307            t2: ("app", true),
308            t3: ("apple", true),
309            t4: ("application", true),
310            t5: ("better", true),
311            t6: ("アップル🍎", true),
312            t7: ("appl", false),
313            t8: ("appler", false),
314        }
315    }
316
317    mod is_prefix_tests {
318        macro_rules! parameterized_tests {
319            ($($name:ident: $value:expr,)*) => {
320            $(
321                #[test]
322                fn $name() {
323                    let (query, expected_match) = $value;
324                    let trie = super::build_trie();
325                    let result = trie.is_prefix(query);
326                    assert_eq!(result, expected_match);
327                }
328            )*
329            }
330        }
331
332        parameterized_tests! {
333            t1: ("a", true),
334            t2: ("app", true),
335            t3: ("apple", false),
336            t4: ("application", false),
337            t5: ("better", false),
338            t6: ("アップル🍎", false),
339            t7: ("appl", true),
340            t8: ("appler", false),
341            t9: ("アップル", true),
342            t10: ("ed", false),
343            t11: ("e", false),
344            t12: ("", true),
345        }
346    }
347
348    mod predictive_search_tests {
349        macro_rules! parameterized_tests {
350            ($($name:ident: $value:expr,)*) => {
351            $(
352                #[test]
353                fn $name() {
354                    let (query, expected_results) = $value;
355                    let trie = super::build_trie();
356                    let results: Vec<String> = trie.predictive_search(query).collect();
357                    assert_eq!(results, expected_results);
358                }
359            )*
360            }
361        }
362
363        parameterized_tests! {
364            t1: ("a", vec!["a", "app", "apple", "application"]),
365            t2: ("app", vec!["app", "apple", "application"]),
366            t3: ("appl", vec!["apple", "application"]),
367            t4: ("apple", vec!["apple"]),
368            t5: ("b", vec!["better"]),
369            t6: ("c", Vec::<&str>::new()),
370            t7: ("アップ", vec!["アップル🍎"]),
371        }
372    }
373
374    mod common_prefix_search_tests {
375        macro_rules! parameterized_tests {
376            ($($name:ident: $value:expr,)*) => {
377            $(
378                #[test]
379                fn $name() {
380                    let (query, expected_results) = $value;
381                    let trie = super::build_trie();
382                    let results: Vec<String> = trie.common_prefix_search(query).collect();
383                    assert_eq!(results, expected_results);
384                }
385            )*
386            }
387        }
388
389        parameterized_tests! {
390            t1: ("a", vec!["a"]),
391            t2: ("ap", vec!["a"]),
392            t3: ("appl", vec!["a", "app"]),
393            t4: ("appler", vec!["a", "app", "apple"]),
394            t5: ("bette", Vec::<&str>::new()),
395            t6: ("betterment", vec!["better"]),
396            t7: ("c", Vec::<&str>::new()),
397            t8: ("アップル🍎🍏", vec!["アップル🍎"]),
398        }
399    }
400}