sozu_lib/router/
pattern_trie.rs

1use std::{collections::HashMap, fmt::Debug, iter, str};
2
3use regex::bytes::Regex;
4
5pub type Key = Vec<u8>;
6pub type KeyValue<K, V> = (K, V);
7
8#[derive(Debug, PartialEq, Eq)]
9pub enum InsertResult {
10    Ok,
11    Existing,
12    Failed,
13}
14
15#[derive(Debug, PartialEq, Eq)]
16pub enum RemoveResult {
17    Ok,
18    NotFound,
19}
20
21fn find_last_dot(input: &[u8]) -> Option<usize> {
22    //println!("find_last_dot: input = {}", from_utf8(input).unwrap());
23    (0..input.len()).rev().find(|&i| input[i] == b'.')
24}
25
26fn find_last_slash(input: &[u8]) -> Option<usize> {
27    //println!("find_last_dot: input = {}", from_utf8(input).unwrap());
28    (0..input.len()).rev().find(|&i| input[i] == b'/')
29}
30
31/// Implementation of a trie tree structure.
32/// In Sozu this is used to store and lookup domains recursively.
33/// Each node represents a "level domain".
34/// A leaf node (leftmost label) can be a wildcard, a regex pattern or a plain string.
35/// Leaves also store a value associated with the complete domain.
36/// For Sozu it is a list of (PathRule, MethodRule, ClusterId). See the Router strucure.
37#[derive(Debug, Default)]
38pub struct TrieNode<V> {
39    key_value: Option<KeyValue<Key, V>>,
40    wildcard: Option<KeyValue<Key, V>>,
41    children: HashMap<Key, TrieNode<V>>,
42    regexps: Vec<(Regex, TrieNode<V>)>,
43}
44
45impl<V: PartialEq> std::cmp::PartialEq for TrieNode<V> {
46    fn eq(&self, other: &Self) -> bool {
47        self.key_value == other.key_value
48            && self.wildcard == other.wildcard
49            && self.children == other.children
50            && self.regexps.len() == other.regexps.len()
51            && self
52                .regexps
53                .iter()
54                .zip(other.regexps.iter())
55                .fold(true, |b, (left, right)| {
56                    b && left.0.as_str() == right.0.as_str() && left.1 == right.1
57                })
58    }
59}
60
61impl<V: Debug + Clone> TrieNode<V> {
62    pub fn new(key: Key, value: V) -> TrieNode<V> {
63        TrieNode {
64            key_value: Some((key, value)),
65            wildcard: None,
66            children: HashMap::new(),
67            regexps: Vec::new(),
68        }
69    }
70
71    pub fn wildcard(key: Key, value: V) -> TrieNode<V> {
72        TrieNode {
73            key_value: None,
74            wildcard: Some((key, value)),
75            children: HashMap::new(),
76            regexps: Vec::new(),
77        }
78    }
79
80    pub fn root() -> TrieNode<V> {
81        TrieNode {
82            key_value: None,
83            wildcard: None,
84            children: HashMap::new(),
85            regexps: Vec::new(),
86        }
87    }
88
89    pub fn is_empty(&self) -> bool {
90        self.key_value.is_none()
91            && self.wildcard.is_none()
92            && self.regexps.is_empty()
93            && self.children.is_empty()
94    }
95
96    pub fn insert(&mut self, key: Key, value: V) -> InsertResult {
97        //println!("insert: key == {}", std::str::from_utf8(&key).unwrap());
98        if key.is_empty() {
99            return InsertResult::Failed;
100        }
101        if key[..] == b"."[..] {
102            return InsertResult::Failed;
103        }
104
105        let insert_result = self.insert_recursive(&key, &key, value);
106        assert_ne!(insert_result, InsertResult::Failed);
107        insert_result
108    }
109
110    pub fn insert_recursive(&mut self, partial_key: &[u8], key: &Key, value: V) -> InsertResult {
111        //println!("insert_rec: key == {}", std::str::from_utf8(partial_key).unwrap());
112        assert_ne!(partial_key, &b""[..]);
113
114        if partial_key[partial_key.len() - 1] == b'/' {
115            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
116
117            if let Some(pos) = pos {
118                if pos > 0 && partial_key[pos - 1] != b'.' {
119                    return InsertResult::Failed;
120                }
121
122                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
123                    for t in self.regexps.iter_mut() {
124                        if t.0.as_str() == s {
125                            return t.1.insert_recursive(&partial_key[..pos - 1], key, value);
126                        }
127                    }
128
129                    if let Ok(r) = Regex::new(s) {
130                        if pos > 0 {
131                            let mut node = TrieNode::root();
132                            let pos = pos - 1;
133
134                            let res = node.insert_recursive(&partial_key[..pos], key, value);
135
136                            if res == InsertResult::Ok {
137                                self.regexps.push((r, node));
138                            }
139
140                            return res;
141                        } else {
142                            let node = TrieNode::new(key.to_vec(), value);
143                            self.regexps.push((r, node));
144                            return InsertResult::Ok;
145                        }
146                    }
147                }
148            }
149
150            return InsertResult::Failed;
151        }
152
153        let pos = find_last_dot(partial_key);
154        match pos {
155            None => {
156                if self.children.contains_key(partial_key) {
157                    InsertResult::Existing
158                } else if partial_key == &b"*"[..] {
159                    if self.wildcard.is_some() {
160                        InsertResult::Existing
161                    } else {
162                        self.wildcard = Some((key.to_vec(), value));
163                        InsertResult::Ok
164                    }
165                } else {
166                    let node = TrieNode::new(key.to_vec(), value);
167                    self.children.insert(partial_key.to_vec(), node);
168                    InsertResult::Ok
169                }
170            }
171            Some(pos) => {
172                if let Some(child) = self.children.get_mut(&partial_key[pos..]) {
173                    return child.insert_recursive(&partial_key[..pos], key, value);
174                }
175
176                let mut node = TrieNode::root();
177                let res = node.insert_recursive(&partial_key[..pos], key, value);
178
179                if res == InsertResult::Ok {
180                    self.children.insert(partial_key[pos..].to_vec(), node);
181                }
182
183                res
184            }
185        }
186    }
187
188    pub fn remove(&mut self, key: &Key) -> RemoveResult {
189        self.remove_recursive(key)
190    }
191
192    pub fn remove_recursive(&mut self, partial_key: &[u8]) -> RemoveResult {
193        //println!("remove: key == {}", std::str::from_utf8(partial_key).unwrap());
194
195        if partial_key.is_empty() {
196            if self.key_value.is_some() {
197                self.key_value = None;
198                return RemoveResult::Ok;
199            } else {
200                return RemoveResult::NotFound;
201            }
202        }
203
204        if partial_key == &b"*"[..] {
205            if self.wildcard.is_some() {
206                self.wildcard = None;
207                return RemoveResult::Ok;
208            } else {
209                return RemoveResult::NotFound;
210            }
211        }
212
213        if partial_key[partial_key.len() - 1] == b'/' {
214            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
215
216            if let Some(pos) = pos {
217                if pos > 0 && partial_key[pos - 1] != b'.' {
218                    return RemoveResult::NotFound;
219                }
220
221                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
222                    if pos > 0 {
223                        let mut remove_result = RemoveResult::NotFound;
224                        for t in self.regexps.iter_mut() {
225                            if t.0.as_str() == s
226                                && t.1.remove_recursive(&partial_key[..pos - 1]) == RemoveResult::Ok
227                            {
228                                remove_result = RemoveResult::Ok;
229                            }
230                        }
231                        return remove_result;
232                    } else {
233                        let len = self.regexps.len();
234                        self.regexps.retain(|(r, _)| r.as_str() != s);
235                        if len > self.regexps.len() {
236                            return RemoveResult::Ok;
237                        }
238                    }
239                }
240            }
241
242            return RemoveResult::NotFound;
243        }
244
245        let pos = find_last_dot(partial_key);
246        let (prefix, suffix) = match pos {
247            None => (&b""[..], partial_key),
248            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
249        };
250        //println!("remove: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
251
252        match self.children.get_mut(suffix) {
253            Some(child) => match child.remove_recursive(prefix) {
254                RemoveResult::NotFound => RemoveResult::NotFound,
255                RemoveResult::Ok => {
256                    if child.is_empty() {
257                        self.children.remove(suffix);
258                    }
259                    RemoveResult::Ok
260                }
261            },
262            None => RemoveResult::NotFound,
263        }
264    }
265
266    pub fn lookup(&self, partial_key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
267        //println!("lookup: key == {}", std::str::from_utf8(partial_key).unwrap());
268
269        if partial_key.is_empty() {
270            return self.key_value.as_ref();
271        }
272
273        let pos = find_last_dot(partial_key);
274        let (prefix, suffix) = match pos {
275            None => (&b""[..], partial_key),
276            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
277        };
278        //println!("lookup: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
279
280        match self.children.get(suffix) {
281            Some(child) => child.lookup(prefix, accept_wildcard),
282            None => {
283                //println!("no child found, testing wildcard and regexps");
284
285                if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
286                    //println!("no dot, wildcard applies");
287                    self.wildcard.as_ref()
288                } else {
289                    //println!("there's still a subdomain, wildcard does not apply");
290
291                    for (ref regexp, ref child) in self.regexps.iter() {
292                        let suffix = if suffix[0] == b'.' {
293                            &suffix[1..]
294                        } else {
295                            suffix
296                        };
297                        //println!("testing regexp: {} on suffix {}", r.as_str(), str::from_utf8(s).unwrap());
298
299                        if regexp.is_match(suffix) {
300                            //println!("matched");
301                            return child.lookup(prefix, accept_wildcard);
302                        }
303                    }
304
305                    None
306                }
307            }
308        }
309    }
310
311    pub fn lookup_mut(
312        &mut self,
313        partial_key: &[u8],
314        accept_wildcard: bool,
315    ) -> Option<&mut KeyValue<Key, V>> {
316        //println!("lookup: key == {}", std::str::from_utf8(partial_key).unwrap());
317
318        if partial_key.is_empty() {
319            return self.key_value.as_mut();
320        }
321
322        if partial_key == &b"*"[..] {
323            return self.wildcard.as_mut();
324        }
325
326        if partial_key[partial_key.len() - 1] == b'/' {
327            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
328
329            if let Some(pos) = pos {
330                if pos > 0 && partial_key[pos - 1] != b'.' {
331                    return None;
332                }
333
334                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
335                    for t in self.regexps.iter_mut() {
336                        if t.0.as_str() == s {
337                            return t.1.lookup_mut(&partial_key[..pos - 1], accept_wildcard);
338                        }
339                    }
340                }
341            }
342
343            return None;
344        }
345
346        let pos = find_last_dot(partial_key);
347        let (prefix, suffix) = match pos {
348            None => (&b""[..], partial_key),
349            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
350        };
351        //println!("lookup: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
352
353        match self.children.get_mut(suffix) {
354            Some(child) => child.lookup_mut(prefix, accept_wildcard),
355            None => {
356                //println!("no child found, testing wildcard and regexps");
357
358                if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
359                    //println!("no dot, wildcard applies");
360                    self.wildcard.as_mut()
361                } else {
362                    //println!("there's still a subdomain, wildcard does not apply");
363
364                    for (ref regexp, ref mut child) in self.regexps.iter_mut() {
365                        let suffix = if suffix[0] == b'.' {
366                            &suffix[1..]
367                        } else {
368                            suffix
369                        };
370                        //println!("testing regexp: {} on suffix {}", r.as_str(), str::from_utf8(s).unwrap());
371
372                        if regexp.is_match(suffix) {
373                            //println!("matched");
374                            return child.lookup_mut(prefix, accept_wildcard);
375                        }
376                    }
377
378                    None
379                }
380            }
381        }
382    }
383
384    pub fn print(&self) {
385        self.print_recursive(b"", 0)
386    }
387
388    pub fn print_recursive(&self, partial_key: &[u8], indent: u8) {
389        let raw_prefix: Vec<u8> = iter::repeat(b' ').take(2 * indent as usize).collect();
390        let prefix = str::from_utf8(&raw_prefix).unwrap();
391
392        print!("{}{}: ", prefix, str::from_utf8(partial_key).unwrap());
393        if let Some((ref key, ref value)) = self.key_value {
394            print!("({}, {:?}) | ", str::from_utf8(key).unwrap(), value);
395        } else {
396            print!("None | ");
397        }
398
399        if let Some((key, value)) = &self.wildcard {
400            println!("({}, {:?})", str::from_utf8(key).unwrap(), value);
401        } else {
402            println!("None");
403        }
404
405        for (child_key, child) in self.children.iter() {
406            child.print_recursive(child_key, indent + 1);
407        }
408
409        for (regexp, child) in self.regexps.iter() {
410            //print!("{}{}:", prefix, regexp.as_str());
411            child.print_recursive(regexp.as_str().as_bytes(), indent + 1);
412        }
413    }
414
415    pub fn domain_insert(&mut self, key: Key, value: V) -> InsertResult {
416        self.insert(key, value)
417    }
418
419    pub fn domain_remove(&mut self, key: &Key) -> RemoveResult {
420        self.remove(key)
421    }
422
423    pub fn domain_lookup(&self, key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
424        self.lookup(key, accept_wildcard)
425    }
426
427    pub fn domain_lookup_mut(
428        &mut self,
429        key: &[u8],
430        accept_wildcard: bool,
431    ) -> Option<&mut KeyValue<Key, V>> {
432        self.lookup_mut(key, accept_wildcard)
433    }
434
435    pub fn size(&self) -> usize {
436        ::std::mem::size_of::<TrieNode<V>>()
437            + ::std::mem::size_of::<Option<KeyValue<Key, V>>>() * 2
438            + self
439                .children
440                .iter()
441                .fold(0, |acc, c| acc + c.0.len() + c.1.size())
442    }
443
444    pub fn to_hashmap(&self) -> HashMap<Key, V> {
445        let mut h = HashMap::new();
446
447        self.to_hashmap_recursive(&mut h);
448
449        h
450    }
451
452    pub fn to_hashmap_recursive(&self, h: &mut HashMap<Key, V>) {
453        if let Some((key, value)) = &self.key_value {
454            h.insert(key.clone(), value.clone());
455        }
456
457        if let Some((key, value)) = &self.wildcard {
458            h.insert(key.clone(), value.clone());
459        }
460
461        for child in self.children.values() {
462            child.to_hashmap_recursive(h);
463        }
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn insert() {
473        let mut root: TrieNode<u8> = TrieNode::root();
474        root.print();
475
476        assert_eq!(
477            root.domain_insert(Vec::from(&b"abcd"[..]), 1),
478            InsertResult::Ok
479        );
480        root.print();
481        assert_eq!(
482            root.domain_insert(Vec::from(&b"abce"[..]), 2),
483            InsertResult::Ok
484        );
485        root.print();
486        assert_eq!(
487            root.domain_insert(Vec::from(&b"abgh"[..]), 3),
488            InsertResult::Ok
489        );
490        root.print();
491
492        assert_eq!(
493            root.domain_lookup(&b"abce"[..], true),
494            Some(&(b"abce"[..].to_vec(), 2))
495        );
496        //assert!(false);
497    }
498
499    #[test]
500    fn remove() {
501        let mut root: TrieNode<u8> = TrieNode::root();
502        println!("creating root:");
503        root.print();
504
505        println!("adding (abcd, 1)");
506        assert_eq!(root.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
507        root.print();
508        println!("adding (abce, 2)");
509        assert_eq!(root.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
510        root.print();
511        println!("adding (abgh, 3)");
512        assert_eq!(root.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
513        root.print();
514
515        let mut root2: TrieNode<u8> = TrieNode::root();
516
517        assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
518        assert_eq!(root2.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
519
520        println!("before remove");
521        root.print();
522        assert_eq!(root.remove(&Vec::from(&b"abce"[..])), RemoveResult::Ok);
523        println!("after remove");
524        root.print();
525
526        println!("expected");
527        root2.print();
528        assert_eq!(root, root2);
529
530        assert_eq!(root.remove(&Vec::from(&b"abgh"[..])), RemoveResult::Ok);
531        println!("after remove");
532        root.print();
533        println!("expected");
534        let mut root3: TrieNode<u8> = TrieNode::root();
535        assert_eq!(root3.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
536        root3.print();
537        assert_eq!(root, root3);
538    }
539
540    #[test]
541    fn insert_remove_through_regex() {
542        let mut root: TrieNode<u8> = TrieNode::root();
543        println!("creating root:");
544        root.print();
545
546        println!("adding (www./.*/.com, 1)");
547        assert_eq!(
548            root.insert(Vec::from(&b"www./.*/.com"[..]), 1),
549            InsertResult::Ok
550        );
551        root.print();
552        println!("adding (www.doc./.*/.com, 2)");
553        assert_eq!(
554            root.insert(Vec::from(&b"www.doc./.*/.com"[..]), 2),
555            InsertResult::Ok
556        );
557        root.print();
558        assert_eq!(
559            root.domain_lookup(b"www.sozu.com".as_ref(), false),
560            Some(&(b"www./.*/.com".to_vec(), 1))
561        );
562        assert_eq!(
563            root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
564            Some(&(b"www.doc./.*/.com".to_vec(), 2))
565        );
566
567        assert_eq!(
568            root.domain_remove(&b"www./.*/.com".to_vec()),
569            RemoveResult::Ok
570        );
571        root.print();
572        assert_eq!(root.domain_lookup(b"www.sozu.com".as_ref(), false), None);
573        assert_eq!(
574            root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
575            Some(&(b"www.doc./.*/.com".to_vec(), 2))
576        );
577    }
578
579    #[test]
580    fn add_child_to_leaf() {
581        let mut root1: TrieNode<u8> = TrieNode::root();
582
583        println!("creating root1:");
584        root1.print();
585        println!("adding (abcd, 1)");
586        assert_eq!(root1.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
587        root1.print();
588        println!("adding (abce, 2)");
589        assert_eq!(root1.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
590        root1.print();
591        println!("adding (abc, 3)");
592        assert_eq!(root1.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
593
594        println!("root1:");
595        root1.print();
596
597        let mut root2: TrieNode<u8> = TrieNode::root();
598
599        assert_eq!(root2.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
600        assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
601        assert_eq!(root2.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
602
603        println!("root2:");
604        root2.print();
605        assert_eq!(root2.remove(&Vec::from(&b"abc"[..])), RemoveResult::Ok);
606
607        println!("root2 after,remove:");
608        root2.print();
609        let mut expected: TrieNode<u8> = TrieNode::root();
610
611        assert_eq!(
612            expected.insert(Vec::from(&b"abcd"[..]), 1),
613            InsertResult::Ok
614        );
615        assert_eq!(
616            expected.insert(Vec::from(&b"abce"[..]), 2),
617            InsertResult::Ok
618        );
619
620        println!("root2 after insert");
621        root2.print();
622        println!("expected");
623        expected.print();
624        assert_eq!(root2, expected);
625    }
626
627    #[test]
628    fn domains() {
629        let mut root: TrieNode<u8> = TrieNode::root();
630        root.print();
631
632        assert_eq!(
633            root.domain_insert(Vec::from(&b"www.example.com"[..]), 1),
634            InsertResult::Ok
635        );
636        root.print();
637        assert_eq!(
638            root.domain_insert(Vec::from(&b"test.example.com"[..]), 2),
639            InsertResult::Ok
640        );
641        root.print();
642        assert_eq!(
643            root.domain_insert(Vec::from(&b"*.alldomains.org"[..]), 3),
644            InsertResult::Ok
645        );
646        root.print();
647        assert_eq!(
648            root.domain_insert(Vec::from(&b"alldomains.org"[..]), 4),
649            InsertResult::Ok
650        );
651        assert_eq!(
652            root.domain_insert(Vec::from(&b"pouet.alldomains.org"[..]), 5),
653            InsertResult::Ok
654        );
655        root.print();
656        assert_eq!(
657            root.domain_insert(Vec::from(&b"hello.com"[..]), 6),
658            InsertResult::Ok
659        );
660        assert_eq!(
661            root.domain_insert(Vec::from(&b"*.hello.com"[..]), 7),
662            InsertResult::Ok
663        );
664        assert_eq!(
665            root.domain_insert(Vec::from(&b"images./cdn[0-9]+/.hello.com"[..]), 8),
666            InsertResult::Ok
667        );
668        root.print();
669        assert_eq!(
670            root.domain_insert(Vec::from(&b"/test[0-9]+/.www.hello.com"[..]), 9),
671            InsertResult::Ok
672        );
673        root.print();
674
675        assert_eq!(root.domain_lookup(&b"example.com"[..], true), None);
676        assert_eq!(
677            root.domain_lookup(&b"blah.test.example.com"[..], true),
678            None
679        );
680        assert_eq!(
681            root.domain_lookup(&b"www.example.com"[..], true),
682            Some(&(b"www.example.com"[..].to_vec(), 1))
683        );
684        assert_eq!(
685            root.domain_lookup(&b"alldomains.org"[..], true),
686            Some(&(b"alldomains.org"[..].to_vec(), 4))
687        );
688        assert_eq!(
689            root.domain_lookup(&b"test.hello.com"[..], true),
690            Some(&(b"*.hello.com"[..].to_vec(), 7))
691        );
692        assert_eq!(
693            root.domain_lookup(&b"images.cdn10.hello.com"[..], true),
694            Some(&(b"images./cdn[0-9]+/.hello.com"[..].to_vec(), 8))
695        );
696        assert_eq!(
697            root.domain_lookup(&b"test42.www.hello.com"[..], true),
698            Some(&(b"/test[0-9]+/.www.hello.com"[..].to_vec(), 9))
699        );
700        assert_eq!(
701            root.domain_lookup(&b"test.alldomains.org"[..], true),
702            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
703        );
704        assert_eq!(
705            root.domain_lookup(&b"hello.alldomains.org"[..], true),
706            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
707        );
708        assert_eq!(
709            root.domain_lookup(&b"pouet.alldomains.org"[..], true),
710            Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
711        );
712        assert_eq!(
713            root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
714            None
715        );
716
717        assert_eq!(
718            root.domain_remove(&Vec::from(&b"alldomains.org"[..])),
719            RemoveResult::Ok
720        );
721        println!("after remove");
722        root.print();
723        assert_eq!(root.domain_lookup(&b"alldomains.org"[..], true), None);
724        assert_eq!(
725            root.domain_lookup(&b"test.alldomains.org"[..], true),
726            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
727        );
728        assert_eq!(
729            root.domain_lookup(&b"hello.alldomains.org"[..], true),
730            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
731        );
732        assert_eq!(
733            root.domain_lookup(&b"pouet.alldomains.org"[..], true),
734            Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
735        );
736        assert_eq!(
737            root.domain_lookup(&b"test.hello.com"[..], true),
738            Some(&(b"*.hello.com"[..].to_vec(), 7))
739        );
740        assert_eq!(
741            root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
742            None
743        );
744    }
745
746    #[test]
747    fn wildcard() {
748        let mut root: TrieNode<u8> = TrieNode::root();
749        root.print();
750        root.domain_insert("*.clever-cloud.com".as_bytes().to_vec(), 2u8);
751        root.domain_insert("services.clever-cloud.com".as_bytes().to_vec(), 0u8);
752        root.domain_insert("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8);
753
754        let res = root.domain_lookup(b"test.services.clever-cloud.com", true);
755        println!("query result: {res:?}");
756
757        assert_eq!(
758            root.domain_lookup(b"pgstudio.services.clever-cloud.com", true),
759            Some(&("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8))
760        );
761    }
762
763    fn hm_insert(h: std::collections::HashMap<String, u32>) -> bool {
764        let mut root: TrieNode<u32> = TrieNode::root();
765
766        for (k, v) in h.iter() {
767            if k.is_empty() {
768                continue;
769            }
770
771            if k.as_bytes()[0] == b'.' {
772                continue;
773            }
774
775            if k.contains('/') {
776                continue;
777            }
778
779            if k == "*" {
780                continue;
781            }
782
783            //println!("inserting key: '{}', value: '{}'", k, v);
784            //assert_eq!(root.domain_insert(Vec::from(k.as_bytes()), *v), InsertResult::Ok);
785            assert_eq!(
786                root.insert(Vec::from(k.as_bytes()), *v),
787                InsertResult::Ok,
788                "could not insert ({k}, {v})"
789            );
790            //root.print();
791        }
792
793        //root.print();
794        for (k, v) in h.iter() {
795            if k.is_empty() {
796                continue;
797            }
798
799            if k.as_bytes()[0] == b'.' {
800                continue;
801            }
802
803            if k.contains('/') {
804                continue;
805            }
806
807            if k == "*" {
808                continue;
809            }
810
811            //match root.domain_lookup(k.as_bytes()) {
812            match root.lookup(k.as_bytes(), false) {
813                None => {
814                    println!("did not find key '{k}'");
815                    return false;
816                }
817                Some(&(ref k1, v1)) => {
818                    if k.as_bytes() != &k1[..] || *v != v1 {
819                        println!(
820                            "request ({}, {}), got ({}, {})",
821                            k,
822                            v,
823                            str::from_utf8(&k1[..]).unwrap(),
824                            v1
825                        );
826                        return false;
827                    }
828                }
829            }
830        }
831
832        true
833    }
834
835    /* FIXME: randomly fails
836    quickcheck! {
837      fn qc_insert(h: std::collections::HashMap<String, u32>) -> bool {
838        hm_insert(h)
839      }
840    }
841    */
842
843    #[test]
844    fn insert_disappearing_tree() {
845        let h: std::collections::HashMap<String, u32> = [
846            (String::from("\n\u{3}"), 0),
847            (String::from("\n\u{0}"), 1),
848            (String::from("\n"), 2),
849        ]
850        .iter()
851        .cloned()
852        .collect();
853        assert!(hm_insert(h));
854    }
855
856    #[test]
857    fn size() {
858        assert_size!(TrieNode<u32>, 136);
859    }
860}