Skip to main content

rustio_core/
router.rs

1//! Path router with `:param` support.
2//!
3//! Routes are registered against a [`Router`] and dispatched by path +
4//! method. Paths that match but with the wrong method produce `405 Method
5//! Not Allowed` rather than collapsing to `404`.
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use hyper::Method;
12
13use crate::error::Error;
14use crate::http::{Request, Response};
15use crate::middleware::{MiddlewareFn, Next};
16
17type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
18type HandlerFn = Arc<dyn Fn(Request, Params) -> BoxFuture<Result<Response, Error>> + Send + Sync>;
19
20pub struct Params {
21    pairs: Vec<(String, String)>,
22}
23
24impl Params {
25    fn empty() -> Self {
26        Self { pairs: Vec::new() }
27    }
28
29    pub fn get(&self, name: &str) -> Option<&str> {
30        self.pairs
31            .iter()
32            .find_map(|(k, v)| (k == name).then_some(v.as_str()))
33    }
34
35    pub fn len(&self) -> usize {
36        self.pairs.len()
37    }
38
39    pub fn is_empty(&self) -> bool {
40        self.pairs.is_empty()
41    }
42}
43
44enum Segment {
45    Literal(String),
46    Param(String),
47}
48
49struct Route {
50    method: Method,
51    segments: Vec<Segment>,
52    handler: HandlerFn,
53}
54
55pub struct Router {
56    routes: Vec<Route>,
57    middlewares: Vec<MiddlewareFn>,
58}
59
60impl Router {
61    pub fn new() -> Self {
62        Self {
63            routes: Vec::new(),
64            middlewares: Vec::new(),
65        }
66    }
67
68    pub fn wrap<F, Fut>(mut self, middleware: F) -> Self
69    where
70        F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
71        Fut: Future<Output = Result<Response, Error>> + Send + 'static,
72    {
73        self.middlewares
74            .push(Arc::new(move |req, next| Box::pin(middleware(req, next))));
75        self
76    }
77
78    pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
79    where
80        F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
81        Fut: Future<Output = Result<Response, Error>> + Send + 'static,
82    {
83        self.route(Method::GET, path, handler)
84    }
85
86    pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
87    where
88        F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
89        Fut: Future<Output = Result<Response, Error>> + Send + 'static,
90    {
91        self.route(Method::POST, path, handler)
92    }
93
94    fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
95    where
96        F: Fn(Request, Params) -> Fut + Send + Sync + 'static,
97        Fut: Future<Output = Result<Response, Error>> + Send + 'static,
98    {
99        let handler: HandlerFn = Arc::new(move |req, params| Box::pin(handler(req, params)));
100        self.routes.push(Route {
101            method,
102            segments: parse_path(path),
103            handler,
104        });
105        self
106    }
107
108    /// `true` when a handler is already registered for `(method, path)`.
109    ///
110    /// Path comparison is literal — `:id` and other patterns are matched
111    /// by their textual form. Used by [`crate::defaults::with_defaults`]
112    /// to skip registering defaults the project already owns (e.g. a
113    /// custom `/` homepage) so that registration order doesn't silently
114    /// shadow user intent.
115    pub fn has_route(&self, method: &Method, path: &str) -> bool {
116        let target = parse_path(path);
117        self.routes
118            .iter()
119            .any(|r| r.method == *method && segments_equal(&r.segments, &target))
120    }
121
122    pub async fn dispatch(&self, req: Request) -> Response {
123        let path = req.uri().path().to_owned();
124        let actual: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
125        let method = req.method().clone();
126
127        let mut found: Option<(HandlerFn, Params)> = None;
128        let mut path_matched = false;
129
130        for route in &self.routes {
131            if let Some(params) = match_segments(&route.segments, &actual) {
132                path_matched = true;
133                if route.method == method {
134                    found = Some((route.handler.clone(), params));
135                    break;
136                }
137            }
138        }
139
140        let (handler, params) = found.unwrap_or_else(|| {
141            let method_not_allowed = path_matched;
142            let fallback: HandlerFn = Arc::new(move |_req, _params| {
143                let err = if method_not_allowed {
144                    Error::MethodNotAllowed
145                } else {
146                    Error::NotFound
147                };
148                Box::pin(async move { Err(err) })
149            });
150            (fallback, Params::empty())
151        });
152
153        let chain = build_chain(&self.middlewares, handler, params);
154        match chain(req).await {
155            Ok(resp) => resp,
156            Err(err) => err.into_response(),
157        }
158    }
159}
160
161impl Default for Router {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167fn build_chain(
168    middlewares: &[MiddlewareFn],
169    handler: HandlerFn,
170    params: Params,
171) -> Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> {
172    let mut chain: Box<dyn FnOnce(Request) -> BoxFuture<Result<Response, Error>> + Send> =
173        Box::new(move |req| handler(req, params));
174
175    for mw in middlewares.iter().rev() {
176        let mw = mw.clone();
177        let inner = chain;
178        chain = Box::new(move |req| {
179            let next = Next::new(inner);
180            mw(req, next)
181        });
182    }
183    chain
184}
185
186fn parse_path(path: &str) -> Vec<Segment> {
187    path.split('/')
188        .filter(|s| !s.is_empty())
189        .map(|s| match s.strip_prefix(':') {
190            Some(name) => Segment::Param(name.to_owned()),
191            None => Segment::Literal(s.to_owned()),
192        })
193        .collect()
194}
195
196fn segments_equal(a: &[Segment], b: &[Segment]) -> bool {
197    if a.len() != b.len() {
198        return false;
199    }
200    a.iter().zip(b.iter()).all(|(x, y)| match (x, y) {
201        (Segment::Literal(u), Segment::Literal(v)) => u == v,
202        (Segment::Param(u), Segment::Param(v)) => u == v,
203        _ => false,
204    })
205}
206
207fn match_segments(patterns: &[Segment], actual: &[&str]) -> Option<Params> {
208    if patterns.len() != actual.len() {
209        return None;
210    }
211    let mut params = Params::empty();
212    for (pat, seg) in patterns.iter().zip(actual.iter()) {
213        match pat {
214            Segment::Literal(lit) => {
215                if lit != seg {
216                    return None;
217                }
218            }
219            Segment::Param(name) => {
220                params.pairs.push((name.clone(), (*seg).to_owned()));
221            }
222        }
223    }
224    Some(params)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    fn segs(path: &str) -> Vec<Segment> {
232        parse_path(path)
233    }
234
235    fn parts(path: &str) -> Vec<&str> {
236        path.split('/').filter(|s| !s.is_empty()).collect()
237    }
238
239    #[test]
240    fn root_path_is_empty_segment_list() {
241        assert!(parse_path("/").is_empty());
242    }
243
244    #[test]
245    fn literal_match() {
246        assert!(match_segments(&segs("/users"), &parts("/users")).is_some());
247        assert!(match_segments(&segs("/users"), &parts("/posts")).is_none());
248    }
249
250    #[test]
251    fn param_captures_value() {
252        let params = match_segments(&segs("/users/:id"), &parts("/users/42")).unwrap();
253        assert_eq!(params.get("id"), Some("42"));
254    }
255
256    #[test]
257    fn length_mismatch_does_not_match() {
258        assert!(match_segments(&segs("/users/:id"), &parts("/users")).is_none());
259        assert!(match_segments(&segs("/users"), &parts("/users/42")).is_none());
260    }
261
262    #[test]
263    fn multiple_params_captured_by_name() {
264        let params = match_segments(&segs("/a/:x/b/:y"), &parts("/a/first/b/second")).unwrap();
265        assert_eq!(params.get("x"), Some("first"));
266        assert_eq!(params.get("y"), Some("second"));
267    }
268
269    #[test]
270    fn has_route_detects_registered_path() {
271        let router = Router::new()
272            .get("/", |_req, _p| async {
273                Ok::<Response, Error>(crate::http::text("home"))
274            })
275            .get("/users/:id", |_req, _p| async {
276                Ok::<Response, Error>(crate::http::text("user"))
277            });
278        assert!(router.has_route(&Method::GET, "/"));
279        assert!(router.has_route(&Method::GET, "/users/:id"));
280        assert!(!router.has_route(&Method::GET, "/missing"));
281        assert!(!router.has_route(&Method::POST, "/"));
282        // Param name difference should still count as the same path shape.
283        assert!(router.has_route(&Method::GET, "/users/:id"));
284        assert!(!router.has_route(&Method::GET, "/users/:other"));
285    }
286}