1use crate::inc_search::IncSearch;
2use crate::iter::{Keys, KeysExt, PostfixIter, PrefixIter, SearchIter};
3use crate::map;
4use crate::try_collect::TryFromIterator;
5use std::iter::FromIterator;
6
7#[cfg(feature = "mem_dbg")]
8use mem_dbg::MemDbg;
9
10#[derive(Debug, Clone)]
11#[cfg_attr(feature = "mem_dbg", derive(mem_dbg::MemDbg, mem_dbg::MemSize))]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct Trie<Label>(pub map::Trie<Label, ()>);
15
16impl<Label: Ord> Trie<Label> {
17 pub fn exact_match(&self, query: impl AsRef<[Label]>) -> bool {
36 self.0.exact_match(query).is_some()
37 }
38
39 pub fn common_prefix_search<C, M>(
58 &self,
59 query: impl AsRef<[Label]>,
60 ) -> Keys<PrefixIter<'_, Label, (), C, M>>
61 where
62 C: TryFromIterator<Label, M>,
63 Label: Clone,
64 {
65 self.0.common_prefix_search(query).keys()
67 }
68
69 pub fn predictive_search<C, M>(
71 &self,
72 query: impl AsRef<[Label]>,
73 ) -> Keys<SearchIter<'_, Label, (), C, M>>
74 where
75 C: TryFromIterator<Label, M> + Clone,
76 Label: Clone,
77 {
78 self.0.predictive_search(query).keys()
79 }
80
81 pub fn postfix_search<C, M>(
104 &self,
105 query: impl AsRef<[Label]>,
106 ) -> Keys<PostfixIter<'_, Label, (), C, M>>
107 where
108 C: TryFromIterator<Label, M>,
109 Label: Clone,
110 {
111 self.0.postfix_search(query).keys()
112 }
113
114 pub fn iter<C, M>(&self) -> Keys<PostfixIter<'_, Label, (), C, M>>
132 where
133 C: TryFromIterator<Label, M>,
134 Label: Clone,
135 {
136 self.postfix_search([])
137 }
138
139 pub fn inc_search(&self) -> IncSearch<'_, Label, ()> {
142 IncSearch::new(&self.0)
143 }
144
145 pub fn is_prefix(&self, query: impl AsRef<[Label]>) -> bool {
150 self.0.is_prefix(query)
151 }
152
153 pub fn longest_prefix<C, M>(&self, query: impl AsRef<[Label]>) -> Option<C>
155 where
156 C: TryFromIterator<Label, M>,
157 Label: Clone,
158 {
159 self.0.longest_prefix(query)
160 }
161}
162
163impl<Label, C> FromIterator<C> for Trie<Label>
164where
165 C: AsRef<[Label]>,
166 Label: Ord + Clone,
167{
168 fn from_iter<T>(iter: T) -> Self
169 where
170 Self: Sized,
171 T: IntoIterator<Item = C>,
172 {
173 let mut builder = super::TrieBuilder::new();
174 for k in iter {
175 builder.push(k)
176 }
177 builder.build()
178 }
179}
180
181#[cfg(test)]
182mod search_tests {
183 use crate::{Trie, TrieBuilder};
184 use std::iter::FromIterator;
185
186 fn build_trie() -> Trie<u8> {
187 let mut builder = TrieBuilder::new();
188 builder.push("a");
189 builder.push("app");
190 builder.push("apple");
191 builder.push("better");
192 builder.push("application");
193 builder.push("アップル🍎");
194 builder.build()
195 }
196
197 #[test]
198 fn trie_from_iter() {
199 let trie = Trie::<u8>::from_iter(["a", "app", "apple", "better", "application"]);
200 assert!(trie.exact_match("application"));
201 }
202
203 #[test]
204 fn collect_a_trie() {
205 let trie: Trie<u8> =
206 IntoIterator::into_iter(["a", "app", "apple", "better", "application"]).collect();
207 assert!(trie.exact_match("application"));
208 }
209
210 #[test]
211 fn clone() {
212 let trie = build_trie();
213 let _c: Trie<u8> = trie.clone();
214 }
215
216 #[rustfmt::skip]
217 #[test]
218 fn print_debug() {
219 let trie: Trie<u8> = ["a"].into_iter().collect();
220 assert_eq!(format!("{:?}", trie),
221"Trie(Trie { louds: Louds { lbs: Fid { byte_vec: [160], bit_len: 5, chunks: Chunks { chunks: [Chunk { value: 2, blocks: Blocks { blocks: [Block { value: 1, length: 1 }, Block { value: 1, length: 1 }, Block { value: 2, length: 1 }, Block { value: 2, length: 1 }], blocks_cnt: 4 } }, Chunk { value: 2, blocks: Blocks { blocks: [Block { value: 0, length: 1 }], blocks_cnt: 1 } }], chunks_cnt: 2 }, table: PopcountTable { bit_length: 1, table: [0, 1] } } }, trie_labels: [TrieLabel { label: 97, value: Some(()) }] })"
222 );
223 }
224
225 #[rustfmt::skip]
226 #[test]
227 fn print_debug_builder() {
228
229 let mut builder = TrieBuilder::new();
230 builder.push("a");
231 builder.push("app");
232 assert_eq!(format!("{:?}", builder),
233"TrieBuilder(TrieBuilder { naive_trie: Root(NaiveTrieRoot { children: [IntermOrLeaf(NaiveTrieIntermOrLeaf { children: [IntermOrLeaf(NaiveTrieIntermOrLeaf { children: [IntermOrLeaf(NaiveTrieIntermOrLeaf { children: [], label: 112, value: Some(()) })], label: 112, value: None })], label: 97, value: Some(()) })] }) })"
234 );
235 }
236
237 #[test]
238 fn use_empty_queries() {
239 let trie = build_trie();
240 assert!(!trie.exact_match(""));
241 let _ = trie.predictive_search::<String, _>("").next();
242 let _ = trie.postfix_search::<String, _>("").next();
243 let _ = trie.common_prefix_search::<String, _>("").next();
244 }
245
246 #[cfg(feature = "mem_dbg")]
247 #[test]
248 fn memsize() {
252 use mem_dbg::*;
253 use std::{
254 env,
255 fs::File,
256 io::{BufRead, BufReader},
257 };
258
259 const COUNT: usize = 100;
260 let mut builder = TrieBuilder::new();
261
262 let repo_root = env::var("CARGO_MANIFEST_DIR")
263 .expect("CARGO_MANIFEST_DIR environment variable must be set.");
264 let edict2_path = format!("{}/benches/edict.furigana", repo_root);
265 println!("Reading dictionary file from: {}", edict2_path);
266
267 let mut n_words = 0;
268 let mut accum = 0;
269 for result in BufReader::new(File::open(edict2_path).unwrap())
270 .lines()
271 .take(COUNT)
272 {
273 let l = result.unwrap();
274 accum += l.len();
275 builder.push(l);
276 n_words += 1;
277 }
278 println!("Read {} words, {} bytes.", n_words, accum);
279
280 let trie = builder.build();
281 let trie_size = trie.mem_size(SizeFlags::default());
282 eprintln!("Trie size {trie_size}");
283 let uncompressed: Vec<String> = trie.iter().collect();
284 let uncompressed_size = uncompressed.mem_size(SizeFlags::default());
285 eprintln!("Uncompressed size {}", uncompressed_size);
286 assert!(accum < trie_size); assert!(trie_size < uncompressed_size);
288 }
289
290 mod exact_match_tests {
291 macro_rules! parameterized_tests {
292 ($($name:ident: $value:expr,)*) => {
293 $(
294 #[test]
295 fn $name() {
296 let (query, expected_match) = $value;
297 let trie = super::build_trie();
298 let result = trie.exact_match(query);
299 assert_eq!(result, expected_match);
300 }
301 )*
302 }
303 }
304
305 parameterized_tests! {
306 t1: ("a", true),
307 t2: ("app", true),
308 t3: ("apple", true),
309 t4: ("application", true),
310 t5: ("better", true),
311 t6: ("アップル🍎", true),
312 t7: ("appl", false),
313 t8: ("appler", false),
314 }
315 }
316
317 mod is_prefix_tests {
318 macro_rules! parameterized_tests {
319 ($($name:ident: $value:expr,)*) => {
320 $(
321 #[test]
322 fn $name() {
323 let (query, expected_match) = $value;
324 let trie = super::build_trie();
325 let result = trie.is_prefix(query);
326 assert_eq!(result, expected_match);
327 }
328 )*
329 }
330 }
331
332 parameterized_tests! {
333 t1: ("a", true),
334 t2: ("app", true),
335 t3: ("apple", false),
336 t4: ("application", false),
337 t5: ("better", false),
338 t6: ("アップル🍎", false),
339 t7: ("appl", true),
340 t8: ("appler", false),
341 t9: ("アップル", true),
342 t10: ("ed", false),
343 t11: ("e", false),
344 t12: ("", true),
345 }
346 }
347
348 mod predictive_search_tests {
349 macro_rules! parameterized_tests {
350 ($($name:ident: $value:expr,)*) => {
351 $(
352 #[test]
353 fn $name() {
354 let (query, expected_results) = $value;
355 let trie = super::build_trie();
356 let results: Vec<String> = trie.predictive_search(query).collect();
357 assert_eq!(results, expected_results);
358 }
359 )*
360 }
361 }
362
363 parameterized_tests! {
364 t1: ("a", vec!["a", "app", "apple", "application"]),
365 t2: ("app", vec!["app", "apple", "application"]),
366 t3: ("appl", vec!["apple", "application"]),
367 t4: ("apple", vec!["apple"]),
368 t5: ("b", vec!["better"]),
369 t6: ("c", Vec::<&str>::new()),
370 t7: ("アップ", vec!["アップル🍎"]),
371 }
372 }
373
374 mod common_prefix_search_tests {
375 macro_rules! parameterized_tests {
376 ($($name:ident: $value:expr,)*) => {
377 $(
378 #[test]
379 fn $name() {
380 let (query, expected_results) = $value;
381 let trie = super::build_trie();
382 let results: Vec<String> = trie.common_prefix_search(query).collect();
383 assert_eq!(results, expected_results);
384 }
385 )*
386 }
387 }
388
389 parameterized_tests! {
390 t1: ("a", vec!["a"]),
391 t2: ("ap", vec!["a"]),
392 t3: ("appl", vec!["a", "app"]),
393 t4: ("appler", vec!["a", "app", "apple"]),
394 t5: ("bette", Vec::<&str>::new()),
395 t6: ("betterment", vec!["better"]),
396 t7: ("c", Vec::<&str>::new()),
397 t8: ("アップル🍎🍏", vec!["アップル🍎"]),
398 }
399 }
400}