Skip to main content

wildcard_trie/
lib.rs

1//! # Radix Trie for URL Routing
2//!
3//! A space-efficient trie that compresses paths by storing common prefixes in single nodes.
4//! This prevents DoS attacks from extremely long segmented paths while maintaining fast lookups.
5//!
6//! This crate supports:
7//! - Wildcard Support: Routes ending in `/*` match any sub-path  
8//! - Fast Lookups: `O(path_length)`` instead of `O(number_of_routes)`
9//! - DoS Resistant: Long paths don't create excessive nodes
10//! - Compressed representation: `/api/v1/users` and `/api/v1/posts` share the `/api/v1/` prefix
11//!
12//! ## Example
13//! ```rust
14//! use wildcard_trie::Trie;
15//!
16//! let mut trie = Trie::new();
17//! trie.insert("/api/*", "api_handler");           // Wildcard
18//! trie.insert("/api/users", "users_handler");     // Exact (takes precedence)
19//!
20//! assert_eq!(trie.get("/api/users"), Some(&"users_handler"));  // Exact match
21//! assert_eq!(trie.get("/api/posts"), Some(&"api_handler"));    // Wildcard match
22//! ```
23
24#[cfg(feature = "debug")]
25mod prettyprint;
26
27use std::collections::HashMap;
28
29/// Suffix that indicates a wildcard route (matches any sub-path)
30const WILDCARD_SUFFIX: &str = "/*";
31
32/// A node in the radix trie that stores a compressed path prefix
33#[derive(Debug, Clone)]
34struct RadixNode<T> {
35    /// The path prefix stored at this node (e.g., "/api/v1")
36    prefix: String,
37    /// Child nodes, indexed by the first character of their prefix
38    children: HashMap<char, RadixNode<T>>,
39    /// Value for exact path matches at this node
40    exact_value: Option<T>,
41    /// Value for wildcard matches (/*) at this node
42    wildcard_value: Option<T>,
43}
44
45impl<T> RadixNode<T> {
46    /// Creates a new node with the given prefix
47    fn new(prefix: String) -> Self {
48        Self {
49            prefix,
50            children: HashMap::new(),
51            exact_value: None,
52            wildcard_value: None,
53        }
54    }
55
56    /// Inserts a value at the given path
57    fn insert(&mut self, path: &str, value: T, is_wildcard: bool) {
58        if path.is_empty() {
59            self.store_value(value, is_wildcard);
60            return;
61        }
62
63        let common_length = self.count_common_prefix_chars(path);
64
65        // Split this node if the path diverges from our prefix
66        if common_length < self.prefix.len() {
67            self.split_at(common_length);
68        }
69
70        // Continue to child or store at current node
71        if common_length < path.len() {
72            self.insert_in_child(&path[common_length..], value, is_wildcard);
73        } else {
74            self.store_value(value, is_wildcard);
75        }
76    }
77
78    /// Retrieves a value for the given path, considering wildcards
79    fn get(&self, path: &str) -> Option<&T> {
80        self.get_with_fallback(path, None)
81    }
82
83    /// Removes a value at the given path
84    fn remove(&mut self, path: &str, is_wildcard: bool) -> Option<T> {
85        if path.is_empty() {
86            return self.take_value(is_wildcard);
87        }
88
89        let common_length = self.count_common_prefix_chars(path);
90        if common_length != self.prefix.len() {
91            return None; // Path doesn't exist
92        }
93
94        let remaining_path = &path[common_length..];
95        if remaining_path.is_empty() {
96            self.take_value(is_wildcard)
97        } else {
98            self.remove_from_child(remaining_path, is_wildcard)
99        }
100    }
101
102    /// Stores a value in the appropriate slot (exact or wildcard)
103    fn store_value(&mut self, value: T, is_wildcard: bool) {
104        if is_wildcard {
105            self.wildcard_value = Some(value);
106        } else {
107            self.exact_value = Some(value);
108        }
109    }
110
111    /// Takes a value from the appropriate slot (exact or wildcard)
112    fn take_value(&mut self, is_wildcard: bool) -> Option<T> {
113        if is_wildcard {
114            self.wildcard_value.take()
115        } else {
116            self.exact_value.take()
117        }
118    }
119
120    /// Counts how many characters this node's prefix shares with the given path
121    fn count_common_prefix_chars(&self, path: &str) -> usize {
122        self.prefix
123            .chars()
124            .zip(path.chars())
125            .take_while(|(a, b)| a == b)
126            .count()
127    }
128
129    /// Retrieves value with wildcard fallback support
130    fn get_with_fallback<'a>(&'a self, path: &str, fallback: Option<&'a T>) -> Option<&'a T> {
131        // Update fallback if we have a wildcard at this level
132        let current_fallback = self.wildcard_value.as_ref().or(fallback);
133
134        if path.is_empty() {
135            return self
136                .exact_value
137                .as_ref()
138                .or(self.wildcard_value.as_ref())
139                .or(fallback);
140        }
141
142        let common_length = self.count_common_prefix_chars(path);
143
144        if common_length == self.prefix.len() {
145            let remaining_path = &path[common_length..];
146
147            if remaining_path.is_empty() {
148                // Exact match at this node
149                self.exact_value
150                    .as_ref()
151                    .or(self.wildcard_value.as_ref())
152                    .or(current_fallback)
153            } else {
154                // Continue searching in children
155                self.search_in_child(remaining_path, current_fallback)
156            }
157        } else {
158            // Partial match - return original fallback, not our wildcard
159            fallback
160        }
161    }
162
163    /// Inserts value in the appropriate child node
164    fn insert_in_child(&mut self, remaining_path: &str, value: T, is_wildcard: bool) {
165        let first_char = remaining_path.chars().next().unwrap();
166        self.children
167            .entry(first_char)
168            .or_insert_with(|| RadixNode::new(remaining_path.to_string()))
169            .insert(remaining_path, value, is_wildcard);
170    }
171
172    /// Searches for a value in child nodes
173    fn search_in_child<'a>(
174        &'a self,
175        remaining_path: &str,
176        fallback: Option<&'a T>,
177    ) -> Option<&'a T> {
178        let first_char = remaining_path.chars().next().unwrap();
179        if let Some(child) = self.children.get(&first_char) {
180            child.get_with_fallback(remaining_path, fallback)
181        } else {
182            fallback
183        }
184    }
185
186    /// Removes value from the appropriate child node
187    fn remove_from_child(&mut self, remaining_path: &str, is_wildcard: bool) -> Option<T> {
188        let first_char = remaining_path.chars().next().unwrap();
189        if let Some(child) = self.children.get_mut(&first_char) {
190            child.remove(remaining_path, is_wildcard)
191        } else {
192            None
193        }
194    }
195
196    /// Splits this node at the given position to accommodate path divergence
197    fn split_at(&mut self, split_position: usize) {
198        if split_position >= self.prefix.len() {
199            return;
200        }
201
202        // Create new child with the suffix
203        let suffix = self.prefix.split_off(split_position);
204        let mut new_child = RadixNode::new(suffix.clone());
205
206        // Move our data to the new child
207        new_child.children = std::mem::take(&mut self.children);
208        new_child.exact_value = self.exact_value.take();
209        new_child.wildcard_value = self.wildcard_value.take();
210
211        // Add the new child
212        let first_char = suffix.chars().next().unwrap();
213        self.children.insert(first_char, new_child);
214    }
215}
216
217/// A radix trie for efficient path-based routing with wildcard support
218#[derive(Debug)]
219pub struct Trie<T>(RadixNode<T>);
220
221impl<T> Default for Trie<T> {
222    fn default() -> Self {
223        Self(RadixNode::new(String::new()))
224    }
225}
226
227impl<T> Trie<T> {
228    /// Creates a new empty trie
229    pub fn new() -> Self {
230        Self::default()
231    }
232
233    /// Inserts a value at the given path
234    ///
235    /// Paths ending with `/*` are treated as wildcard routes that match any sub-path.
236    ///
237    /// # Examples
238    /// ```rust
239    /// # use wildcard_trie::Trie;
240    /// let mut trie = Trie::new();
241    /// trie.insert("/api/users", "users_handler");
242    /// trie.insert("/api/*", "api_fallback");
243    /// ```
244    pub fn insert(&mut self, path: &str, value: T) {
245        let (clean_path, is_wildcard) = Self::parse_path(path);
246        self.0.insert(clean_path, value, is_wildcard);
247    }
248
249    /// Retrieves a value for the given path, with exact > wildcard precedence.
250    ///
251    /// # Examples
252    /// ```rust
253    /// # use wildcard_trie::Trie;
254    /// # let mut trie = Trie::new();
255    /// # trie.insert("/api/users", "users_handler");
256    /// # trie.insert("/api/*", "api_fallback");
257    /// assert_eq!(trie.get("/api/users"), Some(&"users_handler"));  // Exact
258    /// assert_eq!(trie.get("/api/posts"), Some(&"api_fallback"));   // Wildcard
259    /// ```
260    pub fn get<'a>(&'a self, path: &str) -> Option<&'a T> {
261        self.0.get(path)
262    }
263
264    /// Removes a value at the given path, returning it if it existed
265    pub fn remove(&mut self, path: &str) -> Option<T> {
266        let (clean_path, is_wildcard) = Self::parse_path(path);
267        self.0.remove(clean_path, is_wildcard)
268    }
269
270    /// Parses a path to determine if it's a wildcard and extract the clean path
271    fn parse_path(path: &str) -> (&str, bool) {
272        if let Some(prefix) = path.strip_suffix(WILDCARD_SUFFIX) {
273            (prefix, true)
274        } else {
275            (path, false)
276        }
277    }
278
279    /// Checks if the trie is empty
280    fn is_empty(&self) -> bool {
281        self.0.children.is_empty()
282            && self.0.exact_value.is_none()
283            && self.0.wildcard_value.is_none()
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_exact_path_matching() {
293        let mut trie = Trie::new();
294        trie.insert("/api/users", "users_handler");
295        trie.insert("/api/posts", "posts_handler");
296
297        assert_eq!(trie.get("/api/users"), Some(&"users_handler"));
298        assert_eq!(trie.get("/api/posts"), Some(&"posts_handler"));
299        assert_eq!(trie.get("/api/other"), None);
300    }
301
302    #[test]
303    fn test_wildcard_matching() {
304        let mut trie = Trie::new();
305        trie.insert("/api/*", "api_handler");
306
307        assert_eq!(trie.get("/api/users"), Some(&"api_handler"));
308        assert_eq!(trie.get("/api/posts/123"), Some(&"api_handler"));
309        assert_eq!(trie.get("/auth/login"), None);
310    }
311
312    #[test]
313    fn test_exact_takes_precedence_over_wildcard() {
314        let mut trie = Trie::new();
315        trie.insert("/api/*", "wildcard_handler");
316        trie.insert("/api/users", "exact_handler");
317
318        assert_eq!(trie.get("/api/users"), Some(&"exact_handler"));
319        assert_eq!(trie.get("/api/posts"), Some(&"wildcard_handler"));
320    }
321
322    #[test]
323    fn test_path_compression() {
324        let mut trie = Trie::new();
325        trie.insert("/api/v1/users", "v1_users");
326        trie.insert("/api/v1/posts", "v1_posts");
327        trie.insert("/api/v2/users", "v2_users");
328
329        assert_eq!(trie.get("/api/v1/users"), Some(&"v1_users"));
330        assert_eq!(trie.get("/api/v1/posts"), Some(&"v1_posts"));
331        assert_eq!(trie.get("/api/v2/users"), Some(&"v2_users"));
332    }
333
334    #[test]
335    fn test_removal() {
336        let mut trie = Trie::new();
337        trie.insert("/api/users", "handler");
338
339        assert_eq!(trie.get("/api/users"), Some(&"handler"));
340        assert_eq!(trie.remove("/api/users"), Some("handler"));
341        assert_eq!(trie.get("/api/users"), None);
342    }
343
344    #[test]
345    fn test_wildcard_removal() {
346        let mut trie = Trie::new();
347        trie.insert("/api/*", "handler");
348
349        assert_eq!(trie.get("/api/users"), Some(&"handler"));
350        assert_eq!(trie.remove("/api/*"), Some("handler"));
351        assert_eq!(trie.get("/api/users"), None);
352    }
353
354    #[test]
355    fn test_root_path() {
356        let mut trie = Trie::new();
357        trie.insert("/", "root_handler");
358        assert_eq!(trie.get("/"), Some(&"root_handler"));
359    }
360
361    #[test]
362    fn test_root_wildcard() {
363        let mut trie = Trie::new();
364        trie.insert("/*", "root_handler");
365        assert_eq!(trie.get("/"), Some(&"root_handler"));
366    }
367
368    #[test]
369    fn test_empty_path() {
370        let mut trie = Trie::new();
371        trie.insert("", "empty_handler");
372        assert_eq!(trie.get(""), Some(&"empty_handler"));
373    }
374
375    #[test]
376    fn test_common_prefix() {
377        let mut trie = Trie::new();
378        trie.insert("long_prefix_one", "one");
379        trie.insert("long_prefix_two", "two");
380        trie.insert("long_prefix_three", "three");
381
382        assert_eq!(trie.get("long_prefix_one"), Some(&"one"));
383        assert_eq!(trie.get("long_prefix_two"), Some(&"two"));
384        assert_eq!(trie.get("long_prefix_three"), Some(&"three"));
385    }
386}