1use std::{collections::HashMap, fmt::Debug, iter, str};
2
3use regex::bytes::Regex;
4
5pub type Key = Vec<u8>;
6pub type KeyValue<K, V> = (K, V);
7
8#[derive(Debug, PartialEq, Eq)]
9pub enum InsertResult {
10 Ok,
11 Existing,
12 Failed,
13}
14
15#[derive(Debug, PartialEq, Eq)]
16pub enum RemoveResult {
17 Ok,
18 NotFound,
19}
20
21fn find_last_dot(input: &[u8]) -> Option<usize> {
22 (0..input.len()).rev().find(|&i| input[i] == b'.')
24}
25
26fn find_last_slash(input: &[u8]) -> Option<usize> {
27 (0..input.len()).rev().find(|&i| input[i] == b'/')
29}
30
31#[derive(Debug, Default)]
38pub struct TrieNode<V> {
39 key_value: Option<KeyValue<Key, V>>,
40 wildcard: Option<KeyValue<Key, V>>,
41 children: HashMap<Key, TrieNode<V>>,
42 regexps: Vec<(Regex, TrieNode<V>)>,
43}
44
45#[derive(Debug)]
52pub enum TrieSubMatch<'a, 'b> {
53 Wildcard(&'a [u8]),
54 Regexp(&'a [u8], &'b Regex),
55}
56
57pub type TrieMatches<'a, 'b> = Vec<TrieSubMatch<'a, 'b>>;
62
63impl<V: PartialEq> std::cmp::PartialEq for TrieNode<V> {
64 fn eq(&self, other: &Self) -> bool {
65 self.key_value == other.key_value
66 && self.wildcard == other.wildcard
67 && self.children == other.children
68 && self.regexps.len() == other.regexps.len()
69 && self
70 .regexps
71 .iter()
72 .zip(other.regexps.iter())
73 .fold(true, |b, (left, right)| {
74 b && left.0.as_str() == right.0.as_str() && left.1 == right.1
75 })
76 }
77}
78
79impl<V: Debug + Clone> TrieNode<V> {
80 pub fn new(key: Key, value: V) -> TrieNode<V> {
81 TrieNode {
82 key_value: Some((key, value)),
83 wildcard: None,
84 children: HashMap::new(),
85 regexps: Vec::new(),
86 }
87 }
88
89 pub fn wildcard(key: Key, value: V) -> TrieNode<V> {
90 TrieNode {
91 key_value: None,
92 wildcard: Some((key, value)),
93 children: HashMap::new(),
94 regexps: Vec::new(),
95 }
96 }
97
98 pub fn root() -> TrieNode<V> {
99 TrieNode {
100 key_value: None,
101 wildcard: None,
102 children: HashMap::new(),
103 regexps: Vec::new(),
104 }
105 }
106
107 pub fn is_empty(&self) -> bool {
108 self.key_value.is_none()
109 && self.wildcard.is_none()
110 && self.regexps.is_empty()
111 && self.children.is_empty()
112 }
113
114 pub fn insert(&mut self, key: Key, value: V) -> InsertResult {
115 if key.is_empty() {
117 return InsertResult::Failed;
118 }
119 if key[..] == b"."[..] {
120 return InsertResult::Failed;
121 }
122
123 #[cfg(debug_assertions)]
124 let before = self.count_values();
125
126 let insert_result = self.insert_recursive(&key, &key, value);
127 assert_ne!(insert_result, InsertResult::Failed);
128
129 #[cfg(debug_assertions)]
133 {
134 let after = self.count_values();
135 match insert_result {
136 InsertResult::Ok => debug_assert_eq!(
137 after,
138 before + 1,
139 "a fresh insert must add exactly one value to the trie",
140 ),
141 InsertResult::Existing => debug_assert_eq!(
142 after, before,
143 "an Existing insert must not change the trie value count",
144 ),
145 InsertResult::Failed => {
146 unreachable!("insert_recursive returned Failed after key validation")
147 }
148 }
149 self.check_invariants();
150 }
151
152 insert_result
153 }
154
155 pub fn insert_recursive(&mut self, partial_key: &[u8], key: &Key, value: V) -> InsertResult {
156 assert_ne!(partial_key, &b""[..]);
158 debug_assert!(
162 partial_key.len() <= key.len(),
163 "insert recursion must consume the key, never grow past it",
164 );
165
166 if partial_key[partial_key.len() - 1] == b'/' {
167 let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
168
169 if let Some(pos) = pos {
170 if pos > 0 && partial_key[pos - 1] != b'.' {
171 return InsertResult::Failed;
172 }
173
174 if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
175 let anchored_s = format!("\\A{s}\\z");
176 debug_assert!(
177 anchored_s.starts_with("\\A") && anchored_s.ends_with("\\z"),
178 "segment regex must be fully anchored so it matches the whole segment only",
179 );
180 for t in self.regexps.iter_mut() {
181 if t.0.as_str() == anchored_s {
182 if pos > 0 {
194 return t.1.insert_recursive(&partial_key[..pos - 1], key, value);
195 } else {
196 return InsertResult::Existing;
197 }
198 }
199 }
200
201 let anchored = format!("\\A{s}\\z");
206 if let Ok(r) = Regex::new(&anchored) {
207 if pos > 0 {
208 let mut node = TrieNode::root();
209 let pos = pos - 1;
210
211 let res = node.insert_recursive(&partial_key[..pos], key, value);
212
213 if res == InsertResult::Ok {
214 self.regexps.push((r, node));
215 }
216
217 return res;
218 } else {
219 let node = TrieNode::new(key.to_vec(), value);
220 self.regexps.push((r, node));
221 return InsertResult::Ok;
222 }
223 }
224 }
225 }
226
227 return InsertResult::Failed;
228 }
229
230 let pos = find_last_dot(partial_key);
231 match pos {
232 None => {
233 if self.children.contains_key(partial_key) {
234 InsertResult::Existing
235 } else if partial_key == &b"*"[..] {
236 if self.wildcard.is_some() {
237 InsertResult::Existing
238 } else {
239 self.wildcard = Some((key.to_vec(), value));
240 InsertResult::Ok
241 }
242 } else {
243 let node = TrieNode::new(key.to_vec(), value);
244 self.children.insert(partial_key.to_vec(), node);
245 InsertResult::Ok
246 }
247 }
248 Some(pos) => {
249 debug_assert_eq!(
253 partial_key[..pos].len() + partial_key[pos..].len(),
254 partial_key.len(),
255 "dot-split must partition partial_key without losing bytes",
256 );
257 debug_assert_eq!(
258 partial_key[pos], b'.',
259 "find_last_dot must point at a '.' byte",
260 );
261 if let Some(child) = self.children.get_mut(&partial_key[pos..]) {
262 return child.insert_recursive(&partial_key[..pos], key, value);
263 }
264
265 let mut node = TrieNode::root();
266 let res = node.insert_recursive(&partial_key[..pos], key, value);
267
268 if res == InsertResult::Ok {
269 self.children.insert(partial_key[pos..].to_vec(), node);
270 }
271
272 res
273 }
274 }
275 }
276
277 pub fn remove(&mut self, key: &Key) -> RemoveResult {
278 #[cfg(debug_assertions)]
279 let before = self.count_values();
280
281 let remove_result = self.remove_recursive(key);
282
283 #[cfg(debug_assertions)]
287 {
288 let after = self.count_values();
289 match remove_result {
290 RemoveResult::Ok => debug_assert_eq!(
291 after + 1,
292 before,
293 "a successful remove must drop exactly one value from the trie",
294 ),
295 RemoveResult::NotFound => debug_assert_eq!(
296 after, before,
297 "a NotFound remove must not change the trie value count",
298 ),
299 }
300 self.check_invariants();
301 }
302
303 remove_result
304 }
305
306 pub fn remove_recursive(&mut self, partial_key: &[u8]) -> RemoveResult {
307 if partial_key.is_empty() {
310 if self.key_value.is_some() {
311 self.key_value = None;
312 return RemoveResult::Ok;
313 } else {
314 return RemoveResult::NotFound;
315 }
316 }
317
318 if partial_key == &b"*"[..] {
319 if self.wildcard.is_some() {
320 self.wildcard = None;
321 return RemoveResult::Ok;
322 } else {
323 return RemoveResult::NotFound;
324 }
325 }
326
327 if partial_key[partial_key.len() - 1] == b'/' {
328 let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
329
330 if let Some(pos) = pos {
331 if pos > 0 && partial_key[pos - 1] != b'.' {
332 return RemoveResult::NotFound;
333 }
334
335 if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
336 let anchored_s = format!("\\A{s}\\z");
337 if pos > 0 {
338 let mut remove_result = RemoveResult::NotFound;
339 for t in self.regexps.iter_mut() {
340 if t.0.as_str() == anchored_s
341 && t.1.remove_recursive(&partial_key[..pos - 1]) == RemoveResult::Ok
342 {
343 remove_result = RemoveResult::Ok;
344 }
345 }
346 return remove_result;
347 } else {
348 let len = self.regexps.len();
349 self.regexps.retain(|(r, _)| r.as_str() != anchored_s);
350 if len > self.regexps.len() {
351 return RemoveResult::Ok;
352 }
353 }
354 }
355 }
356
357 return RemoveResult::NotFound;
358 }
359
360 let pos = find_last_dot(partial_key);
361 let (prefix, suffix) = match pos {
362 None => (&b""[..], partial_key),
363 Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
364 };
365 debug_assert_eq!(
367 prefix.len() + suffix.len(),
368 partial_key.len(),
369 "dot-split must partition the key without losing or duplicating bytes",
370 );
371
372 match self.children.get_mut(suffix) {
373 Some(child) => match child.remove_recursive(prefix) {
374 RemoveResult::NotFound => RemoveResult::NotFound,
375 RemoveResult::Ok => {
376 if child.is_empty() {
380 self.children.remove(suffix);
381 debug_assert!(
382 !self.children.contains_key(suffix),
383 "an emptied child subtree must be removed from the parent",
384 );
385 } else {
386 #[cfg(debug_assertions)]
390 debug_assert!(
391 child.count_values() > 0,
392 "a retained child subtree must still hold at least one value",
393 );
394 }
395 RemoveResult::Ok
396 }
397 },
398 None => RemoveResult::NotFound,
399 }
400 }
401
402 pub fn lookup_with_path<'a, 'b>(
411 &'b self,
412 partial_key: &'a [u8],
413 accept_wildcard: bool,
414 mut trace: TrieMatches<'a, 'b>,
415 ) -> Option<(&'b KeyValue<Key, V>, TrieMatches<'a, 'b>)> {
416 if partial_key.is_empty() {
417 return self.key_value.as_ref().map(|kv| (kv, trace));
418 }
419
420 let pos = find_last_dot(partial_key);
421 let (prefix, suffix) = match pos {
422 None => (&b""[..], partial_key),
423 Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
424 };
425 debug_assert_eq!(
429 prefix.len() + suffix.len(),
430 partial_key.len(),
431 "dot-split must partition the key without losing or duplicating bytes",
432 );
433 debug_assert!(
434 pos.is_none() || suffix.first() == Some(&b'.'),
435 "a dotted split must place the separator at the head of the suffix",
436 );
437
438 match self.children.get(suffix) {
439 Some(child) => child.lookup_with_path(prefix, accept_wildcard, trace),
440 None => {
441 if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
442 let segment = if !suffix.is_empty() && suffix[0] == b'.' {
443 &suffix[1..]
444 } else {
445 suffix
446 };
447 trace.push(TrieSubMatch::Wildcard(segment));
448 self.wildcard.as_ref().map(|kv| (kv, trace))
449 } else {
450 for (regexp, child) in self.regexps.iter() {
451 let segment = if !suffix.is_empty() && suffix[0] == b'.' {
452 &suffix[1..]
453 } else {
454 suffix
455 };
456 if regexp.is_match(segment) {
457 let mut next = trace;
458 next.push(TrieSubMatch::Regexp(segment, regexp));
459 return child.lookup_with_path(prefix, accept_wildcard, next);
460 }
461 }
462 None
463 }
464 }
465 }
466 }
467
468 pub fn lookup(&self, partial_key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
469 if partial_key.is_empty() {
472 return self.key_value.as_ref();
473 }
474
475 let pos = find_last_dot(partial_key);
476 let (prefix, suffix) = match pos {
477 None => (&b""[..], partial_key),
478 Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
479 };
480 debug_assert_eq!(
482 prefix.len() + suffix.len(),
483 partial_key.len(),
484 "dot-split must partition the key without losing or duplicating bytes",
485 );
486 debug_assert!(
487 !suffix.is_empty(),
488 "the suffix the trie matches children against must be non-empty",
489 );
490
491 match self.children.get(suffix) {
492 Some(child) => child.lookup(prefix, accept_wildcard),
493 None => {
494 if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
497 self.wildcard.as_ref()
499 } else {
500 for (regexp, child) in self.regexps.iter() {
503 let suffix = if suffix[0] == b'.' {
504 &suffix[1..]
505 } else {
506 suffix
507 };
508 if regexp.is_match(suffix) {
511 return child.lookup(prefix, accept_wildcard);
513 }
514 }
515
516 None
517 }
518 }
519 }
520 }
521
522 pub fn lookup_mut(
523 &mut self,
524 partial_key: &[u8],
525 accept_wildcard: bool,
526 ) -> Option<&mut KeyValue<Key, V>> {
527 if partial_key.is_empty() {
530 return self.key_value.as_mut();
531 }
532
533 if partial_key == &b"*"[..] {
534 return self.wildcard.as_mut();
535 }
536
537 if partial_key[partial_key.len() - 1] == b'/' {
538 let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
539
540 if let Some(pos) = pos {
541 if pos > 0 && partial_key[pos - 1] != b'.' {
542 return None;
543 }
544
545 if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
546 let anchored_s = format!("\\A{s}\\z");
547 for t in self.regexps.iter_mut() {
548 if t.0.as_str() == anchored_s {
549 let rest = if pos > 0 {
558 &partial_key[..pos - 1]
559 } else {
560 &partial_key[..0]
561 };
562 return t.1.lookup_mut(rest, accept_wildcard);
563 }
564 }
565 }
566 }
567
568 return None;
569 }
570
571 let pos = find_last_dot(partial_key);
572 let (prefix, suffix) = match pos {
573 None => (&b""[..], partial_key),
574 Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
575 };
576 debug_assert_eq!(
578 prefix.len() + suffix.len(),
579 partial_key.len(),
580 "dot-split must partition the key without losing or duplicating bytes",
581 );
582 debug_assert!(
583 !suffix.is_empty(),
584 "the suffix the trie matches children against must be non-empty",
585 );
586
587 match self.children.get_mut(suffix) {
588 Some(child) => child.lookup_mut(prefix, accept_wildcard),
589 None => {
590 if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
593 self.wildcard.as_mut()
595 } else {
596 for &mut (ref regexp, ref mut child) in self.regexps.iter_mut() {
599 let suffix = if suffix[0] == b'.' {
600 &suffix[1..]
601 } else {
602 suffix
603 };
604 if regexp.is_match(suffix) {
607 return child.lookup_mut(prefix, accept_wildcard);
609 }
610 }
611
612 None
613 }
614 }
615 }
616 }
617
618 pub fn print(&self) {
619 self.print_recursive(b"", 0)
620 }
621
622 pub fn print_recursive(&self, partial_key: &[u8], indent: u8) {
623 let raw_prefix: Vec<u8> = iter::repeat_n(b' ', 2 * indent as usize).collect();
624 let prefix = str::from_utf8(&raw_prefix).unwrap();
625
626 print!("{}{}: ", prefix, str::from_utf8(partial_key).unwrap());
627 if let Some((ref key, ref value)) = self.key_value {
628 print!("({}, {:?}) | ", str::from_utf8(key).unwrap(), value);
629 } else {
630 print!("None | ");
631 }
632
633 if let Some((key, value)) = &self.wildcard {
634 println!("({}, {:?})", str::from_utf8(key).unwrap(), value);
635 } else {
636 println!("None");
637 }
638
639 for (child_key, child) in self.children.iter() {
640 child.print_recursive(child_key, indent + 1);
641 }
642
643 for (regexp, child) in self.regexps.iter() {
644 child.print_recursive(regexp.as_str().as_bytes(), indent + 1);
646 }
647 }
648
649 pub fn for_each_value_mut<F: FnMut(&mut V)>(&mut self, f: &mut F) {
655 if let Some((_, ref mut value)) = self.key_value {
656 f(value);
657 }
658 if let Some((_, ref mut value)) = self.wildcard {
659 f(value);
660 }
661 for child in self.children.values_mut() {
662 child.for_each_value_mut(f);
663 }
664 for (_, child) in self.regexps.iter_mut() {
665 child.for_each_value_mut(f);
666 }
667 }
668
669 #[cfg(debug_assertions)]
675 fn count_values(&self) -> usize {
676 let local = self.key_value.is_some() as usize + self.wildcard.is_some() as usize;
677 let in_children: usize = self.children.values().map(TrieNode::count_values).sum();
678 let in_regexps: usize = self.regexps.iter().map(|(_, c)| c.count_values()).sum();
679 local + in_children + in_regexps
680 }
681
682 #[cfg(debug_assertions)]
699 fn check_invariants(&self) {
700 for i in 0..self.regexps.len() {
702 for j in (i + 1)..self.regexps.len() {
703 debug_assert_ne!(
704 self.regexps[i].0.as_str(),
705 self.regexps[j].0.as_str(),
706 "trie node must not hold two subtrees for the same regex segment",
707 );
708 }
709 }
710
711 for (child_key, child) in self.children.iter() {
712 debug_assert!(
713 !child_key.is_empty(),
714 "trie child must not be keyed by the empty segment",
715 );
716 debug_assert!(
720 !child.is_empty(),
721 "trie must not strand an empty child subtree (remove must prune)",
722 );
723 debug_assert!(
724 child.count_values() > 0,
725 "trie child subtree must lead to at least one value",
726 );
727 child.check_invariants();
728 }
729
730 for (_, child) in self.regexps.iter() {
731 debug_assert!(
732 !child.is_empty(),
733 "trie must not strand an empty regex subtree (remove must prune)",
734 );
735 debug_assert!(
736 child.count_values() > 0,
737 "trie regex subtree must lead to at least one value",
738 );
739 child.check_invariants();
740 }
741 }
742
743 pub fn domain_insert(&mut self, key: Key, value: V) -> InsertResult {
744 self.insert(key, value)
745 }
746
747 pub fn domain_remove(&mut self, key: &Key) -> RemoveResult {
748 self.remove(key)
749 }
750
751 pub fn domain_lookup(&self, key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
752 self.lookup(key, accept_wildcard)
753 }
754
755 pub fn domain_lookup_mut(
756 &mut self,
757 key: &[u8],
758 accept_wildcard: bool,
759 ) -> Option<&mut KeyValue<Key, V>> {
760 self.lookup_mut(key, accept_wildcard)
761 }
762
763 pub fn size(&self) -> usize {
764 ::std::mem::size_of::<TrieNode<V>>()
765 + ::std::mem::size_of::<Option<KeyValue<Key, V>>>() * 2
766 + self
767 .children
768 .iter()
769 .fold(0, |acc, c| acc + c.0.len() + c.1.size())
770 }
771
772 pub fn to_hashmap(&self) -> HashMap<Key, V> {
773 let mut h = HashMap::new();
774
775 self.to_hashmap_recursive(&mut h);
776
777 h
778 }
779
780 pub fn to_hashmap_recursive(&self, h: &mut HashMap<Key, V>) {
781 if let Some((key, value)) = &self.key_value {
782 h.insert(key.clone(), value.clone());
783 }
784
785 if let Some((key, value)) = &self.wildcard {
786 h.insert(key.clone(), value.clone());
787 }
788
789 for child in self.children.values() {
790 child.to_hashmap_recursive(h);
791 }
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798
799 #[test]
800 fn insert() {
801 let mut root: TrieNode<u8> = TrieNode::root();
802 root.print();
803
804 assert_eq!(
805 root.domain_insert(Vec::from(&b"abcd"[..]), 1),
806 InsertResult::Ok
807 );
808 root.print();
809 assert_eq!(
810 root.domain_insert(Vec::from(&b"abce"[..]), 2),
811 InsertResult::Ok
812 );
813 root.print();
814 assert_eq!(
815 root.domain_insert(Vec::from(&b"abgh"[..]), 3),
816 InsertResult::Ok
817 );
818 root.print();
819
820 assert_eq!(
821 root.domain_lookup(&b"abce"[..], true),
822 Some(&(b"abce"[..].to_vec(), 2))
823 );
824 }
826
827 #[test]
828 fn remove() {
829 let mut root: TrieNode<u8> = TrieNode::root();
830 println!("creating root:");
831 root.print();
832
833 println!("adding (abcd, 1)");
834 assert_eq!(root.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
835 root.print();
836 println!("adding (abce, 2)");
837 assert_eq!(root.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
838 root.print();
839 println!("adding (abgh, 3)");
840 assert_eq!(root.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
841 root.print();
842
843 let mut root2: TrieNode<u8> = TrieNode::root();
844
845 assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
846 assert_eq!(root2.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
847
848 println!("before remove");
849 root.print();
850 assert_eq!(root.remove(&Vec::from(&b"abce"[..])), RemoveResult::Ok);
851 println!("after remove");
852 root.print();
853
854 println!("expected");
855 root2.print();
856 assert_eq!(root, root2);
857
858 assert_eq!(root.remove(&Vec::from(&b"abgh"[..])), RemoveResult::Ok);
859 println!("after remove");
860 root.print();
861 println!("expected");
862 let mut root3: TrieNode<u8> = TrieNode::root();
863 assert_eq!(root3.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
864 root3.print();
865 assert_eq!(root, root3);
866 }
867
868 #[test]
869 fn insert_remove_through_regex() {
870 let mut root: TrieNode<u8> = TrieNode::root();
871 println!("creating root:");
872 root.print();
873
874 println!("adding (www./.*/.com, 1)");
875 assert_eq!(
876 root.insert(Vec::from(&b"www./.*/.com"[..]), 1),
877 InsertResult::Ok
878 );
879 root.print();
880 println!("adding (www.doc./.*/.com, 2)");
881 assert_eq!(
882 root.insert(Vec::from(&b"www.doc./.*/.com"[..]), 2),
883 InsertResult::Ok
884 );
885 root.print();
886 assert_eq!(
887 root.domain_lookup(b"www.sozu.com".as_ref(), false),
888 Some(&(b"www./.*/.com".to_vec(), 1))
889 );
890 assert_eq!(
891 root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
892 Some(&(b"www.doc./.*/.com".to_vec(), 2))
893 );
894
895 assert_eq!(
896 root.domain_remove(&b"www./.*/.com".to_vec()),
897 RemoveResult::Ok
898 );
899 root.print();
900 assert_eq!(root.domain_lookup(b"www.sozu.com".as_ref(), false), None);
901 assert_eq!(
902 root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
903 Some(&(b"www.doc./.*/.com".to_vec(), 2))
904 );
905 }
906
907 #[test]
913 fn segment_regex_rejects_partial_matches() {
914 let mut root: TrieNode<u8> = TrieNode::root();
915 assert_eq!(
918 root.insert(Vec::from(&b"/cdn[0-9]+/.example.com"[..]), 7),
919 InsertResult::Ok
920 );
921
922 assert_eq!(
924 root.domain_lookup(b"cdn1.example.com".as_ref(), false),
925 Some(&(b"/cdn[0-9]+/.example.com".to_vec(), 7))
926 );
927 assert_eq!(
928 root.domain_lookup(b"cdn123.example.com".as_ref(), false),
929 Some(&(b"/cdn[0-9]+/.example.com".to_vec(), 7))
930 );
931
932 assert_eq!(
936 root.domain_lookup(b"cdn1xxx.example.com".as_ref(), false),
937 None
938 );
939 assert_eq!(
941 root.domain_lookup(b"xxxcdn1.example.com".as_ref(), false),
942 None
943 );
944 assert_eq!(
946 root.domain_lookup(b"cdnabc.example.com".as_ref(), false),
947 None
948 );
949 }
950
951 #[test]
952 fn add_child_to_leaf() {
953 let mut root1: TrieNode<u8> = TrieNode::root();
954
955 println!("creating root1:");
956 root1.print();
957 println!("adding (abcd, 1)");
958 assert_eq!(root1.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
959 root1.print();
960 println!("adding (abce, 2)");
961 assert_eq!(root1.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
962 root1.print();
963 println!("adding (abc, 3)");
964 assert_eq!(root1.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
965
966 println!("root1:");
967 root1.print();
968
969 let mut root2: TrieNode<u8> = TrieNode::root();
970
971 assert_eq!(root2.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
972 assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
973 assert_eq!(root2.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
974
975 println!("root2:");
976 root2.print();
977 assert_eq!(root2.remove(&Vec::from(&b"abc"[..])), RemoveResult::Ok);
978
979 println!("root2 after,remove:");
980 root2.print();
981 let mut expected: TrieNode<u8> = TrieNode::root();
982
983 assert_eq!(
984 expected.insert(Vec::from(&b"abcd"[..]), 1),
985 InsertResult::Ok
986 );
987 assert_eq!(
988 expected.insert(Vec::from(&b"abce"[..]), 2),
989 InsertResult::Ok
990 );
991
992 println!("root2 after insert");
993 root2.print();
994 println!("expected");
995 expected.print();
996 assert_eq!(root2, expected);
997 }
998
999 #[test]
1000 fn domains() {
1001 let mut root: TrieNode<u8> = TrieNode::root();
1002 root.print();
1003
1004 assert_eq!(
1005 root.domain_insert(Vec::from(&b"www.example.com"[..]), 1),
1006 InsertResult::Ok
1007 );
1008 root.print();
1009 assert_eq!(
1010 root.domain_insert(Vec::from(&b"test.example.com"[..]), 2),
1011 InsertResult::Ok
1012 );
1013 root.print();
1014 assert_eq!(
1015 root.domain_insert(Vec::from(&b"*.alldomains.org"[..]), 3),
1016 InsertResult::Ok
1017 );
1018 root.print();
1019 assert_eq!(
1020 root.domain_insert(Vec::from(&b"alldomains.org"[..]), 4),
1021 InsertResult::Ok
1022 );
1023 assert_eq!(
1024 root.domain_insert(Vec::from(&b"pouet.alldomains.org"[..]), 5),
1025 InsertResult::Ok
1026 );
1027 root.print();
1028 assert_eq!(
1029 root.domain_insert(Vec::from(&b"hello.com"[..]), 6),
1030 InsertResult::Ok
1031 );
1032 assert_eq!(
1033 root.domain_insert(Vec::from(&b"*.hello.com"[..]), 7),
1034 InsertResult::Ok
1035 );
1036 assert_eq!(
1037 root.domain_insert(Vec::from(&b"images./cdn[0-9]+/.hello.com"[..]), 8),
1038 InsertResult::Ok
1039 );
1040 root.print();
1041 assert_eq!(
1042 root.domain_insert(Vec::from(&b"/test[0-9]+/.www.hello.com"[..]), 9),
1043 InsertResult::Ok
1044 );
1045 root.print();
1046
1047 assert_eq!(root.domain_lookup(&b"example.com"[..], true), None);
1048 assert_eq!(
1049 root.domain_lookup(&b"blah.test.example.com"[..], true),
1050 None
1051 );
1052 assert_eq!(
1053 root.domain_lookup(&b"www.example.com"[..], true),
1054 Some(&(b"www.example.com"[..].to_vec(), 1))
1055 );
1056 assert_eq!(
1057 root.domain_lookup(&b"alldomains.org"[..], true),
1058 Some(&(b"alldomains.org"[..].to_vec(), 4))
1059 );
1060 assert_eq!(
1061 root.domain_lookup(&b"test.hello.com"[..], true),
1062 Some(&(b"*.hello.com"[..].to_vec(), 7))
1063 );
1064 assert_eq!(
1065 root.domain_lookup(&b"images.cdn10.hello.com"[..], true),
1066 Some(&(b"images./cdn[0-9]+/.hello.com"[..].to_vec(), 8))
1067 );
1068 assert_eq!(
1069 root.domain_lookup(&b"test42.www.hello.com"[..], true),
1070 Some(&(b"/test[0-9]+/.www.hello.com"[..].to_vec(), 9))
1071 );
1072 assert_eq!(
1073 root.domain_lookup(&b"test.alldomains.org"[..], true),
1074 Some(&(b"*.alldomains.org"[..].to_vec(), 3))
1075 );
1076 assert_eq!(
1077 root.domain_lookup(&b"hello.alldomains.org"[..], true),
1078 Some(&(b"*.alldomains.org"[..].to_vec(), 3))
1079 );
1080 assert_eq!(
1081 root.domain_lookup(&b"pouet.alldomains.org"[..], true),
1082 Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
1083 );
1084 assert_eq!(
1085 root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
1086 None
1087 );
1088
1089 assert_eq!(
1090 root.domain_remove(&Vec::from(&b"alldomains.org"[..])),
1091 RemoveResult::Ok
1092 );
1093 println!("after remove");
1094 root.print();
1095 assert_eq!(root.domain_lookup(&b"alldomains.org"[..], true), None);
1096 assert_eq!(
1097 root.domain_lookup(&b"test.alldomains.org"[..], true),
1098 Some(&(b"*.alldomains.org"[..].to_vec(), 3))
1099 );
1100 assert_eq!(
1101 root.domain_lookup(&b"hello.alldomains.org"[..], true),
1102 Some(&(b"*.alldomains.org"[..].to_vec(), 3))
1103 );
1104 assert_eq!(
1105 root.domain_lookup(&b"pouet.alldomains.org"[..], true),
1106 Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
1107 );
1108 assert_eq!(
1109 root.domain_lookup(&b"test.hello.com"[..], true),
1110 Some(&(b"*.hello.com"[..].to_vec(), 7))
1111 );
1112 assert_eq!(
1113 root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
1114 None
1115 );
1116 }
1117
1118 #[test]
1119 fn wildcard() {
1120 let mut root: TrieNode<u8> = TrieNode::root();
1121 root.print();
1122 root.domain_insert("*.clever-cloud.com".as_bytes().to_vec(), 2u8);
1123 root.domain_insert("services.clever-cloud.com".as_bytes().to_vec(), 0u8);
1124 root.domain_insert("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8);
1125
1126 let res = root.domain_lookup(b"test.services.clever-cloud.com", true);
1127 println!("query result: {res:?}");
1128
1129 assert_eq!(
1130 root.domain_lookup(b"pgstudio.services.clever-cloud.com", true),
1131 Some(&("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8))
1132 );
1133 }
1134
1135 fn hm_insert(h: std::collections::HashMap<String, u32>) -> bool {
1136 let mut root: TrieNode<u32> = TrieNode::root();
1137
1138 for (k, v) in h.iter() {
1139 if k.is_empty() {
1140 continue;
1141 }
1142
1143 if k.as_bytes()[0] == b'.' {
1144 continue;
1145 }
1146
1147 if k.contains('/') {
1148 continue;
1149 }
1150
1151 if k == "*" {
1152 continue;
1153 }
1154
1155 assert_eq!(
1158 root.insert(Vec::from(k.as_bytes()), *v),
1159 InsertResult::Ok,
1160 "could not insert ({k}, {v})"
1161 );
1162 }
1164
1165 for (k, v) in h.iter() {
1167 if k.is_empty() {
1168 continue;
1169 }
1170
1171 if k.as_bytes()[0] == b'.' {
1172 continue;
1173 }
1174
1175 if k.contains('/') {
1176 continue;
1177 }
1178
1179 if k == "*" {
1180 continue;
1181 }
1182
1183 match root.lookup(k.as_bytes(), false) {
1185 None => {
1186 println!("did not find key '{k}'");
1187 return false;
1188 }
1189 Some(&(ref k1, v1)) => {
1190 if k.as_bytes() != &k1[..] || *v != v1 {
1191 println!(
1192 "request ({}, {}), got ({}, {})",
1193 k,
1194 v,
1195 str::from_utf8(&k1[..]).unwrap(),
1196 v1
1197 );
1198 return false;
1199 }
1200 }
1201 }
1202 }
1203
1204 true
1205 }
1206
1207 #[test]
1216 fn insert_disappearing_tree() {
1217 let h: std::collections::HashMap<String, u32> = [
1218 (String::from("\n\u{3}"), 0),
1219 (String::from("\n\u{0}"), 1),
1220 (String::from("\n"), 2),
1221 ]
1222 .iter()
1223 .cloned()
1224 .collect();
1225 assert!(hm_insert(h));
1226 }
1227
1228 #[test]
1229 fn size() {
1230 assert_size!(TrieNode<u32>, 136);
1231 }
1232
1233 #[test]
1240 fn leftmost_regex_segment_reinsert_and_lookup_mut_do_not_panic() {
1241 let mut root: TrieNode<u8> = TrieNode::root();
1242
1243 assert_eq!(
1244 root.insert(Vec::from(&b"/test[0-9]/.example.com"[..]), 7),
1245 InsertResult::Ok
1246 );
1247 assert_eq!(
1250 root.insert(Vec::from(&b"/test[0-9]/.example.com"[..]), 8),
1251 InsertResult::Existing
1252 );
1253
1254 let resolved = root.domain_lookup_mut(b"test4.example.com", false);
1259 assert_eq!(
1260 resolved.map(|(_, v)| *v),
1261 Some(7),
1262 "leftmost-regex host must resolve via lookup_mut without panicking",
1263 );
1264
1265 assert_eq!(
1267 root.domain_lookup(b"test4.example.com", false),
1268 Some(&(b"/test[0-9]/.example.com"[..].to_vec(), 7))
1269 );
1270
1271 assert_eq!(
1273 root.domain_remove(&Vec::from(&b"/test[0-9]/.example.com"[..])),
1274 RemoveResult::Ok
1275 );
1276 assert_eq!(root.domain_lookup(b"test4.example.com", false), None);
1277 }
1278}