spacegate_kernel/service/http_route/
match_hostname.rs

1//! # Match Hostnames
2//!
3//! ## Priority
4//!
5//! | Priority |   Rule             |  Example              |
6//! |:--------:|:-------------------|:----------------------|
7//! |  0       | Exact Host         |  example.com          |
8//! |  1       | Partial wildcard   |  *.example.com        |
9//! |  2       | Wild Card          |  *                    |
10//!
11//! it would be a tree like this:
12//!
13//! ```text
14//! com
15//! |
16//! +- example
17//! |  |
18//! |  +- next
19//! |  \- *
20//! |       
21//! \- *
22//!
23//!
24//! ```
25
26use std::{
27    collections::BTreeMap,
28    fmt::{self},
29    net::{Ipv4Addr, Ipv6Addr},
30};
31
32#[derive(Debug, Clone)]
33pub struct HostnameTree<T> {
34    pub(crate) ipv4: BTreeMap<Ipv4Addr, T>,
35    pub(crate) ipv6: BTreeMap<Ipv6Addr, T>,
36    pub(crate) host: HostnameMatcherNode<T>,
37    pub(crate) fallback: Option<T>,
38}
39
40impl<T> Default for HostnameTree<T> {
41    fn default() -> Self {
42        Self {
43            ipv4: BTreeMap::new(),
44            ipv6: BTreeMap::new(),
45            host: HostnameMatcherNode::new(),
46            fallback: None,
47        }
48    }
49}
50
51impl<T> HostnameTree<T> {
52    pub fn new() -> Self {
53        Self::default()
54    }
55    pub fn iter(&self) -> HostnameTreeIter<T> {
56        HostnameTreeIter {
57            ipv4: self.ipv4.values(),
58            ipv6: self.ipv6.values(),
59            host: self.host.iter(),
60            fallback: self.fallback.iter(),
61        }
62    }
63    pub fn iter_mut(&mut self) -> HostnameTreeIterMut<T> {
64        HostnameTreeIterMut {
65            ipv4: self.ipv4.values_mut(),
66            ipv6: self.ipv6.values_mut(),
67            host: self.host.iter_mut(),
68            fallback: self.fallback.iter_mut(),
69        }
70    }
71    #[allow(clippy::indexing_slicing)]
72    pub fn get(&self, host: &str) -> Option<&T> {
73        // trim port
74        let data = if host.starts_with('[') {
75            let bracket_end = host.find(']')?;
76            let ipv6 = host[1..bracket_end].parse::<Ipv6Addr>().ok()?;
77            self.ipv6.get(&ipv6)
78        } else {
79            let host = host.rsplit_once(':').map(|(host, _)| host).unwrap_or(host);
80            if let Ok(ipv4) = host.parse::<Ipv4Addr>() {
81                self.ipv4.get(&ipv4)
82            } else {
83                self.host.get(host)
84            }
85        };
86        data.or(self.fallback.as_ref())
87    }
88    #[allow(clippy::indexing_slicing)]
89    pub fn get_mut(&mut self, host: &str) -> Option<&mut T> {
90        // trim port
91        let data = if host.starts_with('[') {
92            let bracket_end = host.find(']')?;
93            let ipv6 = host[1..bracket_end].parse::<Ipv6Addr>().ok()?;
94            self.ipv6.get_mut(&ipv6)
95        } else {
96            let host = host.rsplit_once(':').map(|(host, _)| host).unwrap_or(host);
97            if let Ok(ipv4) = host.parse::<Ipv4Addr>() {
98                self.ipv4.get_mut(&ipv4)
99            } else {
100                self.host.get_mut(host)
101            }
102        };
103        data.or(self.fallback.as_mut())
104    }
105    pub fn set(&mut self, host: &str, data: T) {
106        if host == "*" {
107            self.fallback = Some(data);
108            return;
109        }
110        if host.starts_with('[') {
111            if let Some(ipv6) = host.strip_prefix('[').and_then(|host| host.strip_suffix(']')) {
112                if let Ok(ipv6) = ipv6.parse::<Ipv6Addr>() {
113                    self.ipv6.insert(ipv6, data);
114                }
115            }
116        } else if let Ok(ipv4) = host.parse::<Ipv4Addr>() {
117            self.ipv4.insert(ipv4, data);
118        } else {
119            self.host.set(host, data);
120        }
121    }
122}
123
124#[derive(Debug)]
125pub struct HostnameTreeIter<'a, T> {
126    ipv4: std::collections::btree_map::Values<'a, Ipv4Addr, T>,
127    ipv6: std::collections::btree_map::Values<'a, Ipv6Addr, T>,
128    host: HostnameMatcherNodeIter<'a, T>,
129    fallback: std::option::Iter<'a, T>,
130}
131
132impl<'a, T: 'a> Iterator for HostnameTreeIter<'a, T> {
133    type Item = &'a T;
134
135    fn next(&mut self) -> Option<Self::Item> {
136        if let Some(data) = self.ipv4.next() {
137            return Some(data);
138        }
139        if let Some(data) = self.ipv6.next() {
140            return Some(data);
141        }
142        if let Some(data) = self.host.next() {
143            return Some(data);
144        }
145        self.fallback.next()
146    }
147}
148
149#[derive(Debug)]
150pub struct HostnameTreeIterMut<'a, T> {
151    ipv4: std::collections::btree_map::ValuesMut<'a, Ipv4Addr, T>,
152    ipv6: std::collections::btree_map::ValuesMut<'a, Ipv6Addr, T>,
153    host: HostnameMatcherNodeIterMut<'a, T>,
154    fallback: std::option::IterMut<'a, T>,
155}
156
157impl<'a, T: 'a> Iterator for HostnameTreeIterMut<'a, T> {
158    type Item = &'a mut T;
159
160    fn next(&mut self) -> Option<Self::Item> {
161        if let Some(data) = self.ipv4.next() {
162            return Some(data);
163        }
164        if let Some(data) = self.ipv6.next() {
165            return Some(data);
166        }
167        if let Some(data) = self.host.next() {
168            return Some(data);
169        }
170        self.fallback.next()
171    }
172}
173
174/// we don't neet a radix tree here, because host name won't be too long
175#[derive(Clone)]
176pub struct HostnameMatcherNode<T> {
177    data: Option<T>,
178    children: BTreeMap<String, HostnameMatcherNode<T>>,
179    /// for * match
180    else_node: Option<Box<HostnameMatcherNode<T>>>,
181}
182
183#[derive(Debug)]
184pub struct HostnameMatcherNodeIter<'a, T: 'a> {
185    data: std::option::Iter<'a, T>,
186    children: std::collections::btree_map::Values<'a, String, HostnameMatcherNode<T>>,
187    else_node: Option<Box<HostnameMatcherNodeIter<'a, T>>>,
188}
189
190impl<'a, T: 'a> Iterator for HostnameMatcherNodeIter<'a, T> {
191    type Item = &'a T;
192
193    fn next(&mut self) -> Option<Self::Item> {
194        if let Some(data) = self.data.next() {
195            return Some(data);
196        }
197        if let Some(node) = self.children.next() {
198            return node.iter().next();
199        }
200        if let Some(node) = self.else_node.as_mut() {
201            return node.next();
202        }
203        None
204    }
205}
206
207#[derive(Debug)]
208pub struct HostnameMatcherNodeIterMut<'a, T: 'a> {
209    data: std::option::IterMut<'a, T>,
210    children: std::collections::btree_map::ValuesMut<'a, String, HostnameMatcherNode<T>>,
211    else_node: Option<Box<HostnameMatcherNodeIterMut<'a, T>>>,
212}
213
214impl<'a, T: 'a> Iterator for HostnameMatcherNodeIterMut<'a, T> {
215    type Item = &'a mut T;
216
217    fn next(&mut self) -> Option<Self::Item> {
218        if let Some(data) = self.data.next() {
219            return Some(data);
220        }
221        if let Some(node) = self.children.next() {
222            return node.iter_mut().next();
223        }
224        if let Some(node) = self.else_node.as_mut() {
225            return node.next();
226        }
227        None
228    }
229}
230
231impl<T> Default for HostnameMatcherNode<T> {
232    fn default() -> Self {
233        Self {
234            data: None,
235            children: BTreeMap::new(),
236            else_node: None,
237        }
238    }
239}
240
241impl<T: fmt::Debug> fmt::Debug for HostnameMatcherNode<T> {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        let mut df = f.debug_struct("HostnameMatcherNode");
244        if let Some(data) = &self.data {
245            df.field("_", data);
246        }
247        for (key, node) in &self.children {
248            df.field(key, node);
249        }
250        if let Some(node) = &self.else_node {
251            df.field("*", node);
252        }
253        df.finish()
254    }
255}
256
257impl<T> HostnameMatcherNode<T> {
258    pub fn new() -> Self {
259        Self::default()
260    }
261    pub fn append_by_iter<'a, I>(&mut self, mut host: I, data: T)
262    where
263        I: Iterator<Item = &'a str>,
264    {
265        if let Some(segment) = host.next() {
266            match segment {
267                "*" => match self.else_node {
268                    Some(ref mut node) => {
269                        node.append_by_iter(host, data);
270                    }
271                    None => {
272                        let mut node = HostnameMatcherNode::new();
273                        node.append_by_iter(host, data);
274                        self.else_node = Some(Box::new(node));
275                    }
276                },
277                seg => {
278                    self.children.entry(seg.to_ascii_lowercase()).or_default().append_by_iter(host, data);
279                }
280            }
281        } else {
282            self.data = Some(data);
283        }
284    }
285    pub fn set(&mut self, host: &str, data: T) {
286        self.append_by_iter(host.split('.').rev(), data);
287    }
288    pub fn get_by_iter<'a, I>(&self, mut host: I) -> Option<&T>
289    where
290        I: Iterator<Item = &'a str> + Clone,
291    {
292        if let Some(segment) = host.next() {
293            let children_match = match self.children.get(segment) {
294                Some(node) => node.get_by_iter(host.clone()),
295                None => None,
296            };
297            match children_match {
298                Some(data) => Some(data),
299                None => {
300                    let else_node = self.else_node.as_ref()?;
301                    else_node.get_by_iter(host).or(else_node.data.as_ref())
302                }
303            }
304        } else {
305            self.data.as_ref()
306        }
307    }
308    pub fn get_mut_by_iter<'a, 'b, I>(&'b mut self, host: I) -> Option<&'b mut T>
309    where
310        I: Iterator<Item = &'a str> + Clone,
311    {
312        // it's safe to do so because we don't have any other reference to self
313        self.get_by_iter(host).map(|r| unsafe {
314            let r = r as *const T as *mut T;
315            r.as_mut().expect("fail to convert ptr")
316        })
317    }
318    pub fn get(&self, host: &str) -> Option<&T> {
319        let host = host.to_ascii_lowercase();
320        self.get_by_iter(host.split('.').rev())
321    }
322    pub fn get_mut(&mut self, host: &str) -> Option<&mut T> {
323        let host = host.to_ascii_lowercase();
324        self.get_mut_by_iter(host.split('.').rev())
325    }
326    pub fn iter(&self) -> HostnameMatcherNodeIter<'_, T> {
327        HostnameMatcherNodeIter {
328            data: self.data.iter(),
329            children: self.children.values(),
330            else_node: self.else_node.as_ref().map(|node| Box::new(node.iter())),
331        }
332    }
333    pub fn iter_mut(&mut self) -> HostnameMatcherNodeIterMut<'_, T> {
334        HostnameMatcherNodeIterMut {
335            data: self.data.iter_mut(),
336            children: self.children.values_mut(),
337            else_node: self.else_node.as_mut().map(|node| Box::new(node.iter_mut())),
338        }
339    }
340}
341
342#[cfg(test)]
343mod test {
344    use super::*;
345    macro_rules! test_cases {
346        ($tree: ident
347            $(![$($unmatched_case: literal),*])?
348            $([$($case: literal),*] => $rule:literal)*
349        ) => {
350            $($tree.set($rule, $rule);)*
351            println!("{:#?}", $tree.host);
352            $(
353                $(
354                    assert_eq!($tree.get($unmatched_case), None);
355                )*
356            )?
357            $(
358                $(
359                    assert_eq!($tree.get($case).cloned(), Some($rule));
360                )*
361            )*
362        };
363    }
364    #[test]
365    fn test_hostname_matcher_without_fallback() {
366        let mut tree = HostnameTree::new();
367        test_cases! {
368            tree
369            !["com", "127.0.0.23"]
370            ["[::0]", "[::0]:80", "[::]"] => "[::0]"
371            ["192.168.0.1"] => "192.168.0.1"
372            ["example.com", "example.com:80"] => "example.com"
373            ["api.example.com", "apL.v1.example.com:1000"] => "*.example.com"
374            ["api.v1.example.com", "api.v2.example.com"] => "api.*.example.com"
375            ["baidu.com"] => "*.com"
376        }
377    }
378    #[test]
379    fn test_hostname_matcher_node() {
380        let mut tree = HostnameTree::new();
381        test_cases! {
382            tree
383            ["[::0]", "[::0]:80", "[::]"] => "[::0]"
384            ["192.168.0.1"] => "192.168.0.1"
385            ["example.com", "example.com:80"] => "example.com"
386            ["api.example.com", "apL.v1.example.com:1000"] => "*.example.com"
387            ["api.v1.example.com", "api.v2.example.com"] => "api.*.example.com"
388            ["baidu.com"] => "*.com"
389            ["[::1]", "127.0.0.1", "com", "example.org", "example.org:80", "example.org:443", "localhost:8080"] => "*"
390        }
391    }
392    #[test]
393    fn test_any_match() {
394        let mut tree = HostnameTree::new();
395        test_cases! {
396            tree
397            ["com", "example.org", "example.org:80", "example.org:443", "localhost:8080", "127.0.0.1:9090"] => "*"
398        }
399    }
400}