1use crate::iterator::KeyValueRef;
63
64use std::iter::FromIterator;
65
66#[cfg(feature = "serde")]
67use serde_crate::{Deserialize, Serialize};
68
69pub type TrieString<V> = Trie<String, char, V>;
73
74pub type TrieVec<A, V> = Trie<Vec<A>, A, V>;
76
77pub trait TrieAtom: Copy + Default + PartialEq + Ord {}
80
81impl<A> TrieAtom for A
83where
84 A: Copy + Default + PartialEq + Ord,
85{
86 }
89
90pub trait TrieKey<A>: Clone + Default + Ord + FromIterator<A> {}
93
94impl<A, K> TrieKey<A> for K
96where
97 K: Clone + Default + Ord + FromIterator<A>,
98{
99 }
102
103pub trait TrieValue: Default {}
106
107impl<V> TrieValue for V
109where
110 V: Default,
111{
112 }
115
116#[derive(Clone, Debug, Default, PartialEq)]
117#[cfg_attr(
118 feature = "serde",
119 derive(Serialize, Deserialize),
120 serde(crate = "serde_crate")
121)]
122pub(crate) struct AtomValue<A, V> {
123 pub(crate) atom: A,
124 pub(crate) value: Option<V>,
125}
126
127#[derive(Clone, Debug, Default, PartialEq)]
128#[cfg_attr(
129 feature = "serde",
130 derive(Serialize, Deserialize),
131 serde(crate = "serde_crate")
132)]
133pub(crate) struct Node<A, V> {
134 pub(crate) children: Vec<Node<A, V>>,
135 pub(crate) pair: AtomValue<A, V>,
136 pub(crate) terminated: bool,
137}
138
139#[derive(Clone, Debug, Default, PartialEq)]
141#[cfg_attr(
142 feature = "serde",
143 derive(Serialize, Deserialize),
144 serde(crate = "serde_crate")
145)]
146pub struct Trie<K, A, V> {
147 pub(crate) head: Node<A, V>,
148 count: usize,
149 phantom: std::marker::PhantomData<K>,
150 atoms: usize,
151}
152
153impl<A: TrieAtom, V: TrieValue> Node<A, V> {
154 fn new(pair: AtomValue<A, V>) -> Self {
155 Self {
156 pair,
157 ..Default::default()
158 }
159 }
160
161 fn terminated(pair: AtomValue<A, V>) -> Self {
162 Self {
163 pair,
164 terminated: true,
165 ..Default::default()
166 }
167 }
168}
169
170impl<K: TrieKey<A>, A: TrieAtom, V: TrieValue> Trie<K, A, V> {
171 pub fn new() -> Self {
173 Self {
174 head: Node::default(),
175 ..Default::default()
176 }
177 }
178
179 pub fn clear(&mut self) {
181 self.head = Node::default();
182 self.count = 0;
183 self.atoms = 0;
184 }
185
186 pub fn contains<I: IntoIterator<Item = A>>(&self, key: I) -> bool {
188 self.contains_internal(key, |n: &Node<A, V>| (n.terminated, None))
189 .0
190 }
191
192 pub fn contains_prefix<P: IntoIterator<Item = A>>(&self, prefix: P) -> bool {
194 self.contains_internal(prefix, |_| (true, None)).0
195 }
196
197 #[inline(always)]
199 pub fn count(&self) -> usize {
200 self.count
201 }
202
203 #[inline(always)]
205 pub fn atoms(&self) -> usize {
206 self.atoms
207 }
208
209 pub fn get<I: IntoIterator<Item = A>>(&self, key: I) -> Option<&V> {
211 self.contains_internal(key, |n: &Node<A, V>| (n.terminated, n.pair.value.as_ref()))
212 .1
213 }
214
215 pub fn get_alternatives<I: Clone + IntoIterator<Item = A>>(
217 &self,
218 key: I,
219 limit: usize,
220 ) -> Vec<K> {
221 if self
223 .contains_internal(key.clone(), |n: &Node<A, V>| (n.terminated, None))
224 .0
225 {
226 vec![K::from_iter(key)]
227 } else {
228 let mut new_key: Vec<A> = vec![];
229
230 let mut atoms = key.into_iter().peekable();
231 while let Some(atom) = atoms.next() {
232 let last_idx = atoms.peek().is_none();
233 if last_idx {
234 break;
235 } else {
236 new_key.push(atom);
237 }
238 }
239 let mut base = vec![];
240
241 let mut node = &self.head;
242
243 for atom in new_key.into_iter() {
244 match node.children.iter().find(|x| x.pair.atom == atom) {
245 Some(n) => {
246 base.push(n.pair.atom);
247 node = n;
248 }
249 None => {
250 break;
251 }
252 }
253 }
254
255 let mut alternatives = vec![];
257
258 'outer: loop {
260 for mut child in node.children.iter().take(limit) {
261 let mut alternative = base.clone();
263 while !child.terminated && !child.children.is_empty() {
264 alternative.push(child.pair.atom);
265 child = &child.children[0];
266 }
267 alternative.push(child.pair.atom);
268
269 let candidate = K::from_iter(alternative);
271 if !alternatives.contains(&candidate) {
272 alternatives.push(candidate);
273 if alternatives.len() == limit {
275 break 'outer;
276 }
277 }
278 }
279 if node.children.is_empty() {
282 break;
283 } else {
284 node = &node.children[0];
285 base.push(node.pair.atom);
286 }
287 }
288 alternatives
289 }
290 }
291
292 pub fn get_lcps(&self) -> Vec<K> {
297 let mut result = vec![];
299 for node in self.head.children.iter() {
300 let mut lcp: Vec<A> = vec![];
301 let mut current_node = node;
302 while current_node.children.len() == 1 && !current_node.terminated {
303 lcp.push(current_node.pair.atom);
304 current_node = current_node.children.get(0).unwrap();
305 }
306 lcp.push(current_node.pair.atom);
307 result.push(lcp.into_iter().collect());
308 }
309 result
310 }
311
312 pub fn get_sup<I: IntoIterator<Item = A>>(&self, key: I) -> Option<K> {
318 let mut node = &self.head;
319
320 let mut nodes = vec![];
321 let mut master = vec![];
322
323 for atom in key {
324 match node.children.iter().find(|x| x.pair.atom == atom) {
325 Some(n) => {
326 nodes.push((n, node.children.len()));
327 master.push(n.pair.atom);
328 node = n;
329 }
330 None => {
331 return None;
333 }
334 }
335 }
336
337 let mut remove = 0;
341 nodes.reverse();
342
343 for node in nodes {
344 if node.1 == 1 {
345 remove += 1;
346 } else {
347 master.truncate(master.len() - remove);
348 return Some(master.into_iter().collect());
349 }
350 }
351 None
352 }
353
354 pub fn insert<I: IntoIterator<Item = A>>(&mut self, key: I) -> Option<V> {
358 self.insert_with_value(key, None)
359 }
360
361 pub fn insert_with_value<I: IntoIterator<Item = A>>(
365 &mut self,
366 key: I,
367 value: Option<V>,
368 ) -> Option<V> {
369 let mut node = &mut self.head;
370 let mut atoms = key.into_iter().peekable();
371 let mut result = None;
372
373 while let Some(atom) = atoms.next() {
374 let last_idx = atoms.peek().is_none();
375
376 let node_index = match node
377 .children
378 .iter_mut()
379 .enumerate()
380 .find(|(_i, x)| x.pair.atom == atom)
381 {
382 Some((i, mut n)) => {
383 if last_idx {
384 if !n.terminated {
385 self.count += 1;
386 }
387 result = n.pair.value.take();
388 n.pair.value = value;
389 n.terminated = true;
390 break;
391 }
392 i
393 }
394 None => {
395 if last_idx {
396 self.count += 1;
397 let new_node = Node::terminated(AtomValue { atom, value });
398 self.atoms += 1;
399 node.children.push(new_node);
400 break;
401 } else {
402 let new_node = Node::new(AtomValue { atom, value: None });
403 self.atoms += 1;
404 node.children.push(new_node);
405 };
406 node.children.len() - 1
407 }
408 };
409 node = node.children.get_mut(node_index).unwrap();
411 }
412 result
413 }
414
415 pub fn is_empty(&self) -> bool {
417 self.head.children.is_empty()
418 }
419
420 pub fn iter(&self) -> impl Iterator<Item = KeyValueRef<'_, K, A, V>> {
422 self.into_iter()
423 }
424
425 pub fn iter_sorted(&self) -> impl Iterator<Item = KeyValueRef<'_, K, A, V>> {
427 let mut v = self.into_iter().collect::<Vec<KeyValueRef<'_, K, A, V>>>();
428 v.sort_by_key(|x| x.key.clone());
429 v.into_iter()
430 }
431
432 pub fn remove<I: IntoIterator<Item = A>>(&mut self, key: I) -> Option<V> {
436 let closure = |mut n: &mut Node<A, V>| {
437 let present = n.terminated;
438 n.terminated = false;
439 (present, n.pair.value.take())
440 };
441 let result = self.contains_internal_mut(key, closure);
442 if result.0 {
443 self.count -= 1;
444 }
445 result.1
446 }
447
448 fn contains_internal<F: Fn(&Node<A, V>) -> (bool, Option<&V>), I: IntoIterator<Item = A>>(
449 &self,
450 key: I,
451 f: F,
452 ) -> (bool, Option<&V>) {
453 let mut node = &self.head;
454 let mut atoms = key.into_iter().peekable();
455 while let Some(atom) = atoms.next() {
456 let last_idx = atoms.peek().is_none();
457
458 match node.children.iter().find(|x| x.pair.atom == atom) {
459 Some(n) => {
460 if last_idx {
461 return f(n);
462 }
463 node = n;
464 }
465 None => {
466 break;
467 }
468 }
469 }
470 (false, None)
471 }
472
473 fn contains_internal_mut<
474 F: Fn(&mut Node<A, V>) -> (bool, Option<V>),
475 I: IntoIterator<Item = A>,
476 >(
477 &mut self,
478 key: I,
479 f: F,
480 ) -> (bool, Option<V>) {
481 let mut node = &mut self.head;
482 let mut atoms = key.into_iter().peekable();
483 while let Some(atom) = atoms.next() {
484 let last_idx = atoms.peek().is_none();
485
486 match node.children.iter_mut().find(|x| x.pair.atom == atom) {
487 Some(n) => {
488 if last_idx {
489 return f(n);
490 }
491 node = n;
492 }
493 None => {
494 break;
495 }
496 }
497 }
498 (false, None)
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use unicode_segmentation::UnicodeSegmentation;
506
507 #[test]
508 fn it_inserts_new_key() {
509 let mut trie = TrieString::<usize>::new();
510 trie.insert("abcdef".chars());
511 }
512
513 #[test]
514 fn it_finds_exact_key() {
515 let mut trie = TrieString::<usize>::new();
516 let input = "abcdef".chars();
517 trie.insert(input.clone());
518 assert!(trie.contains(input));
519 }
520
521 #[test]
522 fn it_cannot_find_longer_key() {
523 let mut trie = TrieString::<usize>::new();
524 let input = "abcdef".chars();
525 let long_input = "abcdefg".chars();
526 trie.insert(input);
527 assert!(!trie.contains(long_input));
528 }
529
530 #[test]
531 fn it_cannot_find_shorter_key() {
532 let mut trie = TrieString::<usize>::new();
533 let input = "abcdef".chars();
534 let short_input = "abcde".chars();
535 trie.insert(input);
536 assert!(!trie.contains(short_input));
537 }
538
539 #[test]
540 fn it_can_find_multiple_overlapping_keys() {
541 let mut trie = TrieString::<usize>::new();
542 let input = "abcdef".chars();
543 trie.insert(input.clone());
544 let short_input = "abc".chars();
545 trie.insert(short_input.clone());
546 assert!(trie.contains(short_input));
547 assert!(trie.contains(input));
548 }
549
550 #[test]
551 fn it_can_find_prefix_keys() {
552 let mut trie = TrieString::<usize>::new();
553 let input = "abcdef".chars();
554 let short_input = "abc".chars();
555 trie.insert(input);
556 assert!(trie.contains_prefix(short_input));
557 }
558
559 #[test]
560 fn it_can_remove_a_present_key() {
561 let mut trie = TrieString::<usize>::new();
562 let input = "abcdef".chars();
563 trie.insert(input.clone());
564 assert!(trie.contains(input.clone()));
565 assert!(trie.remove(input.clone()).is_none());
566 assert!(!trie.contains(input));
567 }
568
569 #[test]
570 fn it_can_remove_a_missing_key() {
571 let mut trie = TrieString::<usize>::new();
572 let input = "abcdef".chars();
573 assert!(trie.remove(input.clone()).is_none());
574 assert!(!trie.contains(input));
575 }
576
577 #[test]
578 fn it_can_return_previously_inserted_value() {
579 let mut trie = TrieString::<usize>::new();
580 let input = "abcdef".chars();
581 trie.insert_with_value(input.clone(), Some(666));
582 assert_eq!(trie.insert_with_value(input.clone(), Some(667)), Some(666));
583 assert_eq!(trie.remove(input.clone()), Some(667));
584 assert_eq!(trie.remove(input.clone()), None);
585 assert!(!trie.contains(input));
586 }
587
588 #[test]
589 fn it_can_create_an_empty_trie() {
590 let trie = TrieString::<usize>::new();
591 assert!(trie.is_empty());
592 }
593
594 #[test]
595 fn it_can_clear_a_trie() {
596 let mut trie = TrieString::<usize>::new();
597 let input = "abcdef".chars();
598 trie.insert(input.clone());
599 trie.clear();
600 assert!(trie.is_empty());
601 assert!(!trie.contains(input));
602 }
603
604 #[test]
605 fn it_can_count_entries() {
606 let mut trie = TrieString::<usize>::new();
607 let input = "abcdef".chars();
608 trie.insert(input.clone());
609 assert_eq!(1, trie.count());
610 trie.insert(input.clone());
611 trie.insert(input.clone());
612 assert_eq!(1, trie.count());
613 trie.remove(input.clone());
614 assert_eq!(0, trie.count());
615 trie.clear();
616 assert_eq!(0, trie.count());
617 assert!(trie.is_empty());
618 assert!(!trie.contains(input));
619 }
620
621 #[test]
623 fn it_inserts_new_usize_key() {
624 let mut trie = TrieVec::<usize, usize>::new();
625 let input: Vec<usize> = vec![0, 1, 2, 3, 4, 5, 6];
626 trie.insert(input);
627 }
628
629 #[test]
630 fn it_finds_exact_usize_key() {
631 let mut trie = TrieVec::<usize, usize>::new();
632 let input = [0, 1, 2, 3, 4, 5, 6];
633 trie.insert(input);
634 assert!(trie.contains(input));
635 }
636
637 #[test]
638 fn it_cannot_find_short_usize_key() {
639 let mut trie = TrieVec::<usize, usize>::new();
640 let input = [0, 1, 2, 3, 4, 5, 6];
641 let input_short = [0, 1, 2, 3, 4, 5];
642 trie.insert(input);
643 assert!(!trie.contains(input_short));
644 }
645
646 #[test]
648 fn it_can_process_grapheme_clusters() {
649 let mut trie = TrieVec::<&str, bool>::new();
650 let s = "a̐éö̲\r\n";
651 let input = s.graphemes(true);
652 trie.insert(input.clone());
653 assert!(trie.contains(input.clone()));
654 assert!(trie.remove(input.clone()).is_none());
655 assert!(!trie.contains(input));
656 }
657
658 #[test]
660 fn it_can_process_str_clusters() {
661 let mut trie = TrieVec::<&str, usize>::new();
662 let input = "the quick brown fox".split_whitespace();
663 trie.insert_with_value(input.clone(), Some(5));
664 assert_eq!(trie.get(input.clone()), Some(&5));
665 assert!(trie.contains(input.clone()));
666 assert!(trie.remove(input.clone()).is_some());
667 assert!(!trie.contains(input));
668 }
669
670 #[test]
672 fn it_serializes_trie_to_json() {
673 let mut t1 = TrieVec::<usize, usize>::new();
674 let input = [0, 1, 2, 3, 4, 5, 6];
675 t1.insert(input);
676 let t_str = serde_json::to_string(&t1).expect("serializing");
679 let t2: TrieVec<usize, usize> = serde_json::from_str(&t_str).expect("deserializing");
680 assert_eq!(t1, t2);
681 }
682 #[test]
683 fn it_can_find_lcp() {
684 let input = vec![
685 "code",
686 "coder",
687 "coding",
688 "codable",
689 "codec",
690 "codecs",
691 "coded",
692 "codeless",
693 "codependence",
694 "codependency",
695 "codependent",
696 "codependents",
697 "codes",
698 "a",
699 "codesign",
700 "codesigned",
701 "codeveloped",
702 "codeveloper",
703 "abc",
704 "codex",
705 "codify",
706 "codiscovered",
707 "codrive",
708 "abz",
709 ];
710 let mut trie = TrieString::<()>::new();
711 for entry in input {
712 trie.insert(entry.chars());
713 }
714 assert_eq!(vec!["cod", "a"], trie.get_lcps());
715 }
716
717 #[test]
718 fn it_can_find_lcp_usize() {
719 let input = vec![
720 vec![1, 11, 111, 1111],
721 vec![1, 11, 111, 1111, 11112],
722 vec![1, 11, 111, 1111, 11113],
723 ];
724 let mut trie = TrieVec::<usize, ()>::new();
725 for entry in input {
726 trie.insert(entry);
727 }
728 assert_eq!(vec![vec![1, 11, 111, 1111]], trie.get_lcps());
729 }
730
731 #[test]
732 fn it_can_find_sups_that_exist() {
733 let input = vec!["AND", "BONFIRE", "BOOL", "CASE", "CATCH", "CHAR"];
734 let output = vec!["A", "BON", "BOO", "CAS", "CAT", "CH"];
735 let mut trie = TrieString::<()>::new();
736
737 for entry in input.clone() {
738 trie.insert(entry.chars());
739 }
740
741 for (inn, out) in input.into_iter().zip(output.into_iter()) {
742 assert_eq!(trie.get_sup(inn.to_string().chars()), Some(out.to_string()));
743 }
744 }
745
746 #[test]
747 fn it_cannot_find_sups_that_have_prefixes() {
748 let base = vec!["AND", "BONFIRE", "BOOL", "CASE", "CATCH", "CHAR"];
749 let input = vec!["ANDY", "BONFIREY", "BOOLY", "CASEY", "CATCHY", "CHARY"];
750 let output = vec![None; 6];
751 let mut trie = TrieString::<()>::new();
752
753 for entry in base {
754 trie.insert(entry.chars());
755 }
756
757 for (inn, out) in input.into_iter().zip(output.into_iter()) {
758 assert_eq!(trie.get_sup(inn.to_string().chars()), out);
759 }
760 }
761
762 #[test]
763 fn it_cannot_find_sups_that_are_just_wrong() {
764 let base = vec!["AND", "BONFIRE", "BOOL", "CASE", "CATCH", "CHAR"];
765 let input = vec!["WHAT", "IS", "THIS", "TEST", "ALL", "ABOUT"];
766 let output = vec![None; 6];
767 let mut trie = TrieString::<()>::new();
768
769 for entry in base {
770 trie.insert(entry.chars());
771 }
772
773 for (inn, out) in input.into_iter().zip(output.into_iter()) {
774 assert_eq!(trie.get_sup(inn.to_string().chars()), out);
775 }
776 }
777
778 #[test]
779 fn it_can_iter_sorted() {
780 let mut input = vec![
781 "lexicographic",
782 "sorting",
783 "of",
784 "a",
785 "set",
786 "of",
787 "keys",
788 "can",
789 "be",
790 "accomplished",
791 "with",
792 "a",
793 "simple",
794 "trie",
795 "based",
796 "algorithm",
797 "we",
798 "insert",
799 "all",
800 "keys",
801 "in",
802 "a",
803 "trie",
804 "output",
805 "all",
806 "keys",
807 "in",
808 "the",
809 "trie",
810 "by",
811 "means",
812 "of",
813 "preorder",
814 "traversal",
815 "which",
816 "results",
817 "in",
818 "output",
819 "that",
820 "is",
821 "in",
822 "lexicographically",
823 "increasing",
824 "order",
825 "preorder",
826 "traversal",
827 "is",
828 "a",
829 "kind",
830 "of",
831 "depth",
832 "first",
833 "traversal",
834 ];
835 let mut trie = TrieString::<()>::new();
836 for entry in &input {
837 trie.insert(entry.chars());
838 }
839 let sorted_words: Vec<String> = trie.iter_sorted().map(|x| x.key).collect();
840 input.sort();
842 input.dedup();
843 assert_eq!(input, sorted_words);
844 }
845
846 #[test]
847 fn it_can_find_maximum_occurring_entry() {
848 let input = vec![
849 "code",
850 "coder",
851 "coding",
852 "codable",
853 "codec",
854 "codecs",
855 "coded",
856 "codeless",
857 "codec",
858 "codecs",
859 "codependence",
860 "codex",
861 "codify",
862 "codependents",
863 "codes",
864 "code",
865 "coder",
866 "codesign",
867 "codec",
868 "codeveloper",
869 "codrive",
870 "codec",
871 "codecs",
872 "codiscovered",
873 ];
874 let mut trie = TrieString::<usize>::new();
875 for entry in input {
876 let ch = entry.chars();
877 let value = match trie.get(ch.clone()) {
878 Some(v) => v + 1,
879 None => 1,
880 };
881 trie.insert_with_value(ch, Some(value));
882 }
883 let mut answer = None;
884 let mut highest = 0;
885 for entry in trie.iter() {
886 if let Some(&v) = entry.value {
887 if v > highest {
888 highest = v;
889 answer = Some(entry.key.clone());
890 }
891 }
892 }
893 assert_eq!(highest, 4);
895 assert_eq!(answer, Some("codec".to_string()));
896 }
897
898 #[test]
899 fn it_can_find_alternatives() {
900 let input = vec![
901 "code",
902 "coder",
903 "coding",
904 "codable",
905 "codec",
906 "codecs",
907 "coded",
908 "codeless",
909 "codec",
910 "codecs",
911 "codependence",
912 "codex",
913 "codify",
914 "codependents",
915 "codes",
916 "code",
917 "coder",
918 "codesign",
919 "codec",
920 "codeveloper",
921 "codrive",
922 "codec",
923 "codecs",
924 "codiscovered",
925 ];
926 let mut trie = TrieString::<()>::new();
927 for entry in input {
928 let ch = entry.chars();
929 trie.insert(ch);
930 }
931 assert_eq!(
932 trie.get_alternatives("codg".chars(), 5),
933 ["code", "coding", "codable", "codrive", "coder"]
934 )
935 }
936}