Skip to main content

rustio_admin/
router.rs

1//! A small, opinionated router.
2//!
3//! - Path segments starting with `:` are captured into `req.param(name)`.
4//! - Middleware is a chain of `async fn(Request, Next) -> Result<Response>`.
5//! - 404 vs 405 is distinguished (path matched but method didn't → 405).
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use hyper::Method;
12
13use crate::error::{Error, Result};
14use crate::http::{response_from_error, Request, Response};
15
16pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
17
18pub type HandlerFn =
19    Arc<dyn Fn(Request) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
20
21pub type MiddlewareFn =
22    Arc<dyn Fn(Request, Next) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
23
24pub struct Next {
25    chain: Vec<MiddlewareFn>,
26    handler: HandlerFn,
27    index: usize,
28}
29
30impl Next {
31    pub fn run(mut self, req: Request) -> BoxFuture<'static, Result<Response>> {
32        Box::pin(async move {
33            if self.index < self.chain.len() {
34                let mw = self.chain[self.index].clone();
35                self.index += 1;
36                mw(req, self).await
37            } else {
38                (self.handler)(req).await
39            }
40        })
41    }
42}
43
44struct Route {
45    method: Method,
46    segments: Vec<Segment>,
47    handler: HandlerFn,
48}
49
50enum Segment {
51    Static(String),
52    Param(String),
53}
54
55pub struct Router {
56    routes: Vec<Route>,
57    middleware: Vec<MiddlewareFn>,
58}
59
60impl Default for Router {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl Router {
67    pub fn new() -> Self {
68        Self {
69            routes: Vec::new(),
70            middleware: Vec::new(),
71        }
72    }
73
74    pub fn middleware<F, Fut>(mut self, mw: F) -> Self
75    where
76        F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
77        Fut: Future<Output = Result<Response>> + Send + 'static,
78    {
79        let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(mw(req, next)));
80        self.middleware.push(wrapped);
81        self
82    }
83
84    pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
85    where
86        F: Fn(Request) -> Fut + Send + Sync + 'static,
87        Fut: Future<Output = Result<Response>> + Send + 'static,
88    {
89        self.route(Method::GET, path, handler)
90    }
91
92    pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
93    where
94        F: Fn(Request) -> Fut + Send + Sync + 'static,
95        Fut: Future<Output = Result<Response>> + Send + 'static,
96    {
97        self.route(Method::POST, path, handler)
98    }
99
100    pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
101    where
102        F: Fn(Request) -> Fut + Send + Sync + 'static,
103        Fut: Future<Output = Result<Response>> + Send + 'static,
104    {
105        let segments = parse_path(path);
106        let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
107        self.routes.push(Route {
108            method,
109            segments,
110            handler,
111        });
112        self
113    }
114
115    /// Look up a handler for the given method+path.
116    fn find(&self, method: &Method, path: &str) -> MatchResult {
117        let mut path_segs: Vec<&str> = path.trim_start_matches('/').split('/').collect();
118        // Normalise trailing slash: `/admin/posts/` and `/admin/posts`
119        // address the same handler. Skip when the only segment is empty
120        // (the root path `/`), so root-route lookups still work.
121        if path_segs.len() > 1 && path_segs.last() == Some(&"") {
122            path_segs.pop();
123        }
124        let mut path_matched = false;
125
126        for route in &self.routes {
127            if !segments_match(&route.segments, &path_segs) {
128                continue;
129            }
130            path_matched = true;
131            if route.method == *method {
132                let params = extract_params(&route.segments, &path_segs);
133                return MatchResult::Ok {
134                    handler: route.handler.clone(),
135                    params,
136                };
137            }
138        }
139
140        if path_matched {
141            MatchResult::MethodNotAllowed
142        } else {
143            MatchResult::NotFound
144        }
145    }
146
147    pub async fn dispatch(&self, mut req: Request) -> Response {
148        let matched = self.find(req.method(), req.path());
149
150        let outcome = match matched {
151            MatchResult::Ok { handler, params } => {
152                req.set_params(params);
153                let next = Next {
154                    chain: self.middleware.clone(),
155                    handler,
156                    index: 0,
157                };
158                next.run(req).await
159            }
160            MatchResult::NotFound => Err(Error::NotFound(format!("no route for {}", req.path()))),
161            MatchResult::MethodNotAllowed => Err(Error::MethodNotAllowed(format!(
162                "{} not allowed",
163                req.method()
164            ))),
165        };
166
167        match outcome {
168            Ok(resp) => resp,
169            Err(err) => response_from_error(&err),
170        }
171    }
172}
173
174enum MatchResult {
175    Ok {
176        handler: HandlerFn,
177        params: std::collections::HashMap<String, String>,
178    },
179    NotFound,
180    MethodNotAllowed,
181}
182
183fn parse_path(path: &str) -> Vec<Segment> {
184    path.trim_start_matches('/')
185        .split('/')
186        .map(|seg| {
187            if let Some(name) = seg.strip_prefix(':') {
188                Segment::Param(name.to_string())
189            } else {
190                Segment::Static(seg.to_string())
191            }
192        })
193        .collect()
194}
195
196fn segments_match(route: &[Segment], path: &[&str]) -> bool {
197    if route.len() != path.len() {
198        return false;
199    }
200    for (r, p) in route.iter().zip(path.iter()) {
201        match r {
202            Segment::Static(s) if s != p => return false,
203            _ => {}
204        }
205    }
206    true
207}
208
209fn extract_params(route: &[Segment], path: &[&str]) -> std::collections::HashMap<String, String> {
210    let mut out = std::collections::HashMap::new();
211    for (r, p) in route.iter().zip(path.iter()) {
212        if let Segment::Param(name) = r {
213            out.insert(name.clone(), (*p).to_string());
214        }
215    }
216    out
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[tokio::test]
224    async fn matches_static_path() {
225        let router = Router::new().get("/hello", |_req| async { Ok(Response::text("hi")) });
226        let req = Request::new(
227            Method::GET,
228            "/hello".into(),
229            String::new(),
230            Default::default(),
231            bytes::Bytes::new(),
232        );
233        let resp = router.dispatch(req).await;
234        assert_eq!(resp.status.as_u16(), 200);
235    }
236
237    #[tokio::test]
238    async fn captures_param() {
239        let router = Router::new().get("/users/:id", |req| async move {
240            let id = req.param("id").unwrap_or("").to_string();
241            Ok(Response::text(id))
242        });
243        let req = Request::new(
244            Method::GET,
245            "/users/42".into(),
246            String::new(),
247            Default::default(),
248            bytes::Bytes::new(),
249        );
250        let resp = router.dispatch(req).await;
251        assert_eq!(resp.status.as_u16(), 200);
252        assert_eq!(&resp.body[..], b"42");
253    }
254
255    #[tokio::test]
256    async fn distinguishes_404_from_405() {
257        let router = Router::new().get("/things", |_| async { Ok(Response::text("ok")) });
258
259        let post = Request::new(
260            Method::POST,
261            "/things".into(),
262            String::new(),
263            Default::default(),
264            bytes::Bytes::new(),
265        );
266        assert_eq!(router.dispatch(post).await.status.as_u16(), 405);
267
268        let missing = Request::new(
269            Method::GET,
270            "/nope".into(),
271            String::new(),
272            Default::default(),
273            bytes::Bytes::new(),
274        );
275        assert_eq!(router.dispatch(missing).await.status.as_u16(), 404);
276    }
277
278    #[tokio::test]
279    async fn trailing_slash_is_normalised_for_static_and_param_routes() {
280        let router = Router::new()
281            .get("/admin/:name", |req| async move {
282                let name = req.param("name").unwrap_or("").to_string();
283                Ok(Response::text(name))
284            })
285            .get("/admin/:name/:id/edit", |req| async move {
286                Ok(Response::text(format!(
287                    "{}/{}",
288                    req.param("name").unwrap_or(""),
289                    req.param("id").unwrap_or(""),
290                )))
291            });
292
293        for path in ["/admin/posts", "/admin/posts/"] {
294            let req = Request::new(
295                Method::GET,
296                path.into(),
297                String::new(),
298                Default::default(),
299                bytes::Bytes::new(),
300            );
301            let resp = router.dispatch(req).await;
302            assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
303            assert_eq!(&resp.body[..], b"posts", "GET {path} body");
304        }
305
306        for path in ["/admin/posts/1/edit", "/admin/posts/1/edit/"] {
307            let req = Request::new(
308                Method::GET,
309                path.into(),
310                String::new(),
311                Default::default(),
312                bytes::Bytes::new(),
313            );
314            let resp = router.dispatch(req).await;
315            assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
316            assert_eq!(&resp.body[..], b"posts/1", "GET {path} body");
317        }
318    }
319
320    #[tokio::test]
321    async fn root_path_still_matches_after_trailing_slash_normalisation() {
322        // Regression check: the trailing-slash strip must NOT collapse
323        // the single empty segment that represents the root path.
324        let router = Router::new().get("/", |_| async { Ok(Response::text("home")) });
325        let req = Request::new(
326            Method::GET,
327            "/".into(),
328            String::new(),
329            Default::default(),
330            bytes::Bytes::new(),
331        );
332        let resp = router.dispatch(req).await;
333        assert_eq!(resp.status.as_u16(), 200);
334        assert_eq!(&resp.body[..], b"home");
335    }
336}