Skip to main content

sozu_lib/router/
pattern_trie.rs

1use std::{collections::HashMap, fmt::Debug, iter, str};
2
3use regex::bytes::Regex;
4
5pub type Key = Vec<u8>;
6pub type KeyValue<K, V> = (K, V);
7
8#[derive(Debug, PartialEq, Eq)]
9pub enum InsertResult {
10    Ok,
11    Existing,
12    Failed,
13}
14
15#[derive(Debug, PartialEq, Eq)]
16pub enum RemoveResult {
17    Ok,
18    NotFound,
19}
20
21fn find_last_dot(input: &[u8]) -> Option<usize> {
22    //println!("find_last_dot: input = {}", from_utf8(input).unwrap());
23    (0..input.len()).rev().find(|&i| input[i] == b'.')
24}
25
26fn find_last_slash(input: &[u8]) -> Option<usize> {
27    //println!("find_last_dot: input = {}", from_utf8(input).unwrap());
28    (0..input.len()).rev().find(|&i| input[i] == b'/')
29}
30
31/// Implementation of a trie tree structure.
32/// In Sozu this is used to store and lookup domains recursively.
33/// Each node represents a "level domain".
34/// A leaf node (leftmost label) can be a wildcard, a regex pattern or a plain string.
35/// Leaves also store a value associated with the complete domain.
36/// For Sozu it is a list of (PathRule, MethodRule, ClusterId). See the Router strucure.
37#[derive(Debug, Default)]
38pub struct TrieNode<V> {
39    key_value: Option<KeyValue<Key, V>>,
40    wildcard: Option<KeyValue<Key, V>>,
41    children: HashMap<Key, TrieNode<V>>,
42    regexps: Vec<(Regex, TrieNode<V>)>,
43}
44
45/// One step of a trie traversal where a non-literal segment matched.
46///
47/// `Wildcard` carries the actual segment bytes consumed by a `*` wildcard
48/// (so a router that wants to capture them can splice them into a rewrite
49/// template). `Regexp` carries both the matched bytes and the regex itself
50/// so the caller can re-run `Regex::captures` to pull explicit groups.
51#[derive(Debug)]
52pub enum TrieSubMatch<'a, 'b> {
53    Wildcard(&'a [u8]),
54    Regexp(&'a [u8], &'b Regex),
55}
56
57/// Ordered list of non-literal trie segments visited during a successful
58/// `lookup_with_path` traversal. Routers feed the entries into rewrite
59/// templates (`$HOST[n]`) so frontend rewrites can reach into the matched
60/// segments. Empty when only literal segments matched.
61pub type TrieMatches<'a, 'b> = Vec<TrieSubMatch<'a, 'b>>;
62
63impl<V: PartialEq> std::cmp::PartialEq for TrieNode<V> {
64    fn eq(&self, other: &Self) -> bool {
65        self.key_value == other.key_value
66            && self.wildcard == other.wildcard
67            && self.children == other.children
68            && self.regexps.len() == other.regexps.len()
69            && self
70                .regexps
71                .iter()
72                .zip(other.regexps.iter())
73                .fold(true, |b, (left, right)| {
74                    b && left.0.as_str() == right.0.as_str() && left.1 == right.1
75                })
76    }
77}
78
79impl<V: Debug + Clone> TrieNode<V> {
80    pub fn new(key: Key, value: V) -> TrieNode<V> {
81        TrieNode {
82            key_value: Some((key, value)),
83            wildcard: None,
84            children: HashMap::new(),
85            regexps: Vec::new(),
86        }
87    }
88
89    pub fn wildcard(key: Key, value: V) -> TrieNode<V> {
90        TrieNode {
91            key_value: None,
92            wildcard: Some((key, value)),
93            children: HashMap::new(),
94            regexps: Vec::new(),
95        }
96    }
97
98    pub fn root() -> TrieNode<V> {
99        TrieNode {
100            key_value: None,
101            wildcard: None,
102            children: HashMap::new(),
103            regexps: Vec::new(),
104        }
105    }
106
107    pub fn is_empty(&self) -> bool {
108        self.key_value.is_none()
109            && self.wildcard.is_none()
110            && self.regexps.is_empty()
111            && self.children.is_empty()
112    }
113
114    pub fn insert(&mut self, key: Key, value: V) -> InsertResult {
115        //println!("insert: key == {}", std::str::from_utf8(&key).unwrap());
116        if key.is_empty() {
117            return InsertResult::Failed;
118        }
119        if key[..] == b"."[..] {
120            return InsertResult::Failed;
121        }
122
123        let insert_result = self.insert_recursive(&key, &key, value);
124        assert_ne!(insert_result, InsertResult::Failed);
125        insert_result
126    }
127
128    pub fn insert_recursive(&mut self, partial_key: &[u8], key: &Key, value: V) -> InsertResult {
129        //println!("insert_rec: key == {}", std::str::from_utf8(partial_key).unwrap());
130        assert_ne!(partial_key, &b""[..]);
131
132        if partial_key[partial_key.len() - 1] == b'/' {
133            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
134
135            if let Some(pos) = pos {
136                if pos > 0 && partial_key[pos - 1] != b'.' {
137                    return InsertResult::Failed;
138                }
139
140                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
141                    let anchored_s = format!("\\A{s}\\z");
142                    for t in self.regexps.iter_mut() {
143                        if t.0.as_str() == anchored_s {
144                            return t.1.insert_recursive(&partial_key[..pos - 1], key, value);
145                        }
146                    }
147
148                    // Anchor segment regexes so they only match the entire
149                    // segment, not partial overlaps. Without `\A...\z`, a
150                    // pattern like `cdn[0-9]+` would match `cdn123xxx`,
151                    // which silently widens the routing surface.
152                    let anchored = format!("\\A{s}\\z");
153                    if let Ok(r) = Regex::new(&anchored) {
154                        if pos > 0 {
155                            let mut node = TrieNode::root();
156                            let pos = pos - 1;
157
158                            let res = node.insert_recursive(&partial_key[..pos], key, value);
159
160                            if res == InsertResult::Ok {
161                                self.regexps.push((r, node));
162                            }
163
164                            return res;
165                        } else {
166                            let node = TrieNode::new(key.to_vec(), value);
167                            self.regexps.push((r, node));
168                            return InsertResult::Ok;
169                        }
170                    }
171                }
172            }
173
174            return InsertResult::Failed;
175        }
176
177        let pos = find_last_dot(partial_key);
178        match pos {
179            None => {
180                if self.children.contains_key(partial_key) {
181                    InsertResult::Existing
182                } else if partial_key == &b"*"[..] {
183                    if self.wildcard.is_some() {
184                        InsertResult::Existing
185                    } else {
186                        self.wildcard = Some((key.to_vec(), value));
187                        InsertResult::Ok
188                    }
189                } else {
190                    let node = TrieNode::new(key.to_vec(), value);
191                    self.children.insert(partial_key.to_vec(), node);
192                    InsertResult::Ok
193                }
194            }
195            Some(pos) => {
196                if let Some(child) = self.children.get_mut(&partial_key[pos..]) {
197                    return child.insert_recursive(&partial_key[..pos], key, value);
198                }
199
200                let mut node = TrieNode::root();
201                let res = node.insert_recursive(&partial_key[..pos], key, value);
202
203                if res == InsertResult::Ok {
204                    self.children.insert(partial_key[pos..].to_vec(), node);
205                }
206
207                res
208            }
209        }
210    }
211
212    pub fn remove(&mut self, key: &Key) -> RemoveResult {
213        self.remove_recursive(key)
214    }
215
216    pub fn remove_recursive(&mut self, partial_key: &[u8]) -> RemoveResult {
217        //println!("remove: key == {}", std::str::from_utf8(partial_key).unwrap());
218
219        if partial_key.is_empty() {
220            if self.key_value.is_some() {
221                self.key_value = None;
222                return RemoveResult::Ok;
223            } else {
224                return RemoveResult::NotFound;
225            }
226        }
227
228        if partial_key == &b"*"[..] {
229            if self.wildcard.is_some() {
230                self.wildcard = None;
231                return RemoveResult::Ok;
232            } else {
233                return RemoveResult::NotFound;
234            }
235        }
236
237        if partial_key[partial_key.len() - 1] == b'/' {
238            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
239
240            if let Some(pos) = pos {
241                if pos > 0 && partial_key[pos - 1] != b'.' {
242                    return RemoveResult::NotFound;
243                }
244
245                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
246                    let anchored_s = format!("\\A{s}\\z");
247                    if pos > 0 {
248                        let mut remove_result = RemoveResult::NotFound;
249                        for t in self.regexps.iter_mut() {
250                            if t.0.as_str() == anchored_s
251                                && t.1.remove_recursive(&partial_key[..pos - 1]) == RemoveResult::Ok
252                            {
253                                remove_result = RemoveResult::Ok;
254                            }
255                        }
256                        return remove_result;
257                    } else {
258                        let len = self.regexps.len();
259                        self.regexps.retain(|(r, _)| r.as_str() != anchored_s);
260                        if len > self.regexps.len() {
261                            return RemoveResult::Ok;
262                        }
263                    }
264                }
265            }
266
267            return RemoveResult::NotFound;
268        }
269
270        let pos = find_last_dot(partial_key);
271        let (prefix, suffix) = match pos {
272            None => (&b""[..], partial_key),
273            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
274        };
275        //println!("remove: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
276
277        match self.children.get_mut(suffix) {
278            Some(child) => match child.remove_recursive(prefix) {
279                RemoveResult::NotFound => RemoveResult::NotFound,
280                RemoveResult::Ok => {
281                    if child.is_empty() {
282                        self.children.remove(suffix);
283                    }
284                    RemoveResult::Ok
285                }
286            },
287            None => RemoveResult::NotFound,
288        }
289    }
290
291    /// Look up `partial_key` and additionally collect the non-literal segments
292    /// that matched along the way (`TrieMatches`).
293    ///
294    /// Equivalent to `lookup` for callers that don't need the captures, but
295    /// frontends with `$HOST[n]` rewrite templates need the matched segments
296    /// to fill the placeholders. The accumulator is passed in by value so
297    /// callers can pre-size it (`Vec::with_capacity`) and we own the path
298    /// returned alongside the value.
299    pub fn lookup_with_path<'a, 'b>(
300        &'b self,
301        partial_key: &'a [u8],
302        accept_wildcard: bool,
303        mut trace: TrieMatches<'a, 'b>,
304    ) -> Option<(&'b KeyValue<Key, V>, TrieMatches<'a, 'b>)> {
305        if partial_key.is_empty() {
306            return self.key_value.as_ref().map(|kv| (kv, trace));
307        }
308
309        let pos = find_last_dot(partial_key);
310        let (prefix, suffix) = match pos {
311            None => (&b""[..], partial_key),
312            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
313        };
314
315        match self.children.get(suffix) {
316            Some(child) => child.lookup_with_path(prefix, accept_wildcard, trace),
317            None => {
318                if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
319                    let segment = if !suffix.is_empty() && suffix[0] == b'.' {
320                        &suffix[1..]
321                    } else {
322                        suffix
323                    };
324                    trace.push(TrieSubMatch::Wildcard(segment));
325                    self.wildcard.as_ref().map(|kv| (kv, trace))
326                } else {
327                    for (regexp, child) in self.regexps.iter() {
328                        let segment = if !suffix.is_empty() && suffix[0] == b'.' {
329                            &suffix[1..]
330                        } else {
331                            suffix
332                        };
333                        if regexp.is_match(segment) {
334                            let mut next = trace;
335                            next.push(TrieSubMatch::Regexp(segment, regexp));
336                            return child.lookup_with_path(prefix, accept_wildcard, next);
337                        }
338                    }
339                    None
340                }
341            }
342        }
343    }
344
345    pub fn lookup(&self, partial_key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
346        //println!("lookup: key == {}", std::str::from_utf8(partial_key).unwrap());
347
348        if partial_key.is_empty() {
349            return self.key_value.as_ref();
350        }
351
352        let pos = find_last_dot(partial_key);
353        let (prefix, suffix) = match pos {
354            None => (&b""[..], partial_key),
355            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
356        };
357        //println!("lookup: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
358
359        match self.children.get(suffix) {
360            Some(child) => child.lookup(prefix, accept_wildcard),
361            None => {
362                //println!("no child found, testing wildcard and regexps");
363
364                if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
365                    //println!("no dot, wildcard applies");
366                    self.wildcard.as_ref()
367                } else {
368                    //println!("there's still a subdomain, wildcard does not apply");
369
370                    for (regexp, child) in self.regexps.iter() {
371                        let suffix = if suffix[0] == b'.' {
372                            &suffix[1..]
373                        } else {
374                            suffix
375                        };
376                        //println!("testing regexp: {} on suffix {}", r.as_str(), str::from_utf8(s).unwrap());
377
378                        if regexp.is_match(suffix) {
379                            //println!("matched");
380                            return child.lookup(prefix, accept_wildcard);
381                        }
382                    }
383
384                    None
385                }
386            }
387        }
388    }
389
390    pub fn lookup_mut(
391        &mut self,
392        partial_key: &[u8],
393        accept_wildcard: bool,
394    ) -> Option<&mut KeyValue<Key, V>> {
395        //println!("lookup: key == {}", std::str::from_utf8(partial_key).unwrap());
396
397        if partial_key.is_empty() {
398            return self.key_value.as_mut();
399        }
400
401        if partial_key == &b"*"[..] {
402            return self.wildcard.as_mut();
403        }
404
405        if partial_key[partial_key.len() - 1] == b'/' {
406            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
407
408            if let Some(pos) = pos {
409                if pos > 0 && partial_key[pos - 1] != b'.' {
410                    return None;
411                }
412
413                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
414                    let anchored_s = format!("\\A{s}\\z");
415                    for t in self.regexps.iter_mut() {
416                        if t.0.as_str() == anchored_s {
417                            return t.1.lookup_mut(&partial_key[..pos - 1], accept_wildcard);
418                        }
419                    }
420                }
421            }
422
423            return None;
424        }
425
426        let pos = find_last_dot(partial_key);
427        let (prefix, suffix) = match pos {
428            None => (&b""[..], partial_key),
429            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
430        };
431        //println!("lookup: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
432
433        match self.children.get_mut(suffix) {
434            Some(child) => child.lookup_mut(prefix, accept_wildcard),
435            None => {
436                //println!("no child found, testing wildcard and regexps");
437
438                if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
439                    //println!("no dot, wildcard applies");
440                    self.wildcard.as_mut()
441                } else {
442                    //println!("there's still a subdomain, wildcard does not apply");
443
444                    for &mut (ref regexp, ref mut child) in self.regexps.iter_mut() {
445                        let suffix = if suffix[0] == b'.' {
446                            &suffix[1..]
447                        } else {
448                            suffix
449                        };
450                        //println!("testing regexp: {} on suffix {}", r.as_str(), str::from_utf8(s).unwrap());
451
452                        if regexp.is_match(suffix) {
453                            //println!("matched");
454                            return child.lookup_mut(prefix, accept_wildcard);
455                        }
456                    }
457
458                    None
459                }
460            }
461        }
462    }
463
464    pub fn print(&self) {
465        self.print_recursive(b"", 0)
466    }
467
468    pub fn print_recursive(&self, partial_key: &[u8], indent: u8) {
469        let raw_prefix: Vec<u8> = iter::repeat_n(b' ', 2 * indent as usize).collect();
470        let prefix = str::from_utf8(&raw_prefix).unwrap();
471
472        print!("{}{}: ", prefix, str::from_utf8(partial_key).unwrap());
473        if let Some((ref key, ref value)) = self.key_value {
474            print!("({}, {:?}) | ", str::from_utf8(key).unwrap(), value);
475        } else {
476            print!("None | ");
477        }
478
479        if let Some((key, value)) = &self.wildcard {
480            println!("({}, {:?})", str::from_utf8(key).unwrap(), value);
481        } else {
482            println!("None");
483        }
484
485        for (child_key, child) in self.children.iter() {
486            child.print_recursive(child_key, indent + 1);
487        }
488
489        for (regexp, child) in self.regexps.iter() {
490            //print!("{}{}:", prefix, regexp.as_str());
491            child.print_recursive(regexp.as_str().as_bytes(), indent + 1);
492        }
493    }
494
495    /// Visit every stored value in the trie (the literal `key_value` and
496    /// the leftmost `wildcard` slot of every node, plus all
497    /// regex-subtree leaves) and invoke `f` on each. Used by the router
498    /// to walk all routes for cross-cutting refreshes (e.g. listener-
499    /// default HSTS reflow) without rebuilding the trie.
500    pub fn for_each_value_mut<F: FnMut(&mut V)>(&mut self, f: &mut F) {
501        if let Some((_, ref mut value)) = self.key_value {
502            f(value);
503        }
504        if let Some((_, ref mut value)) = self.wildcard {
505            f(value);
506        }
507        for child in self.children.values_mut() {
508            child.for_each_value_mut(f);
509        }
510        for (_, child) in self.regexps.iter_mut() {
511            child.for_each_value_mut(f);
512        }
513    }
514
515    pub fn domain_insert(&mut self, key: Key, value: V) -> InsertResult {
516        self.insert(key, value)
517    }
518
519    pub fn domain_remove(&mut self, key: &Key) -> RemoveResult {
520        self.remove(key)
521    }
522
523    pub fn domain_lookup(&self, key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
524        self.lookup(key, accept_wildcard)
525    }
526
527    pub fn domain_lookup_mut(
528        &mut self,
529        key: &[u8],
530        accept_wildcard: bool,
531    ) -> Option<&mut KeyValue<Key, V>> {
532        self.lookup_mut(key, accept_wildcard)
533    }
534
535    pub fn size(&self) -> usize {
536        ::std::mem::size_of::<TrieNode<V>>()
537            + ::std::mem::size_of::<Option<KeyValue<Key, V>>>() * 2
538            + self
539                .children
540                .iter()
541                .fold(0, |acc, c| acc + c.0.len() + c.1.size())
542    }
543
544    pub fn to_hashmap(&self) -> HashMap<Key, V> {
545        let mut h = HashMap::new();
546
547        self.to_hashmap_recursive(&mut h);
548
549        h
550    }
551
552    pub fn to_hashmap_recursive(&self, h: &mut HashMap<Key, V>) {
553        if let Some((key, value)) = &self.key_value {
554            h.insert(key.clone(), value.clone());
555        }
556
557        if let Some((key, value)) = &self.wildcard {
558            h.insert(key.clone(), value.clone());
559        }
560
561        for child in self.children.values() {
562            child.to_hashmap_recursive(h);
563        }
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570
571    #[test]
572    fn insert() {
573        let mut root: TrieNode<u8> = TrieNode::root();
574        root.print();
575
576        assert_eq!(
577            root.domain_insert(Vec::from(&b"abcd"[..]), 1),
578            InsertResult::Ok
579        );
580        root.print();
581        assert_eq!(
582            root.domain_insert(Vec::from(&b"abce"[..]), 2),
583            InsertResult::Ok
584        );
585        root.print();
586        assert_eq!(
587            root.domain_insert(Vec::from(&b"abgh"[..]), 3),
588            InsertResult::Ok
589        );
590        root.print();
591
592        assert_eq!(
593            root.domain_lookup(&b"abce"[..], true),
594            Some(&(b"abce"[..].to_vec(), 2))
595        );
596        //assert!(false);
597    }
598
599    #[test]
600    fn remove() {
601        let mut root: TrieNode<u8> = TrieNode::root();
602        println!("creating root:");
603        root.print();
604
605        println!("adding (abcd, 1)");
606        assert_eq!(root.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
607        root.print();
608        println!("adding (abce, 2)");
609        assert_eq!(root.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
610        root.print();
611        println!("adding (abgh, 3)");
612        assert_eq!(root.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
613        root.print();
614
615        let mut root2: TrieNode<u8> = TrieNode::root();
616
617        assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
618        assert_eq!(root2.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
619
620        println!("before remove");
621        root.print();
622        assert_eq!(root.remove(&Vec::from(&b"abce"[..])), RemoveResult::Ok);
623        println!("after remove");
624        root.print();
625
626        println!("expected");
627        root2.print();
628        assert_eq!(root, root2);
629
630        assert_eq!(root.remove(&Vec::from(&b"abgh"[..])), RemoveResult::Ok);
631        println!("after remove");
632        root.print();
633        println!("expected");
634        let mut root3: TrieNode<u8> = TrieNode::root();
635        assert_eq!(root3.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
636        root3.print();
637        assert_eq!(root, root3);
638    }
639
640    #[test]
641    fn insert_remove_through_regex() {
642        let mut root: TrieNode<u8> = TrieNode::root();
643        println!("creating root:");
644        root.print();
645
646        println!("adding (www./.*/.com, 1)");
647        assert_eq!(
648            root.insert(Vec::from(&b"www./.*/.com"[..]), 1),
649            InsertResult::Ok
650        );
651        root.print();
652        println!("adding (www.doc./.*/.com, 2)");
653        assert_eq!(
654            root.insert(Vec::from(&b"www.doc./.*/.com"[..]), 2),
655            InsertResult::Ok
656        );
657        root.print();
658        assert_eq!(
659            root.domain_lookup(b"www.sozu.com".as_ref(), false),
660            Some(&(b"www./.*/.com".to_vec(), 1))
661        );
662        assert_eq!(
663            root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
664            Some(&(b"www.doc./.*/.com".to_vec(), 2))
665        );
666
667        assert_eq!(
668            root.domain_remove(&b"www./.*/.com".to_vec()),
669            RemoveResult::Ok
670        );
671        root.print();
672        assert_eq!(root.domain_lookup(b"www.sozu.com".as_ref(), false), None);
673        assert_eq!(
674            root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
675            Some(&(b"www.doc./.*/.com".to_vec(), 2))
676        );
677    }
678
679    /// Segment regexes must match the entire segment, not just a prefix.
680    /// Without `\A...\z` anchoring the previous behaviour matched any
681    /// segment whose prefix satisfied the pattern, silently widening the
682    /// routing surface. This regression test exercises the exact-match
683    /// invariant that anchoring guarantees.
684    #[test]
685    fn segment_regex_rejects_partial_matches() {
686        let mut root: TrieNode<u8> = TrieNode::root();
687        // The regex segment `cdn[0-9]+` must match `cdn1`, `cdn99`, etc.
688        // exactly — never `cdn1xxx` or `xxxcdn1` as a prefix/suffix.
689        assert_eq!(
690            root.insert(Vec::from(&b"/cdn[0-9]+/.example.com"[..]), 7),
691            InsertResult::Ok
692        );
693
694        // Exact-match cases still resolve.
695        assert_eq!(
696            root.domain_lookup(b"cdn1.example.com".as_ref(), false),
697            Some(&(b"/cdn[0-9]+/.example.com".to_vec(), 7))
698        );
699        assert_eq!(
700            root.domain_lookup(b"cdn123.example.com".as_ref(), false),
701            Some(&(b"/cdn[0-9]+/.example.com".to_vec(), 7))
702        );
703
704        // Trailing characters past the digit run must fail. Pre-anchoring
705        // the trie would have matched `cdn1xxx` because `cdn[0-9]+` ate
706        // the `cdn1` prefix; with `\A...\z` the segment is rejected.
707        assert_eq!(
708            root.domain_lookup(b"cdn1xxx.example.com".as_ref(), false),
709            None
710        );
711        // Leading characters likewise must fail.
712        assert_eq!(
713            root.domain_lookup(b"xxxcdn1.example.com".as_ref(), false),
714            None
715        );
716        // Non-digit middle bytes break the digit run and the segment.
717        assert_eq!(
718            root.domain_lookup(b"cdnabc.example.com".as_ref(), false),
719            None
720        );
721    }
722
723    #[test]
724    fn add_child_to_leaf() {
725        let mut root1: TrieNode<u8> = TrieNode::root();
726
727        println!("creating root1:");
728        root1.print();
729        println!("adding (abcd, 1)");
730        assert_eq!(root1.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
731        root1.print();
732        println!("adding (abce, 2)");
733        assert_eq!(root1.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
734        root1.print();
735        println!("adding (abc, 3)");
736        assert_eq!(root1.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
737
738        println!("root1:");
739        root1.print();
740
741        let mut root2: TrieNode<u8> = TrieNode::root();
742
743        assert_eq!(root2.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
744        assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
745        assert_eq!(root2.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
746
747        println!("root2:");
748        root2.print();
749        assert_eq!(root2.remove(&Vec::from(&b"abc"[..])), RemoveResult::Ok);
750
751        println!("root2 after,remove:");
752        root2.print();
753        let mut expected: TrieNode<u8> = TrieNode::root();
754
755        assert_eq!(
756            expected.insert(Vec::from(&b"abcd"[..]), 1),
757            InsertResult::Ok
758        );
759        assert_eq!(
760            expected.insert(Vec::from(&b"abce"[..]), 2),
761            InsertResult::Ok
762        );
763
764        println!("root2 after insert");
765        root2.print();
766        println!("expected");
767        expected.print();
768        assert_eq!(root2, expected);
769    }
770
771    #[test]
772    fn domains() {
773        let mut root: TrieNode<u8> = TrieNode::root();
774        root.print();
775
776        assert_eq!(
777            root.domain_insert(Vec::from(&b"www.example.com"[..]), 1),
778            InsertResult::Ok
779        );
780        root.print();
781        assert_eq!(
782            root.domain_insert(Vec::from(&b"test.example.com"[..]), 2),
783            InsertResult::Ok
784        );
785        root.print();
786        assert_eq!(
787            root.domain_insert(Vec::from(&b"*.alldomains.org"[..]), 3),
788            InsertResult::Ok
789        );
790        root.print();
791        assert_eq!(
792            root.domain_insert(Vec::from(&b"alldomains.org"[..]), 4),
793            InsertResult::Ok
794        );
795        assert_eq!(
796            root.domain_insert(Vec::from(&b"pouet.alldomains.org"[..]), 5),
797            InsertResult::Ok
798        );
799        root.print();
800        assert_eq!(
801            root.domain_insert(Vec::from(&b"hello.com"[..]), 6),
802            InsertResult::Ok
803        );
804        assert_eq!(
805            root.domain_insert(Vec::from(&b"*.hello.com"[..]), 7),
806            InsertResult::Ok
807        );
808        assert_eq!(
809            root.domain_insert(Vec::from(&b"images./cdn[0-9]+/.hello.com"[..]), 8),
810            InsertResult::Ok
811        );
812        root.print();
813        assert_eq!(
814            root.domain_insert(Vec::from(&b"/test[0-9]+/.www.hello.com"[..]), 9),
815            InsertResult::Ok
816        );
817        root.print();
818
819        assert_eq!(root.domain_lookup(&b"example.com"[..], true), None);
820        assert_eq!(
821            root.domain_lookup(&b"blah.test.example.com"[..], true),
822            None
823        );
824        assert_eq!(
825            root.domain_lookup(&b"www.example.com"[..], true),
826            Some(&(b"www.example.com"[..].to_vec(), 1))
827        );
828        assert_eq!(
829            root.domain_lookup(&b"alldomains.org"[..], true),
830            Some(&(b"alldomains.org"[..].to_vec(), 4))
831        );
832        assert_eq!(
833            root.domain_lookup(&b"test.hello.com"[..], true),
834            Some(&(b"*.hello.com"[..].to_vec(), 7))
835        );
836        assert_eq!(
837            root.domain_lookup(&b"images.cdn10.hello.com"[..], true),
838            Some(&(b"images./cdn[0-9]+/.hello.com"[..].to_vec(), 8))
839        );
840        assert_eq!(
841            root.domain_lookup(&b"test42.www.hello.com"[..], true),
842            Some(&(b"/test[0-9]+/.www.hello.com"[..].to_vec(), 9))
843        );
844        assert_eq!(
845            root.domain_lookup(&b"test.alldomains.org"[..], true),
846            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
847        );
848        assert_eq!(
849            root.domain_lookup(&b"hello.alldomains.org"[..], true),
850            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
851        );
852        assert_eq!(
853            root.domain_lookup(&b"pouet.alldomains.org"[..], true),
854            Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
855        );
856        assert_eq!(
857            root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
858            None
859        );
860
861        assert_eq!(
862            root.domain_remove(&Vec::from(&b"alldomains.org"[..])),
863            RemoveResult::Ok
864        );
865        println!("after remove");
866        root.print();
867        assert_eq!(root.domain_lookup(&b"alldomains.org"[..], true), None);
868        assert_eq!(
869            root.domain_lookup(&b"test.alldomains.org"[..], true),
870            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
871        );
872        assert_eq!(
873            root.domain_lookup(&b"hello.alldomains.org"[..], true),
874            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
875        );
876        assert_eq!(
877            root.domain_lookup(&b"pouet.alldomains.org"[..], true),
878            Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
879        );
880        assert_eq!(
881            root.domain_lookup(&b"test.hello.com"[..], true),
882            Some(&(b"*.hello.com"[..].to_vec(), 7))
883        );
884        assert_eq!(
885            root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
886            None
887        );
888    }
889
890    #[test]
891    fn wildcard() {
892        let mut root: TrieNode<u8> = TrieNode::root();
893        root.print();
894        root.domain_insert("*.clever-cloud.com".as_bytes().to_vec(), 2u8);
895        root.domain_insert("services.clever-cloud.com".as_bytes().to_vec(), 0u8);
896        root.domain_insert("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8);
897
898        let res = root.domain_lookup(b"test.services.clever-cloud.com", true);
899        println!("query result: {res:?}");
900
901        assert_eq!(
902            root.domain_lookup(b"pgstudio.services.clever-cloud.com", true),
903            Some(&("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8))
904        );
905    }
906
907    fn hm_insert(h: std::collections::HashMap<String, u32>) -> bool {
908        let mut root: TrieNode<u32> = TrieNode::root();
909
910        for (k, v) in h.iter() {
911            if k.is_empty() {
912                continue;
913            }
914
915            if k.as_bytes()[0] == b'.' {
916                continue;
917            }
918
919            if k.contains('/') {
920                continue;
921            }
922
923            if k == "*" {
924                continue;
925            }
926
927            //println!("inserting key: '{}', value: '{}'", k, v);
928            //assert_eq!(root.domain_insert(Vec::from(k.as_bytes()), *v), InsertResult::Ok);
929            assert_eq!(
930                root.insert(Vec::from(k.as_bytes()), *v),
931                InsertResult::Ok,
932                "could not insert ({k}, {v})"
933            );
934            //root.print();
935        }
936
937        //root.print();
938        for (k, v) in h.iter() {
939            if k.is_empty() {
940                continue;
941            }
942
943            if k.as_bytes()[0] == b'.' {
944                continue;
945            }
946
947            if k.contains('/') {
948                continue;
949            }
950
951            if k == "*" {
952                continue;
953            }
954
955            //match root.domain_lookup(k.as_bytes()) {
956            match root.lookup(k.as_bytes(), false) {
957                None => {
958                    println!("did not find key '{k}'");
959                    return false;
960                }
961                Some(&(ref k1, v1)) => {
962                    if k.as_bytes() != &k1[..] || *v != v1 {
963                        println!(
964                            "request ({}, {}), got ({}, {})",
965                            k,
966                            v,
967                            str::from_utf8(&k1[..]).unwrap(),
968                            v1
969                        );
970                        return false;
971                    }
972                }
973            }
974        }
975
976        true
977    }
978
979    /* FIXME: randomly fails
980    quickcheck! {
981      fn qc_insert(h: std::collections::HashMap<String, u32>) -> bool {
982        hm_insert(h)
983      }
984    }
985    */
986
987    #[test]
988    fn insert_disappearing_tree() {
989        let h: std::collections::HashMap<String, u32> = [
990            (String::from("\n\u{3}"), 0),
991            (String::from("\n\u{0}"), 1),
992            (String::from("\n"), 2),
993        ]
994        .iter()
995        .cloned()
996        .collect();
997        assert!(hm_insert(h));
998    }
999
1000    #[test]
1001    fn size() {
1002        assert_size!(TrieNode<u32>, 136);
1003    }
1004}