1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use hyper::Method;
12
13use crate::error::{Error, Result};
14use crate::http::{response_from_error, Request, Response};
15
16pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
17
18pub type HandlerFn =
19 Arc<dyn Fn(Request) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
20
21pub type MiddlewareFn =
22 Arc<dyn Fn(Request, Next) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
23
24pub struct Next {
25 chain: Vec<MiddlewareFn>,
26 handler: HandlerFn,
27 index: usize,
28}
29
30impl Next {
31 pub fn run(mut self, req: Request) -> BoxFuture<'static, Result<Response>> {
32 Box::pin(async move {
33 if self.index < self.chain.len() {
34 let mw = self.chain[self.index].clone();
35 self.index += 1;
36 mw(req, self).await
37 } else {
38 (self.handler)(req).await
39 }
40 })
41 }
42}
43
44struct Route {
45 method: Method,
46 segments: Vec<Segment>,
47 handler: HandlerFn,
48}
49
50enum Segment {
51 Static(String),
52 Param(String),
53}
54
55pub struct Router {
56 routes: Vec<Route>,
57 middleware: Vec<MiddlewareFn>,
58}
59
60impl Default for Router {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl Router {
67 pub fn new() -> Self {
68 Self {
69 routes: Vec::new(),
70 middleware: Vec::new(),
71 }
72 }
73
74 pub fn middleware<F, Fut>(mut self, mw: F) -> Self
75 where
76 F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
77 Fut: Future<Output = Result<Response>> + Send + 'static,
78 {
79 let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(mw(req, next)));
80 self.middleware.push(wrapped);
81 self
82 }
83
84 pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
85 where
86 F: Fn(Request) -> Fut + Send + Sync + 'static,
87 Fut: Future<Output = Result<Response>> + Send + 'static,
88 {
89 self.route(Method::GET, path, handler)
90 }
91
92 pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
93 where
94 F: Fn(Request) -> Fut + Send + Sync + 'static,
95 Fut: Future<Output = Result<Response>> + Send + 'static,
96 {
97 self.route(Method::POST, path, handler)
98 }
99
100 pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
101 where
102 F: Fn(Request) -> Fut + Send + Sync + 'static,
103 Fut: Future<Output = Result<Response>> + Send + 'static,
104 {
105 let segments = parse_path(path);
106 let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
107 self.routes.push(Route {
108 method,
109 segments,
110 handler,
111 });
112 self
113 }
114
115 fn find(&self, method: &Method, path: &str) -> MatchResult {
119 let mut path_segs: Vec<&str> = path.trim_start_matches('/').split('/').collect();
120 if path_segs.len() > 1 && path_segs.last() == Some(&"") {
124 path_segs.pop();
125 }
126 let mut path_matched = false;
127
128 for route in &self.routes {
129 if !segments_match(&route.segments, &path_segs) {
130 continue;
131 }
132 path_matched = true;
133 if route.method == *method {
134 let params = extract_params(&route.segments, &path_segs);
135 return MatchResult::Ok {
136 handler: route.handler.clone(),
137 params,
138 };
139 }
140 }
141
142 if path_matched {
143 MatchResult::MethodNotAllowed
144 } else {
145 MatchResult::NotFound
146 }
147 }
148
149 pub async fn dispatch(&self, mut req: Request) -> Response {
150 let matched = self.find(req.method(), req.path());
151
152 let outcome = match matched {
153 MatchResult::Ok { handler, params } => {
154 req.set_params(params);
155 let next = Next {
156 chain: self.middleware.clone(),
157 handler,
158 index: 0,
159 };
160 next.run(req).await
161 }
162 MatchResult::NotFound => Err(Error::NotFound(format!("no route for {}", req.path()))),
163 MatchResult::MethodNotAllowed => {
164 Err(Error::MethodNotAllowed(format!("{} not allowed", req.method())))
165 }
166 };
167
168 match outcome {
169 Ok(resp) => resp,
170 Err(err) => response_from_error(&err),
171 }
172 }
173}
174
175enum MatchResult {
176 Ok {
177 handler: HandlerFn,
178 params: std::collections::HashMap<String, String>,
179 },
180 NotFound,
181 MethodNotAllowed,
182}
183
184fn parse_path(path: &str) -> Vec<Segment> {
185 path.trim_start_matches('/')
186 .split('/')
187 .map(|seg| {
188 if let Some(name) = seg.strip_prefix(':') {
189 Segment::Param(name.to_string())
190 } else {
191 Segment::Static(seg.to_string())
192 }
193 })
194 .collect()
195}
196
197fn segments_match(route: &[Segment], path: &[&str]) -> bool {
198 if route.len() != path.len() {
199 return false;
200 }
201 for (r, p) in route.iter().zip(path.iter()) {
202 match r {
203 Segment::Static(s) if s != p => return false,
204 _ => {}
205 }
206 }
207 true
208}
209
210fn extract_params(route: &[Segment], path: &[&str]) -> std::collections::HashMap<String, String> {
211 let mut out = std::collections::HashMap::new();
212 for (r, p) in route.iter().zip(path.iter()) {
213 if let Segment::Param(name) = r {
214 out.insert(name.clone(), (*p).to_string());
215 }
216 }
217 out
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[tokio::test]
225 async fn matches_static_path() {
226 let router = Router::new().get("/hello", |_req| async {
227 Ok(Response::text("hi"))
228 });
229 let req = Request::new(
230 Method::GET,
231 "/hello".into(),
232 String::new(),
233 Default::default(),
234 bytes::Bytes::new(),
235 );
236 let resp = router.dispatch(req).await;
237 assert_eq!(resp.status.as_u16(), 200);
238 }
239
240 #[tokio::test]
241 async fn captures_param() {
242 let router = Router::new().get("/users/:id", |req| async move {
243 let id = req.param("id").unwrap_or("").to_string();
244 Ok(Response::text(id))
245 });
246 let req = Request::new(
247 Method::GET,
248 "/users/42".into(),
249 String::new(),
250 Default::default(),
251 bytes::Bytes::new(),
252 );
253 let resp = router.dispatch(req).await;
254 assert_eq!(resp.status.as_u16(), 200);
255 assert_eq!(&resp.body[..], b"42");
256 }
257
258 #[tokio::test]
259 async fn distinguishes_404_from_405() {
260 let router = Router::new().get("/things", |_| async { Ok(Response::text("ok")) });
261
262 let post = Request::new(
263 Method::POST,
264 "/things".into(),
265 String::new(),
266 Default::default(),
267 bytes::Bytes::new(),
268 );
269 assert_eq!(router.dispatch(post).await.status.as_u16(), 405);
270
271 let missing = Request::new(
272 Method::GET,
273 "/nope".into(),
274 String::new(),
275 Default::default(),
276 bytes::Bytes::new(),
277 );
278 assert_eq!(router.dispatch(missing).await.status.as_u16(), 404);
279 }
280
281 #[tokio::test]
282 async fn trailing_slash_is_normalised_for_static_and_param_routes() {
283 let router = Router::new()
284 .get("/admin/:name", |req| async move {
285 let name = req.param("name").unwrap_or("").to_string();
286 Ok(Response::text(name))
287 })
288 .get("/admin/:name/:id/edit", |req| async move {
289 Ok(Response::text(format!(
290 "{}/{}",
291 req.param("name").unwrap_or(""),
292 req.param("id").unwrap_or(""),
293 )))
294 });
295
296 for path in ["/admin/posts", "/admin/posts/"] {
297 let req = Request::new(
298 Method::GET,
299 path.into(),
300 String::new(),
301 Default::default(),
302 bytes::Bytes::new(),
303 );
304 let resp = router.dispatch(req).await;
305 assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
306 assert_eq!(&resp.body[..], b"posts", "GET {path} body");
307 }
308
309 for path in ["/admin/posts/1/edit", "/admin/posts/1/edit/"] {
310 let req = Request::new(
311 Method::GET,
312 path.into(),
313 String::new(),
314 Default::default(),
315 bytes::Bytes::new(),
316 );
317 let resp = router.dispatch(req).await;
318 assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
319 assert_eq!(&resp.body[..], b"posts/1", "GET {path} body");
320 }
321 }
322
323 #[tokio::test]
324 async fn root_path_still_matches_after_trailing_slash_normalisation() {
325 let router = Router::new().get("/", |_| async { Ok(Response::text("home")) });
328 let req = Request::new(
329 Method::GET,
330 "/".into(),
331 String::new(),
332 Default::default(),
333 bytes::Bytes::new(),
334 );
335 let resp = router.dispatch(req).await;
336 assert_eq!(resp.status.as_u16(), 200);
337 assert_eq!(&resp.body[..], b"home");
338 }
339}