1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use hyper::Method;
12
13use crate::error::{Error, Result};
14use crate::http::{response_from_error, Request, Response};
15
16pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
18
19pub type HandlerFn =
21 Arc<dyn Fn(Request) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
22
23pub type MiddlewareFn =
25 Arc<dyn Fn(Request, Next) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
26
27pub struct Next {
29 chain: Vec<MiddlewareFn>,
30 handler: HandlerFn,
31 index: usize,
32}
33
34impl Next {
35 pub fn run(mut self, req: Request) -> BoxFuture<'static, Result<Response>> {
37 Box::pin(async move {
38 if self.index < self.chain.len() {
39 let mw = self.chain[self.index].clone();
40 self.index += 1;
41 mw(req, self).await
42 } else {
43 (self.handler)(req).await
44 }
45 })
46 }
47}
48
49struct Route {
50 method: Method,
51 segments: Vec<Segment>,
52 handler: HandlerFn,
53}
54
55enum Segment {
56 Static(String),
57 Param(String),
58}
59
60pub struct Router {
62 routes: Vec<Route>,
63 middleware: Vec<MiddlewareFn>,
64}
65
66impl Default for Router {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl Router {
73 pub fn new() -> Self {
75 Self {
76 routes: Vec::new(),
77 middleware: Vec::new(),
78 }
79 }
80
81 pub fn middleware<F, Fut>(mut self, mw: F) -> Self
83 where
84 F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
85 Fut: Future<Output = Result<Response>> + Send + 'static,
86 {
87 let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(mw(req, next)));
88 self.middleware.push(wrapped);
89 self
90 }
91
92 pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
94 where
95 F: Fn(Request) -> Fut + Send + Sync + 'static,
96 Fut: Future<Output = Result<Response>> + Send + 'static,
97 {
98 self.route(Method::GET, path, handler)
99 }
100
101 pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
103 where
104 F: Fn(Request) -> Fut + Send + Sync + 'static,
105 Fut: Future<Output = Result<Response>> + Send + 'static,
106 {
107 self.route(Method::POST, path, handler)
108 }
109
110 pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
112 where
113 F: Fn(Request) -> Fut + Send + Sync + 'static,
114 Fut: Future<Output = Result<Response>> + Send + 'static,
115 {
116 let segments = parse_path(path);
117 let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
118 self.routes.push(Route {
119 method,
120 segments,
121 handler,
122 });
123 self
124 }
125
126 fn find(&self, method: &Method, path: &str) -> MatchResult {
128 let mut path_segs: Vec<&str> = path.trim_start_matches('/').split('/').collect();
129 if path_segs.len() > 1 && path_segs.last() == Some(&"") {
133 path_segs.pop();
134 }
135 let mut path_matched = false;
136
137 for route in &self.routes {
138 if !segments_match(&route.segments, &path_segs) {
139 continue;
140 }
141 path_matched = true;
142 if route.method == *method {
143 let params = extract_params(&route.segments, &path_segs);
144 return MatchResult::Ok {
145 handler: route.handler.clone(),
146 params,
147 };
148 }
149 }
150
151 if path_matched {
152 MatchResult::MethodNotAllowed
153 } else {
154 MatchResult::NotFound
155 }
156 }
157
158 pub async fn dispatch(&self, mut req: Request) -> Response {
160 let matched = self.find(req.method(), req.path());
161
162 let outcome = match matched {
163 MatchResult::Ok { handler, params } => {
164 req.set_params(params);
165 let next = Next {
166 chain: self.middleware.clone(),
167 handler,
168 index: 0,
169 };
170 next.run(req).await
171 }
172 MatchResult::NotFound => Err(Error::NotFound(format!("no route for {}", req.path()))),
173 MatchResult::MethodNotAllowed => Err(Error::MethodNotAllowed(format!(
174 "{} not allowed",
175 req.method()
176 ))),
177 };
178
179 match outcome {
180 Ok(resp) => resp,
181 Err(err) => response_from_error(&err),
182 }
183 }
184}
185
186enum MatchResult {
187 Ok {
188 handler: HandlerFn,
189 params: std::collections::HashMap<String, String>,
190 },
191 NotFound,
192 MethodNotAllowed,
193}
194
195fn parse_path(path: &str) -> Vec<Segment> {
196 path.trim_start_matches('/')
197 .split('/')
198 .map(|seg| {
199 if let Some(name) = seg.strip_prefix(':') {
200 Segment::Param(name.to_string())
201 } else {
202 Segment::Static(seg.to_string())
203 }
204 })
205 .collect()
206}
207
208fn segments_match(route: &[Segment], path: &[&str]) -> bool {
209 if route.len() != path.len() {
210 return false;
211 }
212 for (r, p) in route.iter().zip(path.iter()) {
213 match r {
214 Segment::Static(s) if s != p => return false,
215 _ => {}
216 }
217 }
218 true
219}
220
221fn extract_params(route: &[Segment], path: &[&str]) -> std::collections::HashMap<String, String> {
222 let mut out = std::collections::HashMap::new();
223 for (r, p) in route.iter().zip(path.iter()) {
224 if let Segment::Param(name) = r {
225 out.insert(name.clone(), (*p).to_string());
226 }
227 }
228 out
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[tokio::test]
236 async fn matches_static_path() {
237 let router = Router::new().get("/hello", |_req| async { Ok(Response::text("hi")) });
238 let req = Request::new(
239 Method::GET,
240 "/hello".into(),
241 String::new(),
242 Default::default(),
243 bytes::Bytes::new(),
244 );
245 let resp = router.dispatch(req).await;
246 assert_eq!(resp.status.as_u16(), 200);
247 }
248
249 #[tokio::test]
250 async fn captures_param() {
251 let router = Router::new().get("/users/:id", |req| async move {
252 let id = req.param("id").unwrap_or("").to_string();
253 Ok(Response::text(id))
254 });
255 let req = Request::new(
256 Method::GET,
257 "/users/42".into(),
258 String::new(),
259 Default::default(),
260 bytes::Bytes::new(),
261 );
262 let resp = router.dispatch(req).await;
263 assert_eq!(resp.status.as_u16(), 200);
264 assert_eq!(&resp.body[..], b"42");
265 }
266
267 #[tokio::test]
268 async fn distinguishes_404_from_405() {
269 let router = Router::new().get("/things", |_| async { Ok(Response::text("ok")) });
270
271 let post = Request::new(
272 Method::POST,
273 "/things".into(),
274 String::new(),
275 Default::default(),
276 bytes::Bytes::new(),
277 );
278 assert_eq!(router.dispatch(post).await.status.as_u16(), 405);
279
280 let missing = Request::new(
281 Method::GET,
282 "/nope".into(),
283 String::new(),
284 Default::default(),
285 bytes::Bytes::new(),
286 );
287 assert_eq!(router.dispatch(missing).await.status.as_u16(), 404);
288 }
289
290 #[tokio::test]
291 async fn trailing_slash_is_normalised_for_static_and_param_routes() {
292 let router = Router::new()
293 .get("/admin/:name", |req| async move {
294 let name = req.param("name").unwrap_or("").to_string();
295 Ok(Response::text(name))
296 })
297 .get("/admin/:name/:id/edit", |req| async move {
298 Ok(Response::text(format!(
299 "{}/{}",
300 req.param("name").unwrap_or(""),
301 req.param("id").unwrap_or(""),
302 )))
303 });
304
305 for path in ["/admin/posts", "/admin/posts/"] {
306 let req = Request::new(
307 Method::GET,
308 path.into(),
309 String::new(),
310 Default::default(),
311 bytes::Bytes::new(),
312 );
313 let resp = router.dispatch(req).await;
314 assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
315 assert_eq!(&resp.body[..], b"posts", "GET {path} body");
316 }
317
318 for path in ["/admin/posts/1/edit", "/admin/posts/1/edit/"] {
319 let req = Request::new(
320 Method::GET,
321 path.into(),
322 String::new(),
323 Default::default(),
324 bytes::Bytes::new(),
325 );
326 let resp = router.dispatch(req).await;
327 assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
328 assert_eq!(&resp.body[..], b"posts/1", "GET {path} body");
329 }
330 }
331
332 #[tokio::test]
333 async fn root_path_still_matches_after_trailing_slash_normalisation() {
334 let router = Router::new().get("/", |_| async { Ok(Response::text("home")) });
337 let req = Request::new(
338 Method::GET,
339 "/".into(),
340 String::new(),
341 Default::default(),
342 bytes::Bytes::new(),
343 );
344 let resp = router.dispatch(req).await;
345 assert_eq!(resp.status.as_u16(), 200);
346 assert_eq!(&resp.body[..], b"home");
347 }
348}