Skip to main content

rustio_core/
router.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use hyper::Method;
6
7use crate::error::Error;
8use crate::http::{Request, Response};
9use crate::middleware::{MiddlewareFn, Next};
10
11type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
12type HandlerFn =
13    Arc<dyn Fn(Request, Params) -> BoxFuture<Result<Response, Error>> + Send + Sync>;
14
15pub struct Params {
16    pairs: Vec<(String, String)>,
17}
18
19impl Params {
20    fn empty() -> Self {
21        Self { pairs: Vec::new() }
22    }
23
24    pub fn get(&self, name: &str) -> Option<&str> {
25        self.pairs
26            .iter()
27            .find_map(|(k, v)| (k == name).then_some(v.as_str()))
28    }
29
30    pub fn len(&self) -> usize {
31        self.pairs.len()
32    }
33
34    pub fn is_empty(&self) -> bool {
35        self.pairs.is_empty()
36    }
37}
38
39enum Segment {
40    Literal(String),
41    Param(String),
42}
43
44struct Route {
45    method: Method,
46    segments: Vec<Segment>,
47    handler: HandlerFn,
48}
49
50pub struct Router {
51    routes: Vec<Route>,
52    middlewares: Vec<MiddlewareFn>,
53}
54
55impl Router {
56    pub fn new() -> Self {
57        Self {
58            routes: Vec::new(),
59            middlewares: Vec::new(),
60        }
61    }
62
63    pub fn wrap<F, Fut>(mut self, middleware: F) -> Self
64    where
65        F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
66        Fut: Future<Output = Result<Response, Error>> + Send + 'static,
67    {
68        self.middlewares
69            .push(Arc::new(move |req, next| Box::pin(middleware(req, next))));
70        self
71    }
72
73    pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
74    where
75        F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
76        Fut: Future<Output = Result<Response, Error>> + Send + 'static,
77    {
78        self.route(Method::GET, path, handler)
79    }
80
81    pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
82    where
83        F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
84        Fut: Future<Output = Result<Response, Error>> + Send + 'static,
85    {
86        self.route(Method::POST, path, handler)
87    }
88
89    fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
90    where
91        F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
92        Fut: Future<Output = Result<Response, Error>> + Send + 'static,
93    {
94        let handler: HandlerFn = Arc::new(move |req, params| Box::pin(handler(req, params)));
95        self.routes.push(Route {
96            method,
97            segments: parse_path(path),
98            handler,
99        });
100        self
101    }
102
103    pub async fn dispatch(&self, req: Request) -> Response {
104        let path = req.uri().path().to_owned();
105        let actual: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
106        let method = req.method().clone();
107
108        let mut found: Option<(HandlerFn, Params)> = None;
109        let mut path_matched = false;
110
111        for route in &self.routes {
112            if let Some(params) = match_segments(&route.segments, &actual) {
113                path_matched = true;
114                if route.method == method {
115                    found = Some((route.handler.clone(), params));
116                    break;
117                }
118            }
119        }
120
121        let (handler, params) = found.unwrap_or_else(|| {
122            let method_not_allowed = path_matched;
123            let fallback: HandlerFn = Arc::new(move |_req, _params| {
124                let err = if method_not_allowed {
125                    Error::MethodNotAllowed
126                } else {
127                    Error::NotFound
128                };
129                Box::pin(async move { Err(err) })
130            });
131            (fallback, Params::empty())
132        });
133
134        let chain = build_chain(&self.middlewares, handler, params);
135        match chain(req).await {
136            Ok(resp) => resp,
137            Err(err) => err.into_response(),
138        }
139    }
140}
141
142impl Default for Router {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148fn build_chain(
149    middlewares: &[MiddlewareFn],
150    handler: HandlerFn,
151    params: Params,
152) -> Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> {
153    let mut chain: Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> =
154        Box::new(move |req| handler(req, params));
155
156    for mw in middlewares.iter().rev() {
157        let mw = mw.clone();
158        let inner = chain;
159        chain = Box::new(move |req| {
160            let next = Next::new(inner);
161            mw(req, next)
162        });
163    }
164    chain
165}
166
167fn parse_path(path: &str) -> Vec<Segment> {
168    path.split('/')
169        .filter(|s| !s.is_empty())
170        .map(|s| match s.strip_prefix(':') {
171            Some(name) => Segment::Param(name.to_owned()),
172            None => Segment::Literal(s.to_owned()),
173        })
174        .collect()
175}
176
177fn match_segments(patterns: &[Segment], actual: &[&str]) -> Option<Params> {
178    if patterns.len() != actual.len() {
179        return None;
180    }
181    let mut params = Params::empty();
182    for (pat, seg) in patterns.iter().zip(actual.iter()) {
183        match pat {
184            Segment::Literal(lit) => {
185                if lit != seg {
186                    return None;
187                }
188            }
189            Segment::Param(name) => {
190                params.pairs.push((name.clone(), (*seg).to_owned()));
191            }
192        }
193    }
194    Some(params)
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    fn segs(path: &str) -> Vec<Segment> {
202        parse_path(path)
203    }
204
205    fn parts(path: &str) -> Vec<&str> {
206        path.split('/').filter(|s| !s.is_empty()).collect()
207    }
208
209    #[test]
210    fn root_path_is_empty_segment_list() {
211        assert!(parse_path("/").is_empty());
212    }
213
214    #[test]
215    fn literal_match() {
216        assert!(match_segments(&segs("/users"), &parts("/users")).is_some());
217        assert!(match_segments(&segs("/users"), &parts("/posts")).is_none());
218    }
219
220    #[test]
221    fn param_captures_value() {
222        let params = match_segments(&segs("/users/:id"), &parts("/users/42")).unwrap();
223        assert_eq!(params.get("id"), Some("42"));
224    }
225
226    #[test]
227    fn length_mismatch_does_not_match() {
228        assert!(match_segments(&segs("/users/:id"), &parts("/users")).is_none());
229        assert!(match_segments(&segs("/users"), &parts("/users/42")).is_none());
230    }
231
232    #[test]
233    fn multiple_params_captured_by_name() {
234        let params = match_segments(
235            &segs("/a/:x/b/:y"),
236            &parts("/a/first/b/second"),
237        )
238        .unwrap();
239        assert_eq!(params.get("x"), Some("first"));
240        assert_eq!(params.get("y"), Some("second"));
241    }
242}