1use super::Trie;
3use crate::inc_search::IncSearch;
4use crate::iter::{PostfixIter, PrefixIter, SearchIter};
5use crate::try_collect::{TryCollect, TryFromIterator};
6use louds_rs::{AncestorNodeIter, ChildNodeIter, LoudsNodeNum};
7use std::iter::FromIterator;
8
9impl<Label: Ord, Value> Trie<Label, Value> {
10 pub fn exact_match(&self, query: impl AsRef<[Label]>) -> Option<&Value> {
12 self.exact_match_node(query)
13 .and_then(move |x| self.value(x))
14 }
15
16 #[inline]
18 fn exact_match_node(&self, query: impl AsRef<[Label]>) -> Option<LoudsNodeNum> {
19 let mut cur_node_num = LoudsNodeNum(1);
20
21 for (i, chr) in query.as_ref().iter().enumerate() {
22 let children_node_nums: Vec<LoudsNodeNum> =
23 self.children_node_nums(cur_node_num).collect();
24 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
25
26 match res {
27 Ok(j) => {
28 let child_node_num = children_node_nums[j];
29 if i == query.as_ref().len() - 1 && self.is_terminal(child_node_num) {
30 return Some(child_node_num);
31 }
32 cur_node_num = child_node_num;
33 }
34 Err(_) => return None,
35 }
36 }
37 None
38 }
39
40 pub fn exact_match_mut(&mut self, query: impl AsRef<[Label]>) -> Option<&mut Value> {
42 self.exact_match_node(query)
43 .and_then(move |x| self.value_mut(x))
44 }
45
46 pub fn inc_search(&self) -> IncSearch<'_, Label, Value> {
49 IncSearch::new(self)
50 }
51
52 pub fn is_prefix(&self, query: impl AsRef<[Label]>) -> bool {
57 let mut cur_node_num = LoudsNodeNum(1);
58
59 for chr in query.as_ref().iter() {
60 let children_node_nums: Vec<_> = self.children_node_nums(cur_node_num).collect();
61 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
62 match res {
63 Ok(j) => cur_node_num = children_node_nums[j],
64 Err(_) => return false,
65 }
66 }
67 self.has_children_node_nums(cur_node_num)
69 }
70
71 pub fn predictive_search<C, M>(
73 &self,
74 query: impl AsRef<[Label]>,
75 ) -> SearchIter<'_, Label, Value, C, M>
76 where
77 C: TryFromIterator<Label, M> + Clone,
78 Label: Clone,
79 {
80 SearchIter::new(self, query)
81 }
82
83 pub fn postfix_search<C, M>(
85 &self,
86 query: impl AsRef<[Label]>,
87 ) -> PostfixIter<'_, Label, Value, C, M>
88 where
89 C: TryFromIterator<Label, M>,
90 Label: Clone,
91 {
92 let mut cur_node_num = LoudsNodeNum(1);
93
94 for chr in query.as_ref() {
96 let children_node_nums: Vec<_> = self.children_node_nums(cur_node_num).collect();
97 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
98 match res {
99 Ok(i) => cur_node_num = children_node_nums[i],
100 Err(_) => {
101 return PostfixIter::empty(self);
102 }
103 }
104 }
105
106 PostfixIter::new(self, cur_node_num)
107 }
108
109 pub fn iter<C, M>(&self) -> PostfixIter<'_, Label, Value, C, M>
123 where
124 C: TryFromIterator<Label, M>,
125 Label: Clone,
126 {
127 self.postfix_search([])
128 }
129
130 pub fn common_prefix_search<C, M>(
132 &self,
133 query: impl AsRef<[Label]>,
134 ) -> PrefixIter<'_, Label, Value, C, M>
135 where
136 C: TryFromIterator<Label, M>,
137 Label: Clone,
138 {
139 PrefixIter::new(self, query)
140 }
141
142 pub fn longest_prefix<C, M>(&self, query: impl AsRef<[Label]>) -> Option<C>
144 where
145 C: TryFromIterator<Label, M>,
146 Label: Clone,
147 {
148 let mut cur_node_num = LoudsNodeNum(1);
149 let mut buffer = Vec::new();
150
151 for chr in query.as_ref() {
153 let children_node_nums: Vec<_> = self.children_node_nums(cur_node_num).collect();
154 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
155 match res {
156 Ok(i) => {
157 cur_node_num = children_node_nums[i];
158 buffer.push(cur_node_num);
159 }
160 Err(_) => {
161 return None;
162 }
163 }
164 }
165
166 while !self.is_terminal(cur_node_num) {
168 let mut iter = self.children_node_nums(cur_node_num);
169 let first = iter.next();
170 let second = iter.next();
171 match (first, second) {
172 (Some(child_node_num), None) => {
173 cur_node_num = child_node_num;
174 buffer.push(child_node_num);
175 }
176 _ => break,
177 }
178 }
179 if buffer.is_empty() {
180 None
181 } else {
182 Some(
183 buffer
184 .into_iter()
185 .map(|x| self.label(x).clone())
186 .try_collect()
187 .expect("Could not collect"),
188 )
189 }
190 }
191
192 pub(crate) fn has_children_node_nums(&self, node_num: LoudsNodeNum) -> bool {
193 self.louds
194 .parent_to_children_indices(node_num)
195 .next()
196 .is_some()
197 }
198
199 pub(crate) fn children_node_nums(&self, node_num: LoudsNodeNum) -> ChildNodeIter {
200 self.louds.parent_to_children_nodes(node_num)
201 }
202
203 pub(crate) fn bin_search_by_children_labels(
204 &self,
205 query: &Label,
206 children_node_nums: &[LoudsNodeNum],
207 ) -> Result<usize, usize> {
208 children_node_nums.binary_search_by(|child_node_num| self.label(*child_node_num).cmp(query))
209 }
210
211 pub(crate) fn label(&self, node_num: LoudsNodeNum) -> &Label {
212 &self.trie_labels[(node_num.0 - 2) as usize].label
213 }
214
215 pub(crate) fn is_terminal(&self, node_num: LoudsNodeNum) -> bool {
216 if node_num.0 >= 2 {
217 self.trie_labels[(node_num.0 - 2) as usize].value.is_some()
218 } else {
219 false
220 }
221 }
222
223 pub(crate) fn value(&self, node_num: LoudsNodeNum) -> Option<&Value> {
224 if node_num.0 >= 2 {
225 self.trie_labels[(node_num.0 - 2) as usize].value.as_ref()
226 } else {
227 None
228 }
229 }
230
231 pub(crate) fn value_mut(&mut self, node_num: LoudsNodeNum) -> Option<&mut Value> {
232 self.trie_labels[(node_num.0 - 2) as usize].value.as_mut()
233 }
234
235 pub(crate) fn child_to_ancestors(&self, node_num: LoudsNodeNum) -> AncestorNodeIter {
236 self.louds.child_to_ancestors(node_num)
237 }
238}
239
240impl<Label, Value, C> FromIterator<(C, Value)> for Trie<Label, Value>
241where
242 C: AsRef<[Label]>,
243 Label: Ord + Clone,
244{
245 fn from_iter<T>(iter: T) -> Self
246 where
247 Self: Sized,
248 T: IntoIterator<Item = (C, Value)>,
249 {
250 let mut builder = super::TrieBuilder::new();
251 for (k, v) in iter {
252 builder.push(k, v)
253 }
254 builder.build()
255 }
256}
257
258#[cfg(test)]
259mod search_tests {
260 use crate::map::{Trie, TrieBuilder};
261 use std::iter::FromIterator;
262
263 fn build_trie() -> Trie<u8, u8> {
264 let mut builder = TrieBuilder::new();
265 builder.push("a", 0);
266 builder.push("app", 1);
267 builder.push("apple", 2);
268 builder.push("better", 3);
269 builder.push("application", 4);
270 builder.push("アップル🍎", 5);
271 builder.build()
272 }
273
274 fn build_trie2() -> Trie<char, u8> {
275 let mut builder: TrieBuilder<char, u8> = TrieBuilder::new();
276 builder.insert("a".chars(), 0);
277 builder.insert("app".chars(), 1);
278 builder.insert("apple".chars(), 2);
279 builder.insert("better".chars(), 3);
280 builder.insert("application".chars(), 4);
281 builder.insert("アップル🍎".chars(), 5);
282 builder.build()
283 }
284
285 #[test]
286 fn sanity_check() {
287 let trie = build_trie();
288 let v: Vec<(String, &u8)> = trie.predictive_search("apple").collect();
289 assert_eq!(v, vec![("apple".to_string(), &2)]);
290 }
291
292 #[test]
293 fn clone() {
294 let trie = build_trie();
295 let _c: Trie<u8, u8> = trie.clone();
296 }
297
298 #[test]
299 fn value_mut() {
300 let mut trie = build_trie();
301 assert_eq!(trie.exact_match("apple"), Some(&2));
302 let v = trie.exact_match_mut("apple").unwrap();
303 *v = 10;
304 assert_eq!(trie.exact_match("apple"), Some(&10));
305 }
306
307 #[test]
308 fn trie_from_iter() {
309 let trie = Trie::<u8, u8>::from_iter([
310 ("a", 0),
311 ("app", 1),
312 ("apple", 2),
313 ("better", 3),
314 ("application", 4),
315 ]);
316 assert_eq!(trie.exact_match("application"), Some(&4));
317 }
318
319 #[test]
320 fn collect_a_trie() {
321 let trie: Trie<u8, u8> = vec![
324 ("a", 0),
325 ("app", 1),
326 ("apple", 2),
327 ("better", 3),
328 ("application", 4),
329 ]
330 .into_iter()
331 .collect();
332 assert_eq!(trie.exact_match("application"), Some(&4));
333 }
334
335 #[test]
336 fn use_empty_queries() {
337 let trie = build_trie();
338 assert!(trie.exact_match("").is_none());
339 let _ = trie.predictive_search::<String, _>("").next();
340 let _ = trie.postfix_search::<String, _>("").next();
341 let _ = trie.common_prefix_search::<String, _>("").next();
342 }
343
344 #[test]
345 fn insert_order_dependent() {
346 let trie = Trie::from_iter([("a", 0), ("app", 1), ("apple", 2)]);
347 let results: Vec<(String, &u8)> = trie.iter().collect();
348 assert_eq!(
349 results,
350 [
351 ("a".to_string(), &0u8),
352 ("app".to_string(), &1u8),
353 ("apple".to_string(), &2u8)
354 ]
355 );
356
357 let trie = Trie::from_iter([("a", 0), ("apple", 2), ("app", 1)]);
358 let results: Vec<(String, &u8)> = trie.iter().collect();
359 assert_eq!(
360 results,
361 [
362 ("a".to_string(), &0u8),
363 ("app".to_string(), &1u8),
364 ("apple".to_string(), &2u8)
365 ]
366 );
367 }
368
369 mod exact_match_tests {
370 macro_rules! parameterized_tests {
371 ($($name:ident: $value:expr,)*) => {
372 $(
373 #[test]
374 fn $name() {
375 let (query, expected_match) = $value;
376 let trie = super::build_trie();
377 let result = trie.exact_match(query);
378 assert_eq!(result, expected_match);
379 }
380 )*
381 }
382 }
383
384 parameterized_tests! {
385 t1: ("a", Some(&0)),
386 t2: ("app", Some(&1)),
387 t3: ("apple", Some(&2)),
388 t4: ("application", Some(&4)),
389 t5: ("better", Some(&3)),
390 t6: ("アップル🍎", Some(&5)),
391 t7: ("appl", None),
392 t8: ("appler", None),
393 }
394 }
395
396 mod is_prefix_tests {
397 macro_rules! parameterized_tests {
398 ($($name:ident: $value:expr,)*) => {
399 $(
400 #[test]
401 fn $name() {
402 let (query, expected_match) = $value;
403 let trie = super::build_trie();
404 let result = trie.is_prefix(query);
405 assert_eq!(result, expected_match);
406 }
407 )*
408 }
409 }
410
411 parameterized_tests! {
412 t1: ("a", true),
413 t2: ("app", true),
414 t3: ("apple", false),
415 t4: ("application", false),
416 t5: ("better", false),
417 t6: ("アップル🍎", false),
418 t7: ("appl", true),
419 t8: ("appler", false),
420 t9: ("アップル", true),
421 }
422 }
423
424 mod longest_prefix_tests {
425 macro_rules! parameterized_tests {
426 ($($name:ident: $value:expr,)*) => {
427 $(
428 #[test]
429 fn $name() {
430 let (query, expected_match) = $value;
431 let trie = super::build_trie();
432 let result: Option<String> = trie.longest_prefix(query);
433 let expected_match = expected_match.map(str::to_string);
434 assert_eq!(result, expected_match);
435 }
436 )*
437 }
438 }
439
440 parameterized_tests! {
441 t1: ("a", Some("a")),
442 t2: ("ap", Some("app")),
443 t3: ("app", Some("app")),
444 t4: ("appl", Some("appl")),
445 t5: ("appli", Some("application")),
446 t6: ("b", Some("better")),
447 t7: ("アップル🍎", Some("アップル🍎")),
448 t8: ("appler", None),
449 t9: ("アップル", Some("アップル🍎")),
450 t10: ("z", None),
451 t11: ("applesDONTEXIST", None),
452 t12: ("", None),
453 }
454 }
455
456 mod predictive_search_tests {
457 macro_rules! parameterized_tests {
458 ($($name:ident: $value:expr,)*) => {
459 $(
460 #[test]
461 fn $name() {
462 let (query, expected_results) = $value;
463 let trie = super::build_trie();
464 let results: Vec<(String, &u8)> = trie.predictive_search(query).collect();
465 let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
466 assert_eq!(results, expected_results);
467 }
468 )*
469 }
470 }
471
472 parameterized_tests! {
473 t1: ("a", vec![("a", 0), ("app", 1), ("apple", 2), ("application", 4)]),
474 t2: ("app", vec![("app", 1), ("apple", 2), ("application", 4)]),
475 t3: ("appl", vec![("apple", 2), ("application", 4)]),
476 t4: ("apple", vec![("apple", 2)]),
477 t5: ("b", vec![("better", 3)]),
478 t6: ("c", Vec::<(&str, u8)>::new()),
479 t7: ("アップ", vec![("アップル🍎", 5)]),
480 }
481 }
482
483 mod common_prefix_search_tests {
484 macro_rules! parameterized_tests {
485 ($($name:ident: $value:expr,)*) => {
486 $(
487 #[test]
488 fn $name() {
489 let (query, expected_results) = $value;
490 let trie = super::build_trie();
491 let results: Vec<(String, &u8)> = trie.common_prefix_search(query).collect();
492 let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
493 assert_eq!(results, expected_results);
494 }
495 )*
496 }
497 }
498
499 parameterized_tests! {
500 t1: ("a", vec![("a", 0)]),
501 t2: ("ap", vec![("a", 0)]),
502 t3: ("appl", vec![("a", 0), ("app", 1)]),
503 t4: ("appler", vec![("a", 0), ("app", 1), ("apple", 2)]),
504 t5: ("bette", Vec::<(&str, u8)>::new()),
505 t6: ("betterment", vec![("better", 3)]),
506 t7: ("c", Vec::<(&str, u8)>::new()),
507 t8: ("アップル🍎🍏", vec![("アップル🍎", 5)]),
508 }
509 }
510
511 mod postfix_search_tests {
512 macro_rules! parameterized_tests {
513 ($($name:ident: $value:expr,)*) => {
514 $(
515 #[test]
516 fn $name() {
517 let (query, expected_results) = $value;
518 let trie = super::build_trie();
519 let results: Vec<(String, &u8)> = trie.postfix_search(query).collect();
520 let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
521 assert_eq!(results, expected_results);
522 }
523 )*
524 }
525 }
526
527 parameterized_tests! {
528 t1: ("a", vec![("pp", 1), ("pple", 2), ("pplication", 4)]),
529 t2: ("ap", vec![("p", 1), ("ple", 2), ("plication", 4)]),
530 t3: ("appl", vec![("e", 2), ("ication", 4)]),
531 t4: ("appler", Vec::<(&str, u8)>::new()),
532 t5: ("bette", vec![("r", 3)]),
533 t6: ("betterment", Vec::<(&str, u8)>::new()),
534 t7: ("c", Vec::<(&str, u8)>::new()),
535 t8: ("アップル🍎🍏", Vec::<(&str, u8)>::new()),
536 }
537 }
538
539 mod postfix_search_char_tests {
540 macro_rules! parameterized_tests {
541 ($($name:ident: $value:expr,)*) => {
542 $(
543 #[test]
544 fn $name() {
545 let (query, expected_results) = $value;
546 let trie = super::build_trie2();
547 let chars: Vec<char> = query.chars().collect();
548 let results: Vec<(String, &u8)> = trie.postfix_search(chars).collect();
549 let expected_results: Vec<(String, &u8)> = expected_results.iter().map(|s| (s.0.to_string(), &s.1)).collect();
550 assert_eq!(results, expected_results);
551 }
552 )*
553 }
554 }
555
556 parameterized_tests! {
557 t1: ("a", vec![("pp", 1), ("pple", 2), ("pplication", 4)]),
558 t2: ("ap", vec![("p", 1), ("ple", 2), ("plication", 4)]),
559 t3: ("appl", vec![("e", 2), ("ication", 4)]),
560 t4: ("appler", Vec::<(&str, u8)>::new()),
561 t5: ("bette", vec![("r", 3)]),
562 t6: ("betterment", Vec::<(&str, u8)>::new()),
563 t7: ("c", Vec::<(&str, u8)>::new()),
564 t8: ("アップル🍎🍏", Vec::<(&str, u8)>::new()),
565 }
566 }
567}