1use super::conversions::{IntoResponse, TryFromRequest, TryIntoRequest};
4use super::{responses, Method, Request, Response};
5use async_trait::async_trait;
6use routefinder::{Captures, Router as MethodRouter};
7use std::future::Future;
8use std::{collections::HashMap, fmt::Display};
9
10#[async_trait(?Send)]
15pub trait Handler {
16 async fn handle(&self, req: Request, params: Params) -> Response;
18}
19
20#[async_trait(?Send)]
21impl Handler for Box<dyn Handler> {
22 async fn handle(&self, req: Request, params: Params) -> Response {
23 self.as_ref().handle(req, params).await
24 }
25}
26
27#[async_trait(?Send)]
28impl<F, Fut> Handler for F
29where
30 F: Fn(Request, Params) -> Fut + 'static,
31 Fut: Future<Output = Response> + 'static,
32{
33 async fn handle(&self, req: Request, params: Params) -> Response {
34 let fut = (self)(req, params);
35 fut.await
36 }
37}
38
39pub type Params = Captures<'static, 'static>;
41
42pub struct Router {
158 methods_map: HashMap<Method, MethodRouter<Box<dyn Handler>>>,
159 any_methods: MethodRouter<Box<dyn Handler>>,
160}
161
162impl Default for Router {
163 fn default() -> Router {
164 Router::new()
165 }
166}
167
168impl Display for Router {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 writeln!(f, "Registered routes:")?;
171 for (method, router) in &self.methods_map {
172 for route in router.iter() {
173 writeln!(f, "- {}: {}", method, route.0)?;
174 }
175 }
176 Ok(())
177 }
178}
179
180struct RouteMatch<'a> {
181 params: Captures<'static, 'static>,
182 handler: &'a dyn Handler,
183}
184
185impl Router {
186 pub fn handle<R>(&self, request: R) -> Response
188 where
189 R: TryIntoRequest,
190 R::Error: IntoResponse,
191 {
192 crate::http::executor::run(self.handle_async(request))
193 }
194
195 pub async fn handle_async<R>(&self, request: R) -> Response
197 where
198 R: TryIntoRequest,
199 R::Error: IntoResponse,
200 {
201 let request = match R::try_into_request(request) {
202 Ok(r) => r,
203 Err(e) => return e.into_response(),
204 };
205 let method = request.method.clone();
206 let path = &request.path();
207 let RouteMatch { params, handler } = self.find(path, method);
208 handler.handle(request, params).await
209 }
210
211 fn find(&self, path: &str, method: Method) -> RouteMatch<'_> {
212 let best_match = self
213 .methods_map
214 .get(&method)
215 .and_then(|r| r.best_match(path));
216
217 if let Some(m) = best_match {
218 let params = m.captures().into_owned();
219 let handler = m.handler();
220 return RouteMatch { handler, params };
221 }
222
223 let best_match = self.any_methods.best_match(path);
224
225 match best_match {
226 Some(m) => {
227 let params = m.captures().into_owned();
228 let handler = m.handler();
229 RouteMatch { handler, params }
230 }
231 None if method == Method::Head => {
232 self.find(path, Method::Get)
235 }
236 None => {
237 self.fail(path, method)
239 }
240 }
241 }
242
243 fn fail(&self, path: &str, method: Method) -> RouteMatch<'_> {
245 let is_method_not_allowed = self
247 .methods_map
248 .iter()
249 .filter(|(k, _)| **k != method)
250 .any(|(_, r)| r.best_match(path).is_some());
251
252 if is_method_not_allowed {
253 RouteMatch {
256 handler: &method_not_allowed,
257 params: Captures::default(),
258 }
259 } else {
260 RouteMatch {
262 handler: ¬_found,
263 params: Captures::default(),
264 }
265 }
266 }
267
268 pub fn any<F, Req, Resp>(&mut self, path: &str, handler: F)
270 where
271 F: Fn(Req, Params) -> Resp + 'static,
272 Req: TryFromRequest + 'static,
273 Req::Error: IntoResponse + 'static,
274 Resp: IntoResponse + 'static,
275 {
276 let handler = move |req, params| {
277 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
278 async move {
279 match res {
280 Ok(res) => res.into_response(),
281 Err(e) => e.into_response(),
282 }
283 }
284 };
285
286 self.any_async(path, handler)
287 }
288
289 pub fn any_async<F, Fut, I, O>(&mut self, path: &str, handler: F)
291 where
292 F: Fn(I, Params) -> Fut + 'static,
293 Fut: Future<Output = O> + 'static,
294 I: TryFromRequest + 'static,
295 I::Error: IntoResponse + 'static,
296 O: IntoResponse + 'static,
297 {
298 let handler = move |req, params| {
299 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
300 async move {
301 match res {
302 Ok(f) => f.await.into_response(),
303 Err(e) => e.into_response(),
304 }
305 }
306 };
307
308 self.any_methods.add(path, Box::new(handler)).unwrap();
309 }
310
311 pub fn add<F, Req, Resp>(&mut self, path: &str, method: Method, handler: F)
313 where
314 F: Fn(Req, Params) -> Resp + 'static,
315 Req: TryFromRequest + 'static,
316 Req::Error: IntoResponse + 'static,
317 Resp: IntoResponse + 'static,
318 {
319 let handler = move |req, params| {
320 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
321 async move {
322 match res {
323 Ok(res) => res.into_response(),
324 Err(e) => e.into_response(),
325 }
326 }
327 };
328
329 self.add_async(path, method, handler)
330 }
331
332 pub fn add_async<F, Fut, I, O>(&mut self, path: &str, method: Method, handler: F)
334 where
335 F: Fn(I, Params) -> Fut + 'static,
336 Fut: Future<Output = O> + 'static,
337 I: TryFromRequest + 'static,
338 I::Error: IntoResponse + 'static,
339 O: IntoResponse + 'static,
340 {
341 let handler = move |req, params| {
342 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
343 async move {
344 match res {
345 Ok(f) => f.await.into_response(),
346 Err(e) => e.into_response(),
347 }
348 }
349 };
350
351 self.methods_map
352 .entry(method)
353 .or_default()
354 .add(path, Box::new(handler))
355 .unwrap();
356 }
357
358 pub fn get<F, Req, Resp>(&mut self, path: &str, handler: F)
360 where
361 F: Fn(Req, Params) -> Resp + 'static,
362 Req: TryFromRequest + 'static,
363 Req::Error: IntoResponse + 'static,
364 Resp: IntoResponse + 'static,
365 {
366 self.add(path, Method::Get, handler)
367 }
368
369 pub fn get_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
371 where
372 F: Fn(Req, Params) -> Fut + 'static,
373 Fut: Future<Output = Resp> + 'static,
374 Req: TryFromRequest + 'static,
375 Req::Error: IntoResponse + 'static,
376 Resp: IntoResponse + 'static,
377 {
378 self.add_async(path, Method::Get, handler)
379 }
380
381 pub fn head<F, Req, Resp>(&mut self, path: &str, handler: F)
383 where
384 F: Fn(Req, Params) -> Resp + 'static,
385 Req: TryFromRequest + 'static,
386 Req::Error: IntoResponse + 'static,
387 Resp: IntoResponse + 'static,
388 {
389 self.add(path, Method::Head, handler)
390 }
391
392 pub fn head_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
394 where
395 F: Fn(Req, Params) -> Fut + 'static,
396 Fut: Future<Output = Resp> + 'static,
397 Req: TryFromRequest + 'static,
398 Req::Error: IntoResponse + 'static,
399 Resp: IntoResponse + 'static,
400 {
401 self.add_async(path, Method::Head, handler)
402 }
403
404 pub fn post<F, Req, Resp>(&mut self, path: &str, handler: F)
406 where
407 F: Fn(Req, Params) -> Resp + 'static,
408 Req: TryFromRequest + 'static,
409 Req::Error: IntoResponse + 'static,
410 Resp: IntoResponse + 'static,
411 {
412 self.add(path, Method::Post, handler)
413 }
414
415 pub fn post_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
417 where
418 F: Fn(Req, Params) -> Fut + 'static,
419 Fut: Future<Output = Resp> + 'static,
420 Req: TryFromRequest + 'static,
421 Req::Error: IntoResponse + 'static,
422 Resp: IntoResponse + 'static,
423 {
424 self.add_async(path, Method::Post, handler)
425 }
426
427 pub fn delete<F, Req, Resp>(&mut self, path: &str, handler: F)
429 where
430 F: Fn(Req, Params) -> Resp + 'static,
431 Req: TryFromRequest + 'static,
432 Req::Error: IntoResponse + 'static,
433 Resp: IntoResponse + 'static,
434 {
435 self.add(path, Method::Delete, handler)
436 }
437
438 pub fn delete_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
440 where
441 F: Fn(Req, Params) -> Fut + 'static,
442 Fut: Future<Output = Resp> + 'static,
443 Req: TryFromRequest + 'static,
444 Req::Error: IntoResponse + 'static,
445 Resp: IntoResponse + 'static,
446 {
447 self.add_async(path, Method::Delete, handler)
448 }
449
450 pub fn put<F, Req, Resp>(&mut self, path: &str, handler: F)
452 where
453 F: Fn(Req, Params) -> Resp + 'static,
454 Req: TryFromRequest + 'static,
455 Req::Error: IntoResponse + 'static,
456 Resp: IntoResponse + 'static,
457 {
458 self.add(path, Method::Put, handler)
459 }
460
461 pub fn put_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
463 where
464 F: Fn(Req, Params) -> Fut + 'static,
465 Fut: Future<Output = Resp> + 'static,
466 Req: TryFromRequest + 'static,
467 Req::Error: IntoResponse + 'static,
468 Resp: IntoResponse + 'static,
469 {
470 self.add_async(path, Method::Put, handler)
471 }
472
473 pub fn patch<F, Req, Resp>(&mut self, path: &str, handler: F)
475 where
476 F: Fn(Req, Params) -> Resp + 'static,
477 Req: TryFromRequest + 'static,
478 Req::Error: IntoResponse + 'static,
479 Resp: IntoResponse + 'static,
480 {
481 self.add(path, Method::Patch, handler)
482 }
483
484 pub fn patch_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
486 where
487 F: Fn(Req, Params) -> Fut + 'static,
488 Fut: Future<Output = Resp> + 'static,
489 Req: TryFromRequest + 'static,
490 Req::Error: IntoResponse + 'static,
491 Resp: IntoResponse + 'static,
492 {
493 self.add_async(path, Method::Patch, handler)
494 }
495
496 pub fn options<F, Req, Resp>(&mut self, path: &str, handler: F)
498 where
499 F: Fn(Req, Params) -> Resp + 'static,
500 Req: TryFromRequest + 'static,
501 Req::Error: IntoResponse + 'static,
502 Resp: IntoResponse + 'static,
503 {
504 self.add(path, Method::Options, handler)
505 }
506
507 pub fn options_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
509 where
510 F: Fn(Req, Params) -> Fut + 'static,
511 Fut: Future<Output = Resp> + 'static,
512 Req: TryFromRequest + 'static,
513 Req::Error: IntoResponse + 'static,
514 Resp: IntoResponse + 'static,
515 {
516 self.add_async(path, Method::Options, handler)
517 }
518
519 pub fn new() -> Self {
521 Router {
522 methods_map: HashMap::default(),
523 any_methods: MethodRouter::new(),
524 }
525 }
526}
527
528async fn not_found(_req: Request, _params: Params) -> Response {
529 responses::not_found()
530}
531
532async fn method_not_allowed(_req: Request, _params: Params) -> Response {
533 responses::method_not_allowed()
534}
535
536#[macro_export]
538macro_rules! http_router {
539 ($($method:tt $path:literal => $h:expr),*) => {
540 {
541 let mut router = $crate::http::Router::new();
542 $(
543 $crate::http_router!(@build router $method $path => $h);
544 )*
545 router
546 }
547 };
548 (@build $r:ident HEAD $path:literal => $h:expr) => {
549 $r.head($path, $h);
550 };
551 (@build $r:ident GET $path:literal => $h:expr) => {
552 $r.get($path, $h);
553 };
554 (@build $r:ident PUT $path:literal => $h:expr) => {
555 $r.put($path, $h);
556 };
557 (@build $r:ident POST $path:literal => $h:expr) => {
558 $r.post($path, $h);
559 };
560 (@build $r:ident PATCH $path:literal => $h:expr) => {
561 $r.patch($path, $h);
562 };
563 (@build $r:ident DELETE $path:literal => $h:expr) => {
564 $r.delete($path, $h);
565 };
566 (@build $r:ident OPTIONS $path:literal => $h:expr) => {
567 $r.options($path, $h);
568 };
569 (@build $r:ident _ $path:literal => $h:expr) => {
570 $r.any($path, $h);
571 };
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577
578 fn make_request(method: Method, path: &str) -> Request {
579 Request::new(method, path)
580 }
581
582 fn echo_param(_req: Request, params: Params) -> Response {
583 match params.get("x") {
584 Some(path) => Response::new(200, path),
585 None => responses::not_found(),
586 }
587 }
588
589 #[test]
590 fn test_method_not_allowed() {
591 let mut router = Router::default();
592 router.get("/:x", echo_param);
593
594 let req = make_request(Method::Post, "/foobar");
595 let res = router.handle(req);
596 assert_eq!(res.status, hyperium::StatusCode::METHOD_NOT_ALLOWED);
597 }
598
599 #[test]
600 fn test_not_found() {
601 fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
602 Ok(Response::new(200, ()))
603 }
604
605 let mut router = Router::default();
606 router.get("/h1/:param", h1);
607
608 let req = make_request(Method::Get, "/h1/");
609 let res = router.handle(req);
610 assert_eq!(res.status, hyperium::StatusCode::NOT_FOUND);
611 }
612
613 #[test]
614 fn test_multi_param() {
615 fn multiply(_req: Request, params: Params) -> anyhow::Result<Response> {
616 let x: i64 = params.get("x").unwrap().parse()?;
617 let y: i64 = params.get("y").unwrap().parse()?;
618 Ok(Response::new(200, format!("{result}", result = x * y)))
619 }
620
621 let mut router = Router::default();
622 router.get("/multiply/:x/:y", multiply);
623
624 let req = make_request(Method::Get, "/multiply/2/4");
625 let res = router.handle(req);
626
627 assert_eq!(res.body, "8".to_owned().into_bytes());
628 }
629
630 #[test]
631 fn test_param() {
632 let mut router = Router::default();
633 router.get("/:x", echo_param);
634
635 let req = make_request(Method::Get, "/y");
636 let res = router.handle(req);
637
638 assert_eq!(res.body, "y".to_owned().into_bytes());
639 }
640
641 #[test]
642 fn test_wildcard() {
643 fn echo_wildcard(_req: Request, params: Params) -> Response {
644 match params.wildcard() {
645 Some(path) => Response::new(200, path),
646 None => responses::not_found(),
647 }
648 }
649
650 let mut router = Router::default();
651 router.get("/*", echo_wildcard);
652
653 let req = make_request(Method::Get, "/foo/bar");
654 let res = router.handle(req);
655 assert_eq!(res.status, hyperium::StatusCode::OK);
656 assert_eq!(res.body, "foo/bar".to_owned().into_bytes());
657 }
658
659 #[test]
660 fn test_wildcard_last_segment() {
661 let mut router = Router::default();
662 router.get("/:x/*", echo_param);
663
664 let req = make_request(Method::Get, "/foo/bar");
665 let res = router.handle(req);
666 assert_eq!(res.body, "foo".to_owned().into_bytes());
667 }
668
669 #[test]
670 fn test_router_display() {
671 let mut router = Router::default();
672 router.get("/:x", echo_param);
673
674 let expected = "Registered routes:\n- GET: /:x\n";
675 let actual = format!("{}", router);
676
677 assert_eq!(actual.as_str(), expected);
678 }
679
680 #[test]
681 fn test_ambiguous_wildcard_vs_star() {
682 fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
683 Ok(Response::new(200, "one/two"))
684 }
685
686 fn h2(_req: Request, _params: Params) -> anyhow::Result<Response> {
687 Ok(Response::new(200, "posts/*"))
688 }
689
690 let mut router = Router::default();
691 router.get("/:one/:two", h1);
692 router.get("/posts/*", h2);
693
694 let req = make_request(Method::Get, "/posts/2");
695 let res = router.handle(req);
696
697 assert_eq!(res.body, "posts/*".to_owned().into_bytes());
698 }
699}