Skip to main content

rustio_core/
router.rs

1//! A small, opinionated router.
2//!
3//! - Path segments starting with `:` are captured into `req.param(name)`.
4//! - Middleware is a chain of `async fn(Request, Next) -> Result<Response>`.
5//! - 404 vs 405 is distinguished (path matched but method didn't → 405).
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use hyper::Method;
12
13use crate::error::{Error, Result};
14use crate::http::{response_from_error, Request, Response};
15
16pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
17
18pub type HandlerFn =
19    Arc<dyn Fn(Request) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
20
21pub type MiddlewareFn =
22    Arc<dyn Fn(Request, Next) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
23
24pub struct Next {
25    chain: Vec<MiddlewareFn>,
26    handler: HandlerFn,
27    index: usize,
28}
29
30impl Next {
31    pub fn run(mut self, req: Request) -> BoxFuture<'static, Result<Response>> {
32        Box::pin(async move {
33            if self.index < self.chain.len() {
34                let mw = self.chain[self.index].clone();
35                self.index += 1;
36                mw(req, self).await
37            } else {
38                (self.handler)(req).await
39            }
40        })
41    }
42}
43
44struct Route {
45    method: Method,
46    segments: Vec<Segment>,
47    handler: HandlerFn,
48}
49
50enum Segment {
51    Static(String),
52    Param(String),
53}
54
55pub struct Router {
56    routes: Vec<Route>,
57    middleware: Vec<MiddlewareFn>,
58}
59
60impl Default for Router {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl Router {
67    pub fn new() -> Self {
68        Self {
69            routes: Vec::new(),
70            middleware: Vec::new(),
71        }
72    }
73
74    pub fn middleware<F, Fut>(mut self, mw: F) -> Self
75    where
76        F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
77        Fut: Future<Output = Result<Response>> + Send + 'static,
78    {
79        let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(mw(req, next)));
80        self.middleware.push(wrapped);
81        self
82    }
83
84    pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
85    where
86        F: Fn(Request) -> Fut + Send + Sync + 'static,
87        Fut: Future<Output = Result<Response>> + Send + 'static,
88    {
89        self.route(Method::GET, path, handler)
90    }
91
92    pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
93    where
94        F: Fn(Request) -> Fut + Send + Sync + 'static,
95        Fut: Future<Output = Result<Response>> + Send + 'static,
96    {
97        self.route(Method::POST, path, handler)
98    }
99
100    pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
101    where
102        F: Fn(Request) -> Fut + Send + Sync + 'static,
103        Fut: Future<Output = Result<Response>> + Send + 'static,
104    {
105        let segments = parse_path(path);
106        let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
107        self.routes.push(Route {
108            method,
109            segments,
110            handler,
111        });
112        self
113    }
114
115    /// Look up a handler for the given method+path. Returns a matched
116    /// handler with captured params, or an error describing why it
117    /// didn't match.
118    fn find(&self, method: &Method, path: &str) -> MatchResult {
119        let mut path_segs: Vec<&str> = path.trim_start_matches('/').split('/').collect();
120        // Normalise trailing slash: `/admin/posts/` and `/admin/posts`
121        // address the same handler. Skip when the only segment is empty
122        // (the root path `/`), so root-route lookups still work.
123        if path_segs.len() > 1 && path_segs.last() == Some(&"") {
124            path_segs.pop();
125        }
126        let mut path_matched = false;
127
128        for route in &self.routes {
129            if !segments_match(&route.segments, &path_segs) {
130                continue;
131            }
132            path_matched = true;
133            if route.method == *method {
134                let params = extract_params(&route.segments, &path_segs);
135                return MatchResult::Ok {
136                    handler: route.handler.clone(),
137                    params,
138                };
139            }
140        }
141
142        if path_matched {
143            MatchResult::MethodNotAllowed
144        } else {
145            MatchResult::NotFound
146        }
147    }
148
149    pub async fn dispatch(&self, mut req: Request) -> Response {
150        let matched = self.find(req.method(), req.path());
151
152        let outcome = match matched {
153            MatchResult::Ok { handler, params } => {
154                req.set_params(params);
155                let next = Next {
156                    chain: self.middleware.clone(),
157                    handler,
158                    index: 0,
159                };
160                next.run(req).await
161            }
162            MatchResult::NotFound => Err(Error::NotFound(format!("no route for {}", req.path()))),
163            MatchResult::MethodNotAllowed => {
164                Err(Error::MethodNotAllowed(format!("{} not allowed", req.method())))
165            }
166        };
167
168        match outcome {
169            Ok(resp) => resp,
170            Err(err) => response_from_error(&err),
171        }
172    }
173}
174
175enum MatchResult {
176    Ok {
177        handler: HandlerFn,
178        params: std::collections::HashMap<String, String>,
179    },
180    NotFound,
181    MethodNotAllowed,
182}
183
184fn parse_path(path: &str) -> Vec<Segment> {
185    path.trim_start_matches('/')
186        .split('/')
187        .map(|seg| {
188            if let Some(name) = seg.strip_prefix(':') {
189                Segment::Param(name.to_string())
190            } else {
191                Segment::Static(seg.to_string())
192            }
193        })
194        .collect()
195}
196
197fn segments_match(route: &[Segment], path: &[&str]) -> bool {
198    if route.len() != path.len() {
199        return false;
200    }
201    for (r, p) in route.iter().zip(path.iter()) {
202        match r {
203            Segment::Static(s) if s != p => return false,
204            _ => {}
205        }
206    }
207    true
208}
209
210fn extract_params(route: &[Segment], path: &[&str]) -> std::collections::HashMap<String, String> {
211    let mut out = std::collections::HashMap::new();
212    for (r, p) in route.iter().zip(path.iter()) {
213        if let Segment::Param(name) = r {
214            out.insert(name.clone(), (*p).to_string());
215        }
216    }
217    out
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[tokio::test]
225    async fn matches_static_path() {
226        let router = Router::new().get("/hello", |_req| async {
227            Ok(Response::text("hi"))
228        });
229        let req = Request::new(
230            Method::GET,
231            "/hello".into(),
232            String::new(),
233            Default::default(),
234            bytes::Bytes::new(),
235        );
236        let resp = router.dispatch(req).await;
237        assert_eq!(resp.status.as_u16(), 200);
238    }
239
240    #[tokio::test]
241    async fn captures_param() {
242        let router = Router::new().get("/users/:id", |req| async move {
243            let id = req.param("id").unwrap_or("").to_string();
244            Ok(Response::text(id))
245        });
246        let req = Request::new(
247            Method::GET,
248            "/users/42".into(),
249            String::new(),
250            Default::default(),
251            bytes::Bytes::new(),
252        );
253        let resp = router.dispatch(req).await;
254        assert_eq!(resp.status.as_u16(), 200);
255        assert_eq!(&resp.body[..], b"42");
256    }
257
258    #[tokio::test]
259    async fn distinguishes_404_from_405() {
260        let router = Router::new().get("/things", |_| async { Ok(Response::text("ok")) });
261
262        let post = Request::new(
263            Method::POST,
264            "/things".into(),
265            String::new(),
266            Default::default(),
267            bytes::Bytes::new(),
268        );
269        assert_eq!(router.dispatch(post).await.status.as_u16(), 405);
270
271        let missing = Request::new(
272            Method::GET,
273            "/nope".into(),
274            String::new(),
275            Default::default(),
276            bytes::Bytes::new(),
277        );
278        assert_eq!(router.dispatch(missing).await.status.as_u16(), 404);
279    }
280
281    #[tokio::test]
282    async fn trailing_slash_is_normalised_for_static_and_param_routes() {
283        let router = Router::new()
284            .get("/admin/:name", |req| async move {
285                let name = req.param("name").unwrap_or("").to_string();
286                Ok(Response::text(name))
287            })
288            .get("/admin/:name/:id/edit", |req| async move {
289                Ok(Response::text(format!(
290                    "{}/{}",
291                    req.param("name").unwrap_or(""),
292                    req.param("id").unwrap_or(""),
293                )))
294            });
295
296        for path in ["/admin/posts", "/admin/posts/"] {
297            let req = Request::new(
298                Method::GET,
299                path.into(),
300                String::new(),
301                Default::default(),
302                bytes::Bytes::new(),
303            );
304            let resp = router.dispatch(req).await;
305            assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
306            assert_eq!(&resp.body[..], b"posts", "GET {path} body");
307        }
308
309        for path in ["/admin/posts/1/edit", "/admin/posts/1/edit/"] {
310            let req = Request::new(
311                Method::GET,
312                path.into(),
313                String::new(),
314                Default::default(),
315                bytes::Bytes::new(),
316            );
317            let resp = router.dispatch(req).await;
318            assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
319            assert_eq!(&resp.body[..], b"posts/1", "GET {path} body");
320        }
321    }
322
323    #[tokio::test]
324    async fn root_path_still_matches_after_trailing_slash_normalisation() {
325        // Regression check: the trailing-slash strip must NOT collapse
326        // the single empty segment that represents the root path.
327        let router = Router::new().get("/", |_| async { Ok(Response::text("home")) });
328        let req = Request::new(
329            Method::GET,
330            "/".into(),
331            String::new(),
332            Default::default(),
333            bytes::Bytes::new(),
334        );
335        let resp = router.dispatch(req).await;
336        assert_eq!(resp.status.as_u16(), 200);
337        assert_eq!(&resp.body[..], b"home");
338    }
339}