Skip to main content

sozu_lib/router/
pattern_trie.rs

1use std::{collections::HashMap, fmt::Debug, iter, str};
2
3use regex::bytes::Regex;
4
5pub type Key = Vec<u8>;
6pub type KeyValue<K, V> = (K, V);
7
8#[derive(Debug, PartialEq, Eq)]
9pub enum InsertResult {
10    Ok,
11    Existing,
12    Failed,
13}
14
15#[derive(Debug, PartialEq, Eq)]
16pub enum RemoveResult {
17    Ok,
18    NotFound,
19}
20
21fn find_last_dot(input: &[u8]) -> Option<usize> {
22    //println!("find_last_dot: input = {}", from_utf8(input).unwrap());
23    (0..input.len()).rev().find(|&i| input[i] == b'.')
24}
25
26fn find_last_slash(input: &[u8]) -> Option<usize> {
27    //println!("find_last_dot: input = {}", from_utf8(input).unwrap());
28    (0..input.len()).rev().find(|&i| input[i] == b'/')
29}
30
31/// Implementation of a trie tree structure.
32/// In Sozu this is used to store and lookup domains recursively.
33/// Each node represents a "level domain".
34/// A leaf node (leftmost label) can be a wildcard, a regex pattern or a plain string.
35/// Leaves also store a value associated with the complete domain.
36/// For Sozu it is a list of (PathRule, MethodRule, ClusterId). See the Router strucure.
37#[derive(Debug, Default)]
38pub struct TrieNode<V> {
39    key_value: Option<KeyValue<Key, V>>,
40    wildcard: Option<KeyValue<Key, V>>,
41    children: HashMap<Key, TrieNode<V>>,
42    regexps: Vec<(Regex, TrieNode<V>)>,
43}
44
45/// One step of a trie traversal where a non-literal segment matched.
46///
47/// `Wildcard` carries the actual segment bytes consumed by a `*` wildcard
48/// (so a router that wants to capture them can splice them into a rewrite
49/// template). `Regexp` carries both the matched bytes and the regex itself
50/// so the caller can re-run `Regex::captures` to pull explicit groups.
51#[derive(Debug)]
52pub enum TrieSubMatch<'a, 'b> {
53    Wildcard(&'a [u8]),
54    Regexp(&'a [u8], &'b Regex),
55}
56
57/// Ordered list of non-literal trie segments visited during a successful
58/// `lookup_with_path` traversal. Routers feed the entries into rewrite
59/// templates (`$HOST[n]`) so frontend rewrites can reach into the matched
60/// segments. Empty when only literal segments matched.
61pub type TrieMatches<'a, 'b> = Vec<TrieSubMatch<'a, 'b>>;
62
63impl<V: PartialEq> std::cmp::PartialEq for TrieNode<V> {
64    fn eq(&self, other: &Self) -> bool {
65        self.key_value == other.key_value
66            && self.wildcard == other.wildcard
67            && self.children == other.children
68            && self.regexps.len() == other.regexps.len()
69            && self
70                .regexps
71                .iter()
72                .zip(other.regexps.iter())
73                .fold(true, |b, (left, right)| {
74                    b && left.0.as_str() == right.0.as_str() && left.1 == right.1
75                })
76    }
77}
78
79impl<V: Debug + Clone> TrieNode<V> {
80    pub fn new(key: Key, value: V) -> TrieNode<V> {
81        TrieNode {
82            key_value: Some((key, value)),
83            wildcard: None,
84            children: HashMap::new(),
85            regexps: Vec::new(),
86        }
87    }
88
89    pub fn wildcard(key: Key, value: V) -> TrieNode<V> {
90        TrieNode {
91            key_value: None,
92            wildcard: Some((key, value)),
93            children: HashMap::new(),
94            regexps: Vec::new(),
95        }
96    }
97
98    pub fn root() -> TrieNode<V> {
99        TrieNode {
100            key_value: None,
101            wildcard: None,
102            children: HashMap::new(),
103            regexps: Vec::new(),
104        }
105    }
106
107    pub fn is_empty(&self) -> bool {
108        self.key_value.is_none()
109            && self.wildcard.is_none()
110            && self.regexps.is_empty()
111            && self.children.is_empty()
112    }
113
114    pub fn insert(&mut self, key: Key, value: V) -> InsertResult {
115        //println!("insert: key == {}", std::str::from_utf8(&key).unwrap());
116        if key.is_empty() {
117            return InsertResult::Failed;
118        }
119        if key[..] == b"."[..] {
120            return InsertResult::Failed;
121        }
122
123        #[cfg(debug_assertions)]
124        let before = self.count_values();
125
126        let insert_result = self.insert_recursive(&key, &key, value);
127        assert_ne!(insert_result, InsertResult::Failed);
128
129        // Post: the value count grows by exactly one on a fresh insert
130        // and is unchanged when the key already existed. `Failed` is
131        // ruled out above, so these two are the only reachable cases.
132        #[cfg(debug_assertions)]
133        {
134            let after = self.count_values();
135            match insert_result {
136                InsertResult::Ok => debug_assert_eq!(
137                    after,
138                    before + 1,
139                    "a fresh insert must add exactly one value to the trie",
140                ),
141                InsertResult::Existing => debug_assert_eq!(
142                    after, before,
143                    "an Existing insert must not change the trie value count",
144                ),
145                InsertResult::Failed => {
146                    unreachable!("insert_recursive returned Failed after key validation")
147                }
148            }
149            self.check_invariants();
150        }
151
152        insert_result
153    }
154
155    pub fn insert_recursive(&mut self, partial_key: &[u8], key: &Key, value: V) -> InsertResult {
156        //println!("insert_rec: key == {}", std::str::from_utf8(partial_key).unwrap());
157        assert_ne!(partial_key, &b""[..]);
158        // `partial_key` is always a suffix of the full `key` being
159        // inserted — the recursion only ever shrinks the head, never
160        // rewrites the tail.
161        debug_assert!(
162            partial_key.len() <= key.len(),
163            "insert recursion must consume the key, never grow past it",
164        );
165
166        if partial_key[partial_key.len() - 1] == b'/' {
167            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
168
169            if let Some(pos) = pos {
170                if pos > 0 && partial_key[pos - 1] != b'.' {
171                    return InsertResult::Failed;
172                }
173
174                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
175                    let anchored_s = format!("\\A{s}\\z");
176                    debug_assert!(
177                        anchored_s.starts_with("\\A") && anchored_s.ends_with("\\z"),
178                        "segment regex must be fully anchored so it matches the whole segment only",
179                    );
180                    for t in self.regexps.iter_mut() {
181                        if t.0.as_str() == anchored_s {
182                            // `pos > 0`: there is a `.`-separated prefix
183                            // before this regex segment; recurse on it
184                            // (dropping the leading `.` via `pos - 1`).
185                            // `pos == 0`: the regex is the leftmost/only
186                            // segment, so its subtree is already a
187                            // value-bearing leaf (the create-path below
188                            // built it via `TrieNode::new`); re-inserting
189                            // the same host is `Existing`. The pre-fix code
190                            // did `partial_key[..pos - 1]` unconditionally,
191                            // underflowing to `usize::MAX` and panicking on
192                            // `pos == 0` (same latent bug as `lookup_mut`).
193                            if pos > 0 {
194                                return t.1.insert_recursive(&partial_key[..pos - 1], key, value);
195                            } else {
196                                return InsertResult::Existing;
197                            }
198                        }
199                    }
200
201                    // Anchor segment regexes so they only match the entire
202                    // segment, not partial overlaps. Without `\A...\z`, a
203                    // pattern like `cdn[0-9]+` would match `cdn123xxx`,
204                    // which silently widens the routing surface.
205                    let anchored = format!("\\A{s}\\z");
206                    if let Ok(r) = Regex::new(&anchored) {
207                        if pos > 0 {
208                            let mut node = TrieNode::root();
209                            let pos = pos - 1;
210
211                            let res = node.insert_recursive(&partial_key[..pos], key, value);
212
213                            if res == InsertResult::Ok {
214                                self.regexps.push((r, node));
215                            }
216
217                            return res;
218                        } else {
219                            let node = TrieNode::new(key.to_vec(), value);
220                            self.regexps.push((r, node));
221                            return InsertResult::Ok;
222                        }
223                    }
224                }
225            }
226
227            return InsertResult::Failed;
228        }
229
230        let pos = find_last_dot(partial_key);
231        match pos {
232            None => {
233                if self.children.contains_key(partial_key) {
234                    InsertResult::Existing
235                } else if partial_key == &b"*"[..] {
236                    if self.wildcard.is_some() {
237                        InsertResult::Existing
238                    } else {
239                        self.wildcard = Some((key.to_vec(), value));
240                        InsertResult::Ok
241                    }
242                } else {
243                    let node = TrieNode::new(key.to_vec(), value);
244                    self.children.insert(partial_key.to_vec(), node);
245                    InsertResult::Ok
246                }
247            }
248            Some(pos) => {
249                // The dot at `pos` is kept on the child key (suffix) and
250                // stripped from the recursive prefix; the two slices
251                // partition `partial_key` exactly.
252                debug_assert_eq!(
253                    partial_key[..pos].len() + partial_key[pos..].len(),
254                    partial_key.len(),
255                    "dot-split must partition partial_key without losing bytes",
256                );
257                debug_assert_eq!(
258                    partial_key[pos], b'.',
259                    "find_last_dot must point at a '.' byte",
260                );
261                if let Some(child) = self.children.get_mut(&partial_key[pos..]) {
262                    return child.insert_recursive(&partial_key[..pos], key, value);
263                }
264
265                let mut node = TrieNode::root();
266                let res = node.insert_recursive(&partial_key[..pos], key, value);
267
268                if res == InsertResult::Ok {
269                    self.children.insert(partial_key[pos..].to_vec(), node);
270                }
271
272                res
273            }
274        }
275    }
276
277    pub fn remove(&mut self, key: &Key) -> RemoveResult {
278        #[cfg(debug_assertions)]
279        let before = self.count_values();
280
281        let remove_result = self.remove_recursive(key);
282
283        // Post: a successful remove drops exactly one value; a NotFound
284        // is a no-op on the value count. The structural invariants then
285        // guarantee no emptied subtree was stranded by the prune.
286        #[cfg(debug_assertions)]
287        {
288            let after = self.count_values();
289            match remove_result {
290                RemoveResult::Ok => debug_assert_eq!(
291                    after + 1,
292                    before,
293                    "a successful remove must drop exactly one value from the trie",
294                ),
295                RemoveResult::NotFound => debug_assert_eq!(
296                    after, before,
297                    "a NotFound remove must not change the trie value count",
298                ),
299            }
300            self.check_invariants();
301        }
302
303        remove_result
304    }
305
306    pub fn remove_recursive(&mut self, partial_key: &[u8]) -> RemoveResult {
307        //println!("remove: key == {}", std::str::from_utf8(partial_key).unwrap());
308
309        if partial_key.is_empty() {
310            if self.key_value.is_some() {
311                self.key_value = None;
312                return RemoveResult::Ok;
313            } else {
314                return RemoveResult::NotFound;
315            }
316        }
317
318        if partial_key == &b"*"[..] {
319            if self.wildcard.is_some() {
320                self.wildcard = None;
321                return RemoveResult::Ok;
322            } else {
323                return RemoveResult::NotFound;
324            }
325        }
326
327        if partial_key[partial_key.len() - 1] == b'/' {
328            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
329
330            if let Some(pos) = pos {
331                if pos > 0 && partial_key[pos - 1] != b'.' {
332                    return RemoveResult::NotFound;
333                }
334
335                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
336                    let anchored_s = format!("\\A{s}\\z");
337                    if pos > 0 {
338                        let mut remove_result = RemoveResult::NotFound;
339                        for t in self.regexps.iter_mut() {
340                            if t.0.as_str() == anchored_s
341                                && t.1.remove_recursive(&partial_key[..pos - 1]) == RemoveResult::Ok
342                            {
343                                remove_result = RemoveResult::Ok;
344                            }
345                        }
346                        return remove_result;
347                    } else {
348                        let len = self.regexps.len();
349                        self.regexps.retain(|(r, _)| r.as_str() != anchored_s);
350                        if len > self.regexps.len() {
351                            return RemoveResult::Ok;
352                        }
353                    }
354                }
355            }
356
357            return RemoveResult::NotFound;
358        }
359
360        let pos = find_last_dot(partial_key);
361        let (prefix, suffix) = match pos {
362            None => (&b""[..], partial_key),
363            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
364        };
365        //println!("remove: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
366        debug_assert_eq!(
367            prefix.len() + suffix.len(),
368            partial_key.len(),
369            "dot-split must partition the key without losing or duplicating bytes",
370        );
371
372        match self.children.get_mut(suffix) {
373            Some(child) => match child.remove_recursive(prefix) {
374                RemoveResult::NotFound => RemoveResult::NotFound,
375                RemoveResult::Ok => {
376                    // An emptied child subtree MUST be pruned here so the
377                    // parent never strands a node with no value. After the
378                    // prune the suffix key is gone from `children`.
379                    if child.is_empty() {
380                        self.children.remove(suffix);
381                        debug_assert!(
382                            !self.children.contains_key(suffix),
383                            "an emptied child subtree must be removed from the parent",
384                        );
385                    } else {
386                        // `count_values` is debug-only; gate the whole
387                        // assert so the call does not have to compile in
388                        // release (HARD RULE 2 — E0425 guard).
389                        #[cfg(debug_assertions)]
390                        debug_assert!(
391                            child.count_values() > 0,
392                            "a retained child subtree must still hold at least one value",
393                        );
394                    }
395                    RemoveResult::Ok
396                }
397            },
398            None => RemoveResult::NotFound,
399        }
400    }
401
402    /// Look up `partial_key` and additionally collect the non-literal segments
403    /// that matched along the way (`TrieMatches`).
404    ///
405    /// Equivalent to `lookup` for callers that don't need the captures, but
406    /// frontends with `$HOST[n]` rewrite templates need the matched segments
407    /// to fill the placeholders. The accumulator is passed in by value so
408    /// callers can pre-size it (`Vec::with_capacity`) and we own the path
409    /// returned alongside the value.
410    pub fn lookup_with_path<'a, 'b>(
411        &'b self,
412        partial_key: &'a [u8],
413        accept_wildcard: bool,
414        mut trace: TrieMatches<'a, 'b>,
415    ) -> Option<(&'b KeyValue<Key, V>, TrieMatches<'a, 'b>)> {
416        if partial_key.is_empty() {
417            return self.key_value.as_ref().map(|kv| (kv, trace));
418        }
419
420        let pos = find_last_dot(partial_key);
421        let (prefix, suffix) = match pos {
422            None => (&b""[..], partial_key),
423            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
424        };
425        // The dot-split partitions the key exactly: prefix ++ suffix is
426        // the whole input, and a dotted split puts the `.` at the head
427        // of the suffix (this is the byte the wildcard/regex arms strip).
428        debug_assert_eq!(
429            prefix.len() + suffix.len(),
430            partial_key.len(),
431            "dot-split must partition the key without losing or duplicating bytes",
432        );
433        debug_assert!(
434            pos.is_none() || suffix.first() == Some(&b'.'),
435            "a dotted split must place the separator at the head of the suffix",
436        );
437
438        match self.children.get(suffix) {
439            Some(child) => child.lookup_with_path(prefix, accept_wildcard, trace),
440            None => {
441                if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
442                    let segment = if !suffix.is_empty() && suffix[0] == b'.' {
443                        &suffix[1..]
444                    } else {
445                        suffix
446                    };
447                    trace.push(TrieSubMatch::Wildcard(segment));
448                    self.wildcard.as_ref().map(|kv| (kv, trace))
449                } else {
450                    for (regexp, child) in self.regexps.iter() {
451                        let segment = if !suffix.is_empty() && suffix[0] == b'.' {
452                            &suffix[1..]
453                        } else {
454                            suffix
455                        };
456                        if regexp.is_match(segment) {
457                            let mut next = trace;
458                            next.push(TrieSubMatch::Regexp(segment, regexp));
459                            return child.lookup_with_path(prefix, accept_wildcard, next);
460                        }
461                    }
462                    None
463                }
464            }
465        }
466    }
467
468    pub fn lookup(&self, partial_key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
469        //println!("lookup: key == {}", std::str::from_utf8(partial_key).unwrap());
470
471        if partial_key.is_empty() {
472            return self.key_value.as_ref();
473        }
474
475        let pos = find_last_dot(partial_key);
476        let (prefix, suffix) = match pos {
477            None => (&b""[..], partial_key),
478            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
479        };
480        //println!("lookup: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
481        debug_assert_eq!(
482            prefix.len() + suffix.len(),
483            partial_key.len(),
484            "dot-split must partition the key without losing or duplicating bytes",
485        );
486        debug_assert!(
487            !suffix.is_empty(),
488            "the suffix the trie matches children against must be non-empty",
489        );
490
491        match self.children.get(suffix) {
492            Some(child) => child.lookup(prefix, accept_wildcard),
493            None => {
494                //println!("no child found, testing wildcard and regexps");
495
496                if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
497                    //println!("no dot, wildcard applies");
498                    self.wildcard.as_ref()
499                } else {
500                    //println!("there's still a subdomain, wildcard does not apply");
501
502                    for (regexp, child) in self.regexps.iter() {
503                        let suffix = if suffix[0] == b'.' {
504                            &suffix[1..]
505                        } else {
506                            suffix
507                        };
508                        //println!("testing regexp: {} on suffix {}", r.as_str(), str::from_utf8(s).unwrap());
509
510                        if regexp.is_match(suffix) {
511                            //println!("matched");
512                            return child.lookup(prefix, accept_wildcard);
513                        }
514                    }
515
516                    None
517                }
518            }
519        }
520    }
521
522    pub fn lookup_mut(
523        &mut self,
524        partial_key: &[u8],
525        accept_wildcard: bool,
526    ) -> Option<&mut KeyValue<Key, V>> {
527        //println!("lookup: key == {}", std::str::from_utf8(partial_key).unwrap());
528
529        if partial_key.is_empty() {
530            return self.key_value.as_mut();
531        }
532
533        if partial_key == &b"*"[..] {
534            return self.wildcard.as_mut();
535        }
536
537        if partial_key[partial_key.len() - 1] == b'/' {
538            let pos = find_last_slash(&partial_key[..partial_key.len() - 1]);
539
540            if let Some(pos) = pos {
541                if pos > 0 && partial_key[pos - 1] != b'.' {
542                    return None;
543                }
544
545                if let Ok(s) = str::from_utf8(&partial_key[pos + 1..partial_key.len() - 1]) {
546                    let anchored_s = format!("\\A{s}\\z");
547                    for t in self.regexps.iter_mut() {
548                        if t.0.as_str() == anchored_s {
549                            // `pos == 0` means the regex is the leftmost
550                            // segment and its subtree is a value-bearing
551                            // leaf, reachable via the empty-prefix recursion
552                            // (`lookup_mut(b"")` returns `key_value`). The
553                            // pre-fix `partial_key[..pos - 1]` underflowed to
554                            // `usize::MAX` and panicked on `pos == 0` — the
555                            // same latent bug as the insert dedup loop. Drop
556                            // the leading `.` only when there is one.
557                            let rest = if pos > 0 {
558                                &partial_key[..pos - 1]
559                            } else {
560                                &partial_key[..0]
561                            };
562                            return t.1.lookup_mut(rest, accept_wildcard);
563                        }
564                    }
565                }
566            }
567
568            return None;
569        }
570
571        let pos = find_last_dot(partial_key);
572        let (prefix, suffix) = match pos {
573            None => (&b""[..], partial_key),
574            Some(pos) => (&partial_key[..pos], &partial_key[pos..]),
575        };
576        //println!("lookup: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap());
577        debug_assert_eq!(
578            prefix.len() + suffix.len(),
579            partial_key.len(),
580            "dot-split must partition the key without losing or duplicating bytes",
581        );
582        debug_assert!(
583            !suffix.is_empty(),
584            "the suffix the trie matches children against must be non-empty",
585        );
586
587        match self.children.get_mut(suffix) {
588            Some(child) => child.lookup_mut(prefix, accept_wildcard),
589            None => {
590                //println!("no child found, testing wildcard and regexps");
591
592                if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard {
593                    //println!("no dot, wildcard applies");
594                    self.wildcard.as_mut()
595                } else {
596                    //println!("there's still a subdomain, wildcard does not apply");
597
598                    for &mut (ref regexp, ref mut child) in self.regexps.iter_mut() {
599                        let suffix = if suffix[0] == b'.' {
600                            &suffix[1..]
601                        } else {
602                            suffix
603                        };
604                        //println!("testing regexp: {} on suffix {}", r.as_str(), str::from_utf8(s).unwrap());
605
606                        if regexp.is_match(suffix) {
607                            //println!("matched");
608                            return child.lookup_mut(prefix, accept_wildcard);
609                        }
610                    }
611
612                    None
613                }
614            }
615        }
616    }
617
618    pub fn print(&self) {
619        self.print_recursive(b"", 0)
620    }
621
622    pub fn print_recursive(&self, partial_key: &[u8], indent: u8) {
623        let raw_prefix: Vec<u8> = iter::repeat_n(b' ', 2 * indent as usize).collect();
624        let prefix = str::from_utf8(&raw_prefix).unwrap();
625
626        print!("{}{}: ", prefix, str::from_utf8(partial_key).unwrap());
627        if let Some((ref key, ref value)) = self.key_value {
628            print!("({}, {:?}) | ", str::from_utf8(key).unwrap(), value);
629        } else {
630            print!("None | ");
631        }
632
633        if let Some((key, value)) = &self.wildcard {
634            println!("({}, {:?})", str::from_utf8(key).unwrap(), value);
635        } else {
636            println!("None");
637        }
638
639        for (child_key, child) in self.children.iter() {
640            child.print_recursive(child_key, indent + 1);
641        }
642
643        for (regexp, child) in self.regexps.iter() {
644            //print!("{}{}:", prefix, regexp.as_str());
645            child.print_recursive(regexp.as_str().as_bytes(), indent + 1);
646        }
647    }
648
649    /// Visit every stored value in the trie (the literal `key_value` and
650    /// the leftmost `wildcard` slot of every node, plus all
651    /// regex-subtree leaves) and invoke `f` on each. Used by the router
652    /// to walk all routes for cross-cutting refreshes (e.g. listener-
653    /// default HSTS reflow) without rebuilding the trie.
654    pub fn for_each_value_mut<F: FnMut(&mut V)>(&mut self, f: &mut F) {
655        if let Some((_, ref mut value)) = self.key_value {
656            f(value);
657        }
658        if let Some((_, ref mut value)) = self.wildcard {
659            f(value);
660        }
661        for child in self.children.values_mut() {
662            child.for_each_value_mut(f);
663        }
664        for (_, child) in self.regexps.iter_mut() {
665            child.for_each_value_mut(f);
666        }
667    }
668
669    /// Count every value slot reachable from this node: the literal
670    /// `key_value`, the leftmost `wildcard`, plus all values stored in
671    /// child subtrees and regex subtrees. Used only by the
672    /// `#[cfg(debug_assertions)]` invariant checks as the leaf-count
673    /// accounting (`inserts − removes`); never called in release.
674    #[cfg(debug_assertions)]
675    fn count_values(&self) -> usize {
676        let local = self.key_value.is_some() as usize + self.wildcard.is_some() as usize;
677        let in_children: usize = self.children.values().map(TrieNode::count_values).sum();
678        let in_regexps: usize = self.regexps.iter().map(|(_, c)| c.count_values()).sum();
679        local + in_children + in_regexps
680    }
681
682    /// Full structural invariant sweep for the trie, asserted as a
683    /// run-to-completion postcondition at the end of every mutating
684    /// public operation. Encodes the cross-field invariants that the
685    /// recursive insert/remove logic must preserve:
686    ///
687    /// - **No stranded interior node**: every non-root node reachable
688    ///   through `children` / `regexps` must hold a value somewhere in
689    ///   its subtree (`!is_empty()` and `count_values() > 0`). A node
690    ///   that holds neither a value nor any descendant value is a leak —
691    ///   `remove_recursive` is supposed to prune it via `is_empty()`.
692    /// - **Unique regex segments**: the anchored pattern strings stored
693    ///   in `regexps` are unique within a node (insert dedups by
694    ///   `as_str()` before pushing a new subtree).
695    /// - **Child-key invariant**: no child is keyed by the empty slice.
696    ///
697    /// `debug_assertions`-only; compiled out of release builds.
698    #[cfg(debug_assertions)]
699    fn check_invariants(&self) {
700        // Regex segment patterns are unique per node.
701        for i in 0..self.regexps.len() {
702            for j in (i + 1)..self.regexps.len() {
703                debug_assert_ne!(
704                    self.regexps[i].0.as_str(),
705                    self.regexps[j].0.as_str(),
706                    "trie node must not hold two subtrees for the same regex segment",
707                );
708            }
709        }
710
711        for (child_key, child) in self.children.iter() {
712            debug_assert!(
713                !child_key.is_empty(),
714                "trie child must not be keyed by the empty segment",
715            );
716            // A child subtree that has been fully emptied must have been
717            // pruned by remove_recursive; reaching it here means a
718            // subtree was stranded.
719            debug_assert!(
720                !child.is_empty(),
721                "trie must not strand an empty child subtree (remove must prune)",
722            );
723            debug_assert!(
724                child.count_values() > 0,
725                "trie child subtree must lead to at least one value",
726            );
727            child.check_invariants();
728        }
729
730        for (_, child) in self.regexps.iter() {
731            debug_assert!(
732                !child.is_empty(),
733                "trie must not strand an empty regex subtree (remove must prune)",
734            );
735            debug_assert!(
736                child.count_values() > 0,
737                "trie regex subtree must lead to at least one value",
738            );
739            child.check_invariants();
740        }
741    }
742
743    pub fn domain_insert(&mut self, key: Key, value: V) -> InsertResult {
744        self.insert(key, value)
745    }
746
747    pub fn domain_remove(&mut self, key: &Key) -> RemoveResult {
748        self.remove(key)
749    }
750
751    pub fn domain_lookup(&self, key: &[u8], accept_wildcard: bool) -> Option<&KeyValue<Key, V>> {
752        self.lookup(key, accept_wildcard)
753    }
754
755    pub fn domain_lookup_mut(
756        &mut self,
757        key: &[u8],
758        accept_wildcard: bool,
759    ) -> Option<&mut KeyValue<Key, V>> {
760        self.lookup_mut(key, accept_wildcard)
761    }
762
763    pub fn size(&self) -> usize {
764        ::std::mem::size_of::<TrieNode<V>>()
765            + ::std::mem::size_of::<Option<KeyValue<Key, V>>>() * 2
766            + self
767                .children
768                .iter()
769                .fold(0, |acc, c| acc + c.0.len() + c.1.size())
770    }
771
772    pub fn to_hashmap(&self) -> HashMap<Key, V> {
773        let mut h = HashMap::new();
774
775        self.to_hashmap_recursive(&mut h);
776
777        h
778    }
779
780    pub fn to_hashmap_recursive(&self, h: &mut HashMap<Key, V>) {
781        if let Some((key, value)) = &self.key_value {
782            h.insert(key.clone(), value.clone());
783        }
784
785        if let Some((key, value)) = &self.wildcard {
786            h.insert(key.clone(), value.clone());
787        }
788
789        for child in self.children.values() {
790            child.to_hashmap_recursive(h);
791        }
792    }
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798
799    #[test]
800    fn insert() {
801        let mut root: TrieNode<u8> = TrieNode::root();
802        root.print();
803
804        assert_eq!(
805            root.domain_insert(Vec::from(&b"abcd"[..]), 1),
806            InsertResult::Ok
807        );
808        root.print();
809        assert_eq!(
810            root.domain_insert(Vec::from(&b"abce"[..]), 2),
811            InsertResult::Ok
812        );
813        root.print();
814        assert_eq!(
815            root.domain_insert(Vec::from(&b"abgh"[..]), 3),
816            InsertResult::Ok
817        );
818        root.print();
819
820        assert_eq!(
821            root.domain_lookup(&b"abce"[..], true),
822            Some(&(b"abce"[..].to_vec(), 2))
823        );
824        //assert!(false);
825    }
826
827    #[test]
828    fn remove() {
829        let mut root: TrieNode<u8> = TrieNode::root();
830        println!("creating root:");
831        root.print();
832
833        println!("adding (abcd, 1)");
834        assert_eq!(root.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
835        root.print();
836        println!("adding (abce, 2)");
837        assert_eq!(root.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
838        root.print();
839        println!("adding (abgh, 3)");
840        assert_eq!(root.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
841        root.print();
842
843        let mut root2: TrieNode<u8> = TrieNode::root();
844
845        assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
846        assert_eq!(root2.insert(Vec::from(&b"abgh"[..]), 3), InsertResult::Ok);
847
848        println!("before remove");
849        root.print();
850        assert_eq!(root.remove(&Vec::from(&b"abce"[..])), RemoveResult::Ok);
851        println!("after remove");
852        root.print();
853
854        println!("expected");
855        root2.print();
856        assert_eq!(root, root2);
857
858        assert_eq!(root.remove(&Vec::from(&b"abgh"[..])), RemoveResult::Ok);
859        println!("after remove");
860        root.print();
861        println!("expected");
862        let mut root3: TrieNode<u8> = TrieNode::root();
863        assert_eq!(root3.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
864        root3.print();
865        assert_eq!(root, root3);
866    }
867
868    #[test]
869    fn insert_remove_through_regex() {
870        let mut root: TrieNode<u8> = TrieNode::root();
871        println!("creating root:");
872        root.print();
873
874        println!("adding (www./.*/.com, 1)");
875        assert_eq!(
876            root.insert(Vec::from(&b"www./.*/.com"[..]), 1),
877            InsertResult::Ok
878        );
879        root.print();
880        println!("adding (www.doc./.*/.com, 2)");
881        assert_eq!(
882            root.insert(Vec::from(&b"www.doc./.*/.com"[..]), 2),
883            InsertResult::Ok
884        );
885        root.print();
886        assert_eq!(
887            root.domain_lookup(b"www.sozu.com".as_ref(), false),
888            Some(&(b"www./.*/.com".to_vec(), 1))
889        );
890        assert_eq!(
891            root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
892            Some(&(b"www.doc./.*/.com".to_vec(), 2))
893        );
894
895        assert_eq!(
896            root.domain_remove(&b"www./.*/.com".to_vec()),
897            RemoveResult::Ok
898        );
899        root.print();
900        assert_eq!(root.domain_lookup(b"www.sozu.com".as_ref(), false), None);
901        assert_eq!(
902            root.domain_lookup(b"www.doc.sozu.com".as_ref(), false),
903            Some(&(b"www.doc./.*/.com".to_vec(), 2))
904        );
905    }
906
907    /// Segment regexes must match the entire segment, not just a prefix.
908    /// Without `\A...\z` anchoring the previous behaviour matched any
909    /// segment whose prefix satisfied the pattern, silently widening the
910    /// routing surface. This regression test exercises the exact-match
911    /// invariant that anchoring guarantees.
912    #[test]
913    fn segment_regex_rejects_partial_matches() {
914        let mut root: TrieNode<u8> = TrieNode::root();
915        // The regex segment `cdn[0-9]+` must match `cdn1`, `cdn99`, etc.
916        // exactly — never `cdn1xxx` or `xxxcdn1` as a prefix/suffix.
917        assert_eq!(
918            root.insert(Vec::from(&b"/cdn[0-9]+/.example.com"[..]), 7),
919            InsertResult::Ok
920        );
921
922        // Exact-match cases still resolve.
923        assert_eq!(
924            root.domain_lookup(b"cdn1.example.com".as_ref(), false),
925            Some(&(b"/cdn[0-9]+/.example.com".to_vec(), 7))
926        );
927        assert_eq!(
928            root.domain_lookup(b"cdn123.example.com".as_ref(), false),
929            Some(&(b"/cdn[0-9]+/.example.com".to_vec(), 7))
930        );
931
932        // Trailing characters past the digit run must fail. Pre-anchoring
933        // the trie would have matched `cdn1xxx` because `cdn[0-9]+` ate
934        // the `cdn1` prefix; with `\A...\z` the segment is rejected.
935        assert_eq!(
936            root.domain_lookup(b"cdn1xxx.example.com".as_ref(), false),
937            None
938        );
939        // Leading characters likewise must fail.
940        assert_eq!(
941            root.domain_lookup(b"xxxcdn1.example.com".as_ref(), false),
942            None
943        );
944        // Non-digit middle bytes break the digit run and the segment.
945        assert_eq!(
946            root.domain_lookup(b"cdnabc.example.com".as_ref(), false),
947            None
948        );
949    }
950
951    #[test]
952    fn add_child_to_leaf() {
953        let mut root1: TrieNode<u8> = TrieNode::root();
954
955        println!("creating root1:");
956        root1.print();
957        println!("adding (abcd, 1)");
958        assert_eq!(root1.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
959        root1.print();
960        println!("adding (abce, 2)");
961        assert_eq!(root1.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
962        root1.print();
963        println!("adding (abc, 3)");
964        assert_eq!(root1.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
965
966        println!("root1:");
967        root1.print();
968
969        let mut root2: TrieNode<u8> = TrieNode::root();
970
971        assert_eq!(root2.insert(Vec::from(&b"abc"[..]), 3), InsertResult::Ok);
972        assert_eq!(root2.insert(Vec::from(&b"abcd"[..]), 1), InsertResult::Ok);
973        assert_eq!(root2.insert(Vec::from(&b"abce"[..]), 2), InsertResult::Ok);
974
975        println!("root2:");
976        root2.print();
977        assert_eq!(root2.remove(&Vec::from(&b"abc"[..])), RemoveResult::Ok);
978
979        println!("root2 after,remove:");
980        root2.print();
981        let mut expected: TrieNode<u8> = TrieNode::root();
982
983        assert_eq!(
984            expected.insert(Vec::from(&b"abcd"[..]), 1),
985            InsertResult::Ok
986        );
987        assert_eq!(
988            expected.insert(Vec::from(&b"abce"[..]), 2),
989            InsertResult::Ok
990        );
991
992        println!("root2 after insert");
993        root2.print();
994        println!("expected");
995        expected.print();
996        assert_eq!(root2, expected);
997    }
998
999    #[test]
1000    fn domains() {
1001        let mut root: TrieNode<u8> = TrieNode::root();
1002        root.print();
1003
1004        assert_eq!(
1005            root.domain_insert(Vec::from(&b"www.example.com"[..]), 1),
1006            InsertResult::Ok
1007        );
1008        root.print();
1009        assert_eq!(
1010            root.domain_insert(Vec::from(&b"test.example.com"[..]), 2),
1011            InsertResult::Ok
1012        );
1013        root.print();
1014        assert_eq!(
1015            root.domain_insert(Vec::from(&b"*.alldomains.org"[..]), 3),
1016            InsertResult::Ok
1017        );
1018        root.print();
1019        assert_eq!(
1020            root.domain_insert(Vec::from(&b"alldomains.org"[..]), 4),
1021            InsertResult::Ok
1022        );
1023        assert_eq!(
1024            root.domain_insert(Vec::from(&b"pouet.alldomains.org"[..]), 5),
1025            InsertResult::Ok
1026        );
1027        root.print();
1028        assert_eq!(
1029            root.domain_insert(Vec::from(&b"hello.com"[..]), 6),
1030            InsertResult::Ok
1031        );
1032        assert_eq!(
1033            root.domain_insert(Vec::from(&b"*.hello.com"[..]), 7),
1034            InsertResult::Ok
1035        );
1036        assert_eq!(
1037            root.domain_insert(Vec::from(&b"images./cdn[0-9]+/.hello.com"[..]), 8),
1038            InsertResult::Ok
1039        );
1040        root.print();
1041        assert_eq!(
1042            root.domain_insert(Vec::from(&b"/test[0-9]+/.www.hello.com"[..]), 9),
1043            InsertResult::Ok
1044        );
1045        root.print();
1046
1047        assert_eq!(root.domain_lookup(&b"example.com"[..], true), None);
1048        assert_eq!(
1049            root.domain_lookup(&b"blah.test.example.com"[..], true),
1050            None
1051        );
1052        assert_eq!(
1053            root.domain_lookup(&b"www.example.com"[..], true),
1054            Some(&(b"www.example.com"[..].to_vec(), 1))
1055        );
1056        assert_eq!(
1057            root.domain_lookup(&b"alldomains.org"[..], true),
1058            Some(&(b"alldomains.org"[..].to_vec(), 4))
1059        );
1060        assert_eq!(
1061            root.domain_lookup(&b"test.hello.com"[..], true),
1062            Some(&(b"*.hello.com"[..].to_vec(), 7))
1063        );
1064        assert_eq!(
1065            root.domain_lookup(&b"images.cdn10.hello.com"[..], true),
1066            Some(&(b"images./cdn[0-9]+/.hello.com"[..].to_vec(), 8))
1067        );
1068        assert_eq!(
1069            root.domain_lookup(&b"test42.www.hello.com"[..], true),
1070            Some(&(b"/test[0-9]+/.www.hello.com"[..].to_vec(), 9))
1071        );
1072        assert_eq!(
1073            root.domain_lookup(&b"test.alldomains.org"[..], true),
1074            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
1075        );
1076        assert_eq!(
1077            root.domain_lookup(&b"hello.alldomains.org"[..], true),
1078            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
1079        );
1080        assert_eq!(
1081            root.domain_lookup(&b"pouet.alldomains.org"[..], true),
1082            Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
1083        );
1084        assert_eq!(
1085            root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
1086            None
1087        );
1088
1089        assert_eq!(
1090            root.domain_remove(&Vec::from(&b"alldomains.org"[..])),
1091            RemoveResult::Ok
1092        );
1093        println!("after remove");
1094        root.print();
1095        assert_eq!(root.domain_lookup(&b"alldomains.org"[..], true), None);
1096        assert_eq!(
1097            root.domain_lookup(&b"test.alldomains.org"[..], true),
1098            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
1099        );
1100        assert_eq!(
1101            root.domain_lookup(&b"hello.alldomains.org"[..], true),
1102            Some(&(b"*.alldomains.org"[..].to_vec(), 3))
1103        );
1104        assert_eq!(
1105            root.domain_lookup(&b"pouet.alldomains.org"[..], true),
1106            Some(&(b"pouet.alldomains.org"[..].to_vec(), 5))
1107        );
1108        assert_eq!(
1109            root.domain_lookup(&b"test.hello.com"[..], true),
1110            Some(&(b"*.hello.com"[..].to_vec(), 7))
1111        );
1112        assert_eq!(
1113            root.domain_lookup(&b"blah.test.alldomains.org"[..], true),
1114            None
1115        );
1116    }
1117
1118    #[test]
1119    fn wildcard() {
1120        let mut root: TrieNode<u8> = TrieNode::root();
1121        root.print();
1122        root.domain_insert("*.clever-cloud.com".as_bytes().to_vec(), 2u8);
1123        root.domain_insert("services.clever-cloud.com".as_bytes().to_vec(), 0u8);
1124        root.domain_insert("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8);
1125
1126        let res = root.domain_lookup(b"test.services.clever-cloud.com", true);
1127        println!("query result: {res:?}");
1128
1129        assert_eq!(
1130            root.domain_lookup(b"pgstudio.services.clever-cloud.com", true),
1131            Some(&("*.services.clever-cloud.com".as_bytes().to_vec(), 1u8))
1132        );
1133    }
1134
1135    fn hm_insert(h: std::collections::HashMap<String, u32>) -> bool {
1136        let mut root: TrieNode<u32> = TrieNode::root();
1137
1138        for (k, v) in h.iter() {
1139            if k.is_empty() {
1140                continue;
1141            }
1142
1143            if k.as_bytes()[0] == b'.' {
1144                continue;
1145            }
1146
1147            if k.contains('/') {
1148                continue;
1149            }
1150
1151            if k == "*" {
1152                continue;
1153            }
1154
1155            //println!("inserting key: '{}', value: '{}'", k, v);
1156            //assert_eq!(root.domain_insert(Vec::from(k.as_bytes()), *v), InsertResult::Ok);
1157            assert_eq!(
1158                root.insert(Vec::from(k.as_bytes()), *v),
1159                InsertResult::Ok,
1160                "could not insert ({k}, {v})"
1161            );
1162            //root.print();
1163        }
1164
1165        //root.print();
1166        for (k, v) in h.iter() {
1167            if k.is_empty() {
1168                continue;
1169            }
1170
1171            if k.as_bytes()[0] == b'.' {
1172                continue;
1173            }
1174
1175            if k.contains('/') {
1176                continue;
1177            }
1178
1179            if k == "*" {
1180                continue;
1181            }
1182
1183            //match root.domain_lookup(k.as_bytes()) {
1184            match root.lookup(k.as_bytes(), false) {
1185                None => {
1186                    println!("did not find key '{k}'");
1187                    return false;
1188                }
1189                Some(&(ref k1, v1)) => {
1190                    if k.as_bytes() != &k1[..] || *v != v1 {
1191                        println!(
1192                            "request ({}, {}), got ({}, {})",
1193                            k,
1194                            v,
1195                            str::from_utf8(&k1[..]).unwrap(),
1196                            v1
1197                        );
1198                        return false;
1199                    }
1200                }
1201            }
1202        }
1203
1204        true
1205    }
1206
1207    /* FIXME: randomly fails
1208    quickcheck! {
1209      fn qc_insert(h: std::collections::HashMap<String, u32>) -> bool {
1210        hm_insert(h)
1211      }
1212    }
1213    */
1214
1215    #[test]
1216    fn insert_disappearing_tree() {
1217        let h: std::collections::HashMap<String, u32> = [
1218            (String::from("\n\u{3}"), 0),
1219            (String::from("\n\u{0}"), 1),
1220            (String::from("\n"), 2),
1221        ]
1222        .iter()
1223        .cloned()
1224        .collect();
1225        assert!(hm_insert(h));
1226    }
1227
1228    #[test]
1229    fn size() {
1230        assert_size!(TrieNode<u32>, 136);
1231    }
1232
1233    /// Regression: a hostname whose LEFTMOST segment is a regex
1234    /// (`/test[0-9]/.example.com`) used to underflow `pos - 1` (to
1235    /// `usize::MAX`) and panic on the second insert (the dedup loop) and
1236    /// on any `lookup_mut`. Both paths now special-case `pos == 0`
1237    /// (regex is the leftmost/only segment → value-bearing leaf). This
1238    /// asserts the panic is gone and the entry resolves correctly.
1239    #[test]
1240    fn leftmost_regex_segment_reinsert_and_lookup_mut_do_not_panic() {
1241        let mut root: TrieNode<u8> = TrieNode::root();
1242
1243        assert_eq!(
1244            root.insert(Vec::from(&b"/test[0-9]/.example.com"[..]), 7),
1245            InsertResult::Ok
1246        );
1247        // Second insert of the SAME leftmost-regex host: dedup loop with
1248        // pos == 0. Previously panicked; must now report Existing.
1249        assert_eq!(
1250            root.insert(Vec::from(&b"/test[0-9]/.example.com"[..]), 8),
1251            InsertResult::Existing
1252        );
1253
1254        // lookup_mut on the existing leftmost-regex host: previously
1255        // panicked at `partial_key[..pos - 1]`; must now resolve the leaf
1256        // (value unchanged from the first insert — Existing did not
1257        // overwrite).
1258        let resolved = root.domain_lookup_mut(b"test4.example.com", false);
1259        assert_eq!(
1260            resolved.map(|(_, v)| *v),
1261            Some(7),
1262            "leftmost-regex host must resolve via lookup_mut without panicking",
1263        );
1264
1265        // The immutable lookup path (never buggy) agrees.
1266        assert_eq!(
1267            root.domain_lookup(b"test4.example.com", false),
1268            Some(&(b"/test[0-9]/.example.com"[..].to_vec(), 7))
1269        );
1270
1271        // Removing the last rule clears the host.
1272        assert_eq!(
1273            root.domain_remove(&Vec::from(&b"/test[0-9]/.example.com"[..])),
1274            RemoveResult::Ok
1275        );
1276        assert_eq!(root.domain_lookup(b"test4.example.com", false), None);
1277    }
1278}