1use std::{collections::HashMap, fmt::Debug, iter, str};
2
3use regex::bytes::Regex;
4
5pub type Key = Vec<u8>;
6pub type KeyValue<K, V> = (K, V);
7
8#[derive(Debug, PartialEq, Eq)]
9pub enum InsertResult {
10 Ok,
11 Existing,
12 Failed,
13}
14
15#[derive(Debug, PartialEq, Eq)]
16pub enum RemoveResult {
17 Ok,
18 NotFound,
19}
20
21fn find_last_dot(input: &[u8]) -> Option<usize> {
22 (0..input.len()).rev().find(|&i| input[i] == b'.')
24}
25
26fn find_last_slash(input: &[u8]) -> Option<usize> {
27 (0..input.len()).rev().find(|&i| input[i] == b'/')
29}
30
31#[derive(Debug, Default)]
38pub struct TrieNode<V> {
39 key_value: Option<KeyValue<Key, V>>,
40 wildcard: Option<KeyValue<Key, V>>,
41 children: HashMap<Key, TrieNode<V>>,
42 regexps: Vec<(Regex, TrieNode<V>)>,
43}
44
45#[derive(Debug)]
52pub enum TrieSubMatch<'a, 'b> {
53 Wildcard(&'a [u8]),
54 Regexp(&'a [u8], &'b Regex),
55}
56
57pub type TrieMatches<'a, 'b> = Vec<TrieSubMatch<'a, 'b>>;
62
63impl<V: PartialEq> std::cmp::PartialEq for TrieNode<V> {
64 fn eq(&self, other: &Self) -> bool {
65 self.key_value == other.key_value
66 && self.wildcard == other.wildcard
67 && self.children == other.children
68 && self.regexps.len() == other.regexps.len()
69 && self
70 .regexps
71 .iter()
72 .zip(other.regexps.iter())
73 .fold(true, |b, (left, right)| {
74 b && left.0.as_str() == right.0.as_str() && left.1 == right.1
75 })
76 }
77}
78
79impl<V: Debug + Clone> TrieNode<V> {
80 pub fn new(key: Key, value: V) -> TrieNode<V> {
81 TrieNode {
82 key_value: Some((key, value)),
83 wildcard: None,
84 children: HashMap::new(),
85 regexps: Vec::new(),
86 }
87 }
88
89 pub fn wildcard(key: Key, value: V) -> TrieNode<V> {
90 TrieNode {
91 key_value: None,
92 wildcard: Some((key, value)),
93 children: HashMap::new(),
94 regexps: Vec::new(),
95 }
96 }
97
98 pub fn root() -> TrieNode<V> {
99 TrieNode {
100 key_value: None,
101 wildcard: None,
102 children: HashMap::new(),
103 regexps: Vec::new(),
104 }
105 }
106
107 pub fn is_empty(&self) -> bool {
108 self.key_value.is_none()
109 && self.wildcard.is_none()
110 && self.regexps.is_empty()
111 && self.children.is_empty()
112 }
113
114 pub fn insert(&mut self, key: Key, value: V) -> InsertResult {
115 if key.is_empty() {
117 return InsertResult::Failed;
118 }
119 if key[..] == b"."[..] {
120 return InsertResult::Failed;
121 }
122
123 let insert_result = self.insert_recursive(&key, &key, value);
124 assert_ne!(insert_result, InsertResult::Failed);
125 insert_result
126 }
127
128 pub fn insert_recursive(&mut self, partial_key: &[u8], key: &Key, value: V) -> InsertResult {
129 assert_ne!(partial_key, &b""[..]);
131
132 if partial_key[partial_key.len() - 1] == b'/' {
133 let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
134
135 if let Some(pos) = pos {
136 if pos > 0 && partial_key[pos - 1] != b'.' {
137 return InsertResult::Failed;
138 }
139
140 if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
141 let anchored_s = format!("\\A{s}\\z");
142 for t in self.regexps.iter_mut() {
143 if t.0.as_str() == anchored_s {
144 return t.1.insert_recursive(&partial_key[..pos - 1], key, value);
145 }
146 }
147
148 let anchored = format!("\\A{s}\\z");
153 if let Ok(r) = Regex::new(&anchored) {
154 if pos > 0 {
155 let mut node = TrieNode::root();
156 let pos = pos - 1;
157
158 let res = node.insert_recursive(&partial_key[..pos], key, value);
159
160 if res == InsertResult::Ok {
161 self.regexps.push((r, node));
162 }
163
164 return res;
165 } else {
166 let node = TrieNode::new(key.to_vec(), value);
167 self.regexps.push((r, node));
168 return InsertResult::Ok;
169 }
170 }
171 }
172 }
173
174 return InsertResult::Failed;
175 }
176
177 let pos = find_last_dot(partial_key);
178 match pos {
179 None => {
180 if self.children.contains_key(partial_key) {
181 InsertResult::Existing
182 } else if partial_key == &b"*"[..] {
183 if self.wildcard.is_some() {
184 InsertResult::Existing
185 } else {
186 self.wildcard = Some((key.to_vec(), value));
187 InsertResult::Ok
188 }
189 } else {
190 let node = TrieNode::new(key.to_vec(), value);
191 self.children.insert(partial_key.to_vec(), node);
192 InsertResult::Ok
193 }
194 }
195 Some(pos) => {
196 if let Some(child) = self.children.get_mut(&partial_key[pos..]) {
197 return child.insert_recursive(&partial_key[..pos], key, value);
198 }
199
200 let mut node = TrieNode::root();
201 let res = node.insert_recursive(&partial_key[..pos], key, value);
202
203 if res == InsertResult::Ok {
204 self.children.insert(partial_key[pos..].to_vec(), node);
205 }
206
207 res
208 }
209 }
210 }
211
212 pub fn remove(&mut self, key: &Key) -> RemoveResult {
213 self.remove_recursive(key)
214 }
215
216 pub fn remove_recursive(&mut self, partial_key: &[u8]) -> RemoveResult {
217 if partial_key.is_empty() {
220 if self.key_value.is_some() {
221 self.key_value = None;
222 return RemoveResult::Ok;
223 } else {
224 return RemoveResult::NotFound;
225 }
226 }
227
228 if partial_key == &b"*"[..] {
229 if self.wildcard.is_some() {
230 self.wildcard = None;
231 return RemoveResult::Ok;
232 } else {
233 return RemoveResult::NotFound;
234 }
235 }
236
237 if partial_key[partial_key.len() - 1] == b'/' {
238 let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
239
240 if let Some(pos) = pos {
241 if pos > 0 && partial_key[pos - 1] != b'.' {
242 return RemoveResult::NotFound;
243 }
244
245 if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
246 let anchored_s = format!("\\A{s}\\z");
247 if pos > 0 {
248 let mut remove_result = RemoveResult::NotFound;
249 for t in self.regexps.iter_mut() {
250 if t.0.as_str() == anchored_s
251 && t.1.remove_recursive(&partial_key[..pos - 1]) == RemoveResult::Ok
252 {
253 remove_result = RemoveResult::Ok;
254 }
255 }
256 return remove_result;
257 } else {
258 let len = self.regexps.len();
259 self.regexps.retain(|(r, _)| r.as_str() != anchored_s);
260 if len > self.regexps.len() {
261 return RemoveResult::Ok;
262 }
263 }
264 }
265 }
266
267 return RemoveResult::NotFound;
268 }
269
270 let pos = find_last_dot(partial_key);
271 let (prefix, suffix) = match pos {
272 None => (&b""[..], partial_key),
273 Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
274 };
275 match self.children.get_mut(suffix) {
278 Some(child) => match child.remove_recursive(prefix) {
279 RemoveResult::NotFound => RemoveResult::NotFound,
280 RemoveResult::Ok => {
281 if child.is_empty() {
282 self.children.remove(suffix);
283 }
284 RemoveResult::Ok
285 }
286 },
287 None => RemoveResult::NotFound,
288 }
289 }
290
291 pub fn lookup_with_path<'a, 'b>(
300 &'b self,
301 partial_key: &'a [u8],
302 accept_wildcard: bool,
303 mut trace: TrieMatches<'a, 'b>,
304 ) -> Option<(&'b KeyValue<Key, V>, TrieMatches<'a, 'b>)> {
305 if partial_key.is_empty() {
306 return self.key_value.as_ref().map(|kv| (kv, trace));
307 }
308
309 let pos = find_last_dot(partial_key);
310 let (prefix, suffix) = match pos {
311 None => (&b""[..], partial_key),
312 Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
313 };
314
315 match self.children.get(suffix) {
316 Some(child) => child.lookup_with_path(prefix, accept_wildcard, trace),
317 None => {
318 if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
319 let segment = if !suffix.is_empty() && suffix[0] == b'.' {
320 &suffix[1..]
321 } else {
322 suffix
323 };
324 trace.push(TrieSubMatch::Wildcard(segment));
325 self.wildcard.as_ref().map(|kv| (kv, trace))
326 } else {
327 for (regexp, child) in self.regexps.iter() {
328 let segment = if !suffix.is_empty() && suffix[0] == b'.' {
329 &suffix[1..]
330 } else {
331 suffix
332 };
333 if regexp.is_match(segment) {
334 let mut next = trace;
335 next.push(TrieSubMatch::Regexp(segment, regexp));
336 return child.lookup_with_path(prefix, accept_wildcard, next);
337 }
338 }
339 None
340 }
341 }
342 }
343 }
344
345 pub fn lookup(&self, partial_key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
346 if partial_key.is_empty() {
349 return self.key_value.as_ref();
350 }
351
352 let pos = find_last_dot(partial_key);
353 let (prefix, suffix) = match pos {
354 None => (&b""[..], partial_key),
355 Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
356 };
357 match self.children.get(suffix) {
360 Some(child) => child.lookup(prefix, accept_wildcard),
361 None => {
362 if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
365 self.wildcard.as_ref()
367 } else {
368 for (regexp, child) in self.regexps.iter() {
371 let suffix = if suffix[0] == b'.' {
372 &suffix[1..]
373 } else {
374 suffix
375 };
376 if regexp.is_match(suffix) {
379 return child.lookup(prefix, accept_wildcard);
381 }
382 }
383
384 None
385 }
386 }
387 }
388 }
389
390 pub fn lookup_mut(
391 &mut self,
392 partial_key: &[u8],
393 accept_wildcard: bool,
394 ) -> Option<&mut KeyValue<Key, V>> {
395 if partial_key.is_empty() {
398 return self.key_value.as_mut();
399 }
400
401 if partial_key == &b"*"[..] {
402 return self.wildcard.as_mut();
403 }
404
405 if partial_key[partial_key.len() - 1] == b'/' {
406 let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
407
408 if let Some(pos) = pos {
409 if pos > 0 && partial_key[pos - 1] != b'.' {
410 return None;
411 }
412
413 if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
414 let anchored_s = format!("\\A{s}\\z");
415 for t in self.regexps.iter_mut() {
416 if t.0.as_str() == anchored_s {
417 return t.1.lookup_mut(&partial_key[..pos - 1], accept_wildcard);
418 }
419 }
420 }
421 }
422
423 return None;
424 }
425
426 let pos = find_last_dot(partial_key);
427 let (prefix, suffix) = match pos {
428 None => (&b""[..], partial_key),
429 Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
430 };
431 match self.children.get_mut(suffix) {
434 Some(child) => child.lookup_mut(prefix, accept_wildcard),
435 None => {
436 if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
439 self.wildcard.as_mut()
441 } else {
442 for &mut (ref regexp, ref mut child) in self.regexps.iter_mut() {
445 let suffix = if suffix[0] == b'.' {
446 &suffix[1..]
447 } else {
448 suffix
449 };
450 if regexp.is_match(suffix) {
453 return child.lookup_mut(prefix, accept_wildcard);
455 }
456 }
457
458 None
459 }
460 }
461 }
462 }
463
464 pub fn print(&self) {
465 self.print_recursive(b"", 0)
466 }
467
468 pub fn print_recursive(&self, partial_key: &[u8], indent: u8) {
469 let raw_prefix: Vec<u8> = iter::repeat_n(b' ', 2 * indent as usize).collect();
470 let prefix = str::from_utf8(&raw_prefix).unwrap();
471
472 print!("{}{}: ", prefix, str::from_utf8(partial_key).unwrap());
473 if let Some((ref key, ref value)) = self.key_value {
474 print!("({}, {:?}) | ", str::from_utf8(key).unwrap(), value);
475 } else {
476 print!("None | ");
477 }
478
479 if let Some((key, value)) = &self.wildcard {
480 println!("({}, {:?})", str::from_utf8(key).unwrap(), value);
481 } else {
482 println!("None");
483 }
484
485 for (child_key, child) in self.children.iter() {
486 child.print_recursive(child_key, indent + 1);
487 }
488
489 for (regexp, child) in self.regexps.iter() {
490 child.print_recursive(regexp.as_str().as_bytes(), indent + 1);
492 }
493 }
494
495 pub fn for_each_value_mut<F: FnMut(&mut V)>(&mut self, f: &mut F) {
501 if let Some((_, ref mut value)) = self.key_value {
502 f(value);
503 }
504 if let Some((_, ref mut value)) = self.wildcard {
505 f(value);
506 }
507 for child in self.children.values_mut() {
508 child.for_each_value_mut(f);
509 }
510 for (_, child) in self.regexps.iter_mut() {
511 child.for_each_value_mut(f);
512 }
513 }
514
515 pub fn domain_insert(&mut self, key: Key, value: V) -> InsertResult {
516 self.insert(key, value)
517 }
518
519 pub fn domain_remove(&mut self, key: &Key) -> RemoveResult {
520 self.remove(key)
521 }
522
523 pub fn domain_lookup(&self, key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
524 self.lookup(key, accept_wildcard)
525 }
526
527 pub fn domain_lookup_mut(
528 &mut self,
529 key: &[u8],
530 accept_wildcard: bool,
531 ) -> Option<&mut KeyValue<Key, V>> {
532 self.lookup_mut(key, accept_wildcard)
533 }
534
535 pub fn size(&self) -> usize {
536 ::std::mem::size_of::<TrieNode<V>>()
537 + ::std::mem::size_of::<Option<KeyValue<Key, V>>>() * 2
538 + self
539 .children
540 .iter()
541 .fold(0, |acc, c| acc + c.0.len() + c.1.size())
542 }
543
544 pub fn to_hashmap(&self) -> HashMap<Key, V> {
545 let mut h = HashMap::new();
546
547 self.to_hashmap_recursive(&mut h);
548
549 h
550 }
551
552 pub fn to_hashmap_recursive(&self, h: &mut HashMap<Key, V>) {
553 if let Some((key, value)) = &self.key_value {
554 h.insert(key.clone(), value.clone());
555 }
556
557 if let Some((key, value)) = &self.wildcard {
558 h.insert(key.clone(), value.clone());
559 }
560
561 for child in self.children.values() {
562 child.to_hashmap_recursive(h);
563 }
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn insert() {
573 let mut root: TrieNode<u8> = TrieNode::root();
574 root.print();
575
576 assert_eq!(
577 root.domain_insert(Vec::from(&b"abcd"[..]), 1),
578 InsertResult::Ok
579 );
580 root.print();
581 assert_eq!(
582 root.domain_insert(Vec::from(&b"abce"[..]), 2),
583 InsertResult::Ok
584 );
585 root.print();
586 assert_eq!(
587 root.domain_insert(Vec::from(&b"abgh"[..]), 3),
588 InsertResult::Ok
589 );
590 root.print();
591
592 assert_eq!(
593 root.domain_lookup(&b"abce"[..], true),
594 Some(&(b"abce"[..].to_vec(), 2))
595 );
596 }
598
599 #[test]
600 fn remove() {
601 let mut root: TrieNode<u8> = TrieNode::root();
602 println!("creating root:");
603 root.print();
604
605 println!("adding (abcd, 1)");
606 assert_eq!(root.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
607 root.print();
608 println!("adding (abce, 2)");
609 assert_eq!(root.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
610 root.print();
611 println!("adding (abgh, 3)");
612 assert_eq!(root.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
613 root.print();
614
615 let mut root2: TrieNode<u8> = TrieNode::root();
616
617 assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
618 assert_eq!(root2.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
619
620 println!("before remove");
621 root.print();
622 assert_eq!(root.remove(&Vec::from(&b"abce"[..])), RemoveResult::Ok);
623 println!("after remove");
624 root.print();
625
626 println!("expected");
627 root2.print();
628 assert_eq!(root, root2);
629
630 assert_eq!(root.remove(&Vec::from(&b"abgh"[..])), RemoveResult::Ok);
631 println!("after remove");
632 root.print();
633 println!("expected");
634 let mut root3: TrieNode<u8> = TrieNode::root();
635 assert_eq!(root3.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
636 root3.print();
637 assert_eq!(root, root3);
638 }
639
640 #[test]
641 fn insert_remove_through_regex() {
642 let mut root: TrieNode<u8> = TrieNode::root();
643 println!("creating root:");
644 root.print();
645
646 println!("adding (www./.*/.com, 1)");
647 assert_eq!(
648 root.insert(Vec::from(&b"www./.*/.com"[..]), 1),
649 InsertResult::Ok
650 );
651 root.print();
652 println!("adding (www.doc./.*/.com, 2)");
653 assert_eq!(
654 root.insert(Vec::from(&b"www.doc./.*/.com"[..]), 2),
655 InsertResult::Ok
656 );
657 root.print();
658 assert_eq!(
659 root.domain_lookup(b"www.sozu.com".as_ref(), false),
660 Some(&(b"www./.*/.com".to_vec(), 1))
661 );
662 assert_eq!(
663 root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
664 Some(&(b"www.doc./.*/.com".to_vec(), 2))
665 );
666
667 assert_eq!(
668 root.domain_remove(&b"www./.*/.com".to_vec()),
669 RemoveResult::Ok
670 );
671 root.print();
672 assert_eq!(root.domain_lookup(b"www.sozu.com".as_ref(), false), None);
673 assert_eq!(
674 root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
675 Some(&(b"www.doc./.*/.com".to_vec(), 2))
676 );
677 }
678
679 #[test]
685 fn segment_regex_rejects_partial_matches() {
686 let mut root: TrieNode<u8> = TrieNode::root();
687 assert_eq!(
690 root.insert(Vec::from(&b"/cdn[0-9]+/.example.com"[..]), 7),
691 InsertResult::Ok
692 );
693
694 assert_eq!(
696 root.domain_lookup(b"cdn1.example.com".as_ref(), false),
697 Some(&(b"/cdn[0-9]+/.example.com".to_vec(), 7))
698 );
699 assert_eq!(
700 root.domain_lookup(b"cdn123.example.com".as_ref(), false),
701 Some(&(b"/cdn[0-9]+/.example.com".to_vec(), 7))
702 );
703
704 assert_eq!(
708 root.domain_lookup(b"cdn1xxx.example.com".as_ref(), false),
709 None
710 );
711 assert_eq!(
713 root.domain_lookup(b"xxxcdn1.example.com".as_ref(), false),
714 None
715 );
716 assert_eq!(
718 root.domain_lookup(b"cdnabc.example.com".as_ref(), false),
719 None
720 );
721 }
722
723 #[test]
724 fn add_child_to_leaf() {
725 let mut root1: TrieNode<u8> = TrieNode::root();
726
727 println!("creating root1:");
728 root1.print();
729 println!("adding (abcd, 1)");
730 assert_eq!(root1.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
731 root1.print();
732 println!("adding (abce, 2)");
733 assert_eq!(root1.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
734 root1.print();
735 println!("adding (abc, 3)");
736 assert_eq!(root1.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
737
738 println!("root1:");
739 root1.print();
740
741 let mut root2: TrieNode<u8> = TrieNode::root();
742
743 assert_eq!(root2.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
744 assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
745 assert_eq!(root2.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
746
747 println!("root2:");
748 root2.print();
749 assert_eq!(root2.remove(&Vec::from(&b"abc"[..])), RemoveResult::Ok);
750
751 println!("root2 after,remove:");
752 root2.print();
753 let mut expected: TrieNode<u8> = TrieNode::root();
754
755 assert_eq!(
756 expected.insert(Vec::from(&b"abcd"[..]), 1),
757 InsertResult::Ok
758 );
759 assert_eq!(
760 expected.insert(Vec::from(&b"abce"[..]), 2),
761 InsertResult::Ok
762 );
763
764 println!("root2 after insert");
765 root2.print();
766 println!("expected");
767 expected.print();
768 assert_eq!(root2, expected);
769 }
770
771 #[test]
772 fn domains() {
773 let mut root: TrieNode<u8> = TrieNode::root();
774 root.print();
775
776 assert_eq!(
777 root.domain_insert(Vec::from(&b"www.example.com"[..]), 1),
778 InsertResult::Ok
779 );
780 root.print();
781 assert_eq!(
782 root.domain_insert(Vec::from(&b"test.example.com"[..]), 2),
783 InsertResult::Ok
784 );
785 root.print();
786 assert_eq!(
787 root.domain_insert(Vec::from(&b"*.alldomains.org"[..]), 3),
788 InsertResult::Ok
789 );
790 root.print();
791 assert_eq!(
792 root.domain_insert(Vec::from(&b"alldomains.org"[..]), 4),
793 InsertResult::Ok
794 );
795 assert_eq!(
796 root.domain_insert(Vec::from(&b"pouet.alldomains.org"[..]), 5),
797 InsertResult::Ok
798 );
799 root.print();
800 assert_eq!(
801 root.domain_insert(Vec::from(&b"hello.com"[..]), 6),
802 InsertResult::Ok
803 );
804 assert_eq!(
805 root.domain_insert(Vec::from(&b"*.hello.com"[..]), 7),
806 InsertResult::Ok
807 );
808 assert_eq!(
809 root.domain_insert(Vec::from(&b"images./cdn[0-9]+/.hello.com"[..]), 8),
810 InsertResult::Ok
811 );
812 root.print();
813 assert_eq!(
814 root.domain_insert(Vec::from(&b"/test[0-9]+/.www.hello.com"[..]), 9),
815 InsertResult::Ok
816 );
817 root.print();
818
819 assert_eq!(root.domain_lookup(&b"example.com"[..], true), None);
820 assert_eq!(
821 root.domain_lookup(&b"blah.test.example.com"[..], true),
822 None
823 );
824 assert_eq!(
825 root.domain_lookup(&b"www.example.com"[..], true),
826 Some(&(b"www.example.com"[..].to_vec(), 1))
827 );
828 assert_eq!(
829 root.domain_lookup(&b"alldomains.org"[..], true),
830 Some(&(b"alldomains.org"[..].to_vec(), 4))
831 );
832 assert_eq!(
833 root.domain_lookup(&b"test.hello.com"[..], true),
834 Some(&(b"*.hello.com"[..].to_vec(), 7))
835 );
836 assert_eq!(
837 root.domain_lookup(&b"images.cdn10.hello.com"[..], true),
838 Some(&(b"images./cdn[0-9]+/.hello.com"[..].to_vec(), 8))
839 );
840 assert_eq!(
841 root.domain_lookup(&b"test42.www.hello.com"[..], true),
842 Some(&(b"/test[0-9]+/.www.hello.com"[..].to_vec(), 9))
843 );
844 assert_eq!(
845 root.domain_lookup(&b"test.alldomains.org"[..], true),
846 Some(&(b"*.alldomains.org"[..].to_vec(), 3))
847 );
848 assert_eq!(
849 root.domain_lookup(&b"hello.alldomains.org"[..], true),
850 Some(&(b"*.alldomains.org"[..].to_vec(), 3))
851 );
852 assert_eq!(
853 root.domain_lookup(&b"pouet.alldomains.org"[..], true),
854 Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
855 );
856 assert_eq!(
857 root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
858 None
859 );
860
861 assert_eq!(
862 root.domain_remove(&Vec::from(&b"alldomains.org"[..])),
863 RemoveResult::Ok
864 );
865 println!("after remove");
866 root.print();
867 assert_eq!(root.domain_lookup(&b"alldomains.org"[..], true), None);
868 assert_eq!(
869 root.domain_lookup(&b"test.alldomains.org"[..], true),
870 Some(&(b"*.alldomains.org"[..].to_vec(), 3))
871 );
872 assert_eq!(
873 root.domain_lookup(&b"hello.alldomains.org"[..], true),
874 Some(&(b"*.alldomains.org"[..].to_vec(), 3))
875 );
876 assert_eq!(
877 root.domain_lookup(&b"pouet.alldomains.org"[..], true),
878 Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
879 );
880 assert_eq!(
881 root.domain_lookup(&b"test.hello.com"[..], true),
882 Some(&(b"*.hello.com"[..].to_vec(), 7))
883 );
884 assert_eq!(
885 root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
886 None
887 );
888 }
889
890 #[test]
891 fn wildcard() {
892 let mut root: TrieNode<u8> = TrieNode::root();
893 root.print();
894 root.domain_insert("*.clever-cloud.com".as_bytes().to_vec(), 2u8);
895 root.domain_insert("services.clever-cloud.com".as_bytes().to_vec(), 0u8);
896 root.domain_insert("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8);
897
898 let res = root.domain_lookup(b"test.services.clever-cloud.com", true);
899 println!("query result: {res:?}");
900
901 assert_eq!(
902 root.domain_lookup(b"pgstudio.services.clever-cloud.com", true),
903 Some(&("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8))
904 );
905 }
906
907 fn hm_insert(h: std::collections::HashMap<String, u32>) -> bool {
908 let mut root: TrieNode<u32> = TrieNode::root();
909
910 for (k, v) in h.iter() {
911 if k.is_empty() {
912 continue;
913 }
914
915 if k.as_bytes()[0] == b'.' {
916 continue;
917 }
918
919 if k.contains('/') {
920 continue;
921 }
922
923 if k == "*" {
924 continue;
925 }
926
927 assert_eq!(
930 root.insert(Vec::from(k.as_bytes()), *v),
931 InsertResult::Ok,
932 "could not insert ({k}, {v})"
933 );
934 }
936
937 for (k, v) in h.iter() {
939 if k.is_empty() {
940 continue;
941 }
942
943 if k.as_bytes()[0] == b'.' {
944 continue;
945 }
946
947 if k.contains('/') {
948 continue;
949 }
950
951 if k == "*" {
952 continue;
953 }
954
955 match root.lookup(k.as_bytes(), false) {
957 None => {
958 println!("did not find key '{k}'");
959 return false;
960 }
961 Some(&(ref k1, v1)) => {
962 if k.as_bytes() != &k1[..] || *v != v1 {
963 println!(
964 "request ({}, {}), got ({}, {})",
965 k,
966 v,
967 str::from_utf8(&k1[..]).unwrap(),
968 v1
969 );
970 return false;
971 }
972 }
973 }
974 }
975
976 true
977 }
978
979 #[test]
988 fn insert_disappearing_tree() {
989 let h: std::collections::HashMap<String, u32> = [
990 (String::from("\n\u{3}"), 0),
991 (String::from("\n\u{0}"), 1),
992 (String::from("\n"), 2),
993 ]
994 .iter()
995 .cloned()
996 .collect();
997 assert!(hm_insert(h));
998 }
999
1000 #[test]
1001 fn size() {
1002 assert_size!(TrieNode<u32>, 136);
1003 }
1004}