1use std::{collections::HashMap, convert::Infallible};
11
12use http::status::StatusCode;
13use motore::{ServiceExt, layer::Layer, service::Service};
14
15use super::{
16 Fallback, Route,
17 method_router::MethodRouter,
18 utils::{Matcher, NEST_CATCH_PARAM, RouteId, StripPrefixLayer},
19};
20use crate::{
21 body::Body,
22 context::ServerContext,
23 request::Request,
24 response::Response,
25 server::{IntoResponse, handler::Handler},
26};
27
28#[must_use]
30pub struct Router<B = Body, E = Infallible> {
31 matcher: Matcher,
32 routes: HashMap<RouteId, Endpoint<B, E>>,
33 fallback: Fallback<B, E>,
34 is_default_fallback: bool,
35}
36
37impl<B, E> Default for Router<B, E>
38where
39 B: Send + 'static,
40 E: 'static,
41{
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl<B, E> Router<B, E>
48where
49 B: Send + 'static,
50 E: 'static,
51{
52 pub fn new() -> Self
54 where
55 E: 'static,
56 {
57 Self {
58 matcher: Default::default(),
59 routes: Default::default(),
60 fallback: Fallback::from_status_code(StatusCode::NOT_FOUND),
61 is_default_fallback: true,
62 }
63 }
64
65 pub fn route<S>(mut self, uri: S, method_router: MethodRouter<B, E>) -> Self
164 where
165 S: AsRef<str>,
166 {
167 let route_id = self
168 .matcher
169 .insert(uri.as_ref())
170 .expect("Insert routing rule failed");
171
172 self.routes
173 .insert(route_id, Endpoint::MethodRouter(method_router));
174
175 self
176 }
177
178 pub fn nest<U>(self, uri: U, router: Router<B, E>) -> Self
219 where
220 U: AsRef<str>,
221 {
222 self.nest_route(uri.as_ref().to_owned(), Route::new(router))
223 }
224
225 pub fn nest_service<U, S>(self, uri: U, service: S) -> Self
230 where
231 U: AsRef<str>,
232 S: Service<ServerContext, Request<B>, Error = E> + Send + Sync + 'static,
233 S::Response: IntoResponse,
234 {
235 self.nest_route(
236 uri.as_ref().to_owned(),
237 Route::new(service.map_response(IntoResponse::into_response)),
238 )
239 }
240
241 fn nest_route(mut self, prefix: String, route: Route<B, E>) -> Self {
242 let uri = if prefix.ends_with('/') {
243 format!("{prefix}{NEST_CATCH_PARAM}")
244 } else {
245 format!("{prefix}/{NEST_CATCH_PARAM}")
246 };
247
248 let route_id = self
251 .matcher
252 .insert(prefix.clone())
253 .expect("Insert routing rule failed");
254
255 if !prefix.ends_with('/') {
259 let prefix_with_slash = prefix + "/";
260 self.matcher
261 .insert_with_id(prefix_with_slash, route_id)
262 .expect("Insert routing rule failed");
263 }
264
265 self.matcher
266 .insert_with_id(uri, route_id)
267 .expect("Insert routing rule failed");
268
269 self.routes.insert(
270 route_id,
271 Endpoint::Service(Route::new(StripPrefixLayer.layer(route))),
272 );
273
274 self
275 }
276
277 pub fn fallback<H, T>(mut self, handler: H) -> Self
283 where
284 for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
285 T: 'static,
286 E: 'static,
287 {
288 self.fallback = Fallback::from_handler(handler);
289 self
290 }
291
292 pub fn fallback_service<S>(mut self, service: S) -> Self
298 where
299 for<'a> S: Service<ServerContext, Request<B>, Error = E> + Send + Sync + 'a,
300 S::Response: IntoResponse,
301 {
302 self.fallback = Fallback::from_service(service);
303 self
304 }
305
306 pub fn merge(mut self, other: Self) -> Self {
345 let Router {
346 mut matcher,
347 mut routes,
348 fallback,
349 is_default_fallback,
350 } = other;
351
352 for (path, route_id) in matcher.drain() {
353 self.matcher
354 .insert_with_id(path, route_id)
355 .expect("Insert routing rule failed during merging router");
356 }
357 for (route_id, method_router) in routes.drain() {
358 if self.routes.insert(route_id, method_router).is_some() {
359 unreachable!()
360 }
361 }
362
363 match (self.is_default_fallback, is_default_fallback) {
364 (_, true) => {}
365 (true, false) => {
366 self.fallback = fallback;
367 self.is_default_fallback = false;
368 }
369 (false, false) => {
370 panic!("Merge `Router` failed because both `Router` have customized `fallback`")
371 }
372 }
373
374 self
375 }
376
377 pub fn layer<L, B2, E2>(self, l: L) -> Router<B2, E2>
381 where
382 L: Layer<Route<B, E>> + Clone + Send + Sync + 'static,
383 L::Service: Service<ServerContext, Request<B2>, Error = E2> + Send + Sync + 'static,
384 <L::Service as Service<ServerContext, Request<B2>>>::Response: IntoResponse,
385 B2: 'static,
386 {
387 let routes = self
388 .routes
389 .into_iter()
390 .map(|(id, route)| {
391 let route = route.layer(l.clone());
392 (id, route)
393 })
394 .collect();
395
396 let fallback = self.fallback.layer(l.clone());
397
398 Router {
399 matcher: self.matcher,
400 routes,
401 fallback,
402 is_default_fallback: self.is_default_fallback,
403 }
404 }
405}
406
407impl<B, E> Service<ServerContext, Request<B>> for Router<B, E>
408where
409 B: Send + 'static,
410 E: 'static,
411{
412 type Response = Response;
413 type Error = E;
414
415 async fn call(
416 &self,
417 cx: &mut ServerContext,
418 req: Request<B>,
419 ) -> Result<Self::Response, Self::Error> {
420 if let Ok(matched) = self.matcher.at(req.uri().clone().path()) {
421 if let Some(route) = self.routes.get(matched.value) {
422 if !matched.params.is_empty() {
423 cx.params_mut().extend(matched.params);
424 }
425 return route.call(cx, req).await;
426 }
427 }
428
429 self.fallback.call(cx, req).await
430 }
431}
432
433#[allow(clippy::large_enum_variant)]
434enum Endpoint<B = Body, E = Infallible> {
435 MethodRouter(MethodRouter<B, E>),
436 Service(Route<B, E>),
437}
438
439impl<B, E> Service<ServerContext, Request<B>> for Endpoint<B, E>
440where
441 B: Send + 'static,
442 E: 'static,
443{
444 type Response = Response;
445 type Error = E;
446
447 async fn call(
448 &self,
449 cx: &mut ServerContext,
450 req: Request<B>,
451 ) -> Result<Self::Response, Self::Error> {
452 match self {
453 Self::MethodRouter(mr) => mr.call(cx, req).await,
454 Self::Service(service) => service.call(cx, req).await,
455 }
456 }
457}
458
459impl<B, E> Default for Endpoint<B, E>
460where
461 B: Send + 'static,
462 E: 'static,
463{
464 fn default() -> Self {
465 Self::MethodRouter(Default::default())
466 }
467}
468
469impl<B, E> Endpoint<B, E>
470where
471 B: Send + 'static,
472 E: 'static,
473{
474 fn layer<L, B2, E2>(self, l: L) -> Endpoint<B2, E2>
475 where
476 L: Layer<Route<B, E>> + Clone + Send + Sync + 'static,
477 L::Service: Service<ServerContext, Request<B2>, Error = E2> + Send + Sync,
478 <L::Service as Service<ServerContext, Request<B2>>>::Response: IntoResponse,
479 B2: 'static,
480 {
481 match self {
482 Self::MethodRouter(mr) => Endpoint::MethodRouter(mr.layer(l)),
483 Self::Service(s) => Endpoint::Service(Route::new(
484 l.layer(s).map_response(IntoResponse::into_response),
485 )),
486 }
487 }
488}
489
490#[cfg(test)]
491mod router_tests {
492 use faststr::FastStr;
493 use http::{method::Method, status::StatusCode, uri::Uri};
494
495 use super::Router;
496 use crate::{
497 body::{Body, BodyConversion},
498 server::{
499 Server, param::PathParamsVec, route::method_router::any, test_helpers::TestServer,
500 },
501 };
502
503 async fn always_ok() {}
504 async fn teapot() -> StatusCode {
505 StatusCode::IM_A_TEAPOT
506 }
507
508 #[tokio::test]
509 async fn url_match() {
510 async fn is_ok(server: &TestServer<Router<Option<Body>>, Option<Body>>, uri: &str) -> bool {
511 server.call_route(Method::GET, uri, None).await.status() == StatusCode::OK
512 }
513 let router: Router<Option<Body>> = Router::new()
514 .route("/", any(always_ok))
515 .route("/catch/{id}", any(always_ok))
516 .route("/catch/{id}/another", any(always_ok))
517 .route("/catch/{id}/another/{uid}", any(always_ok))
518 .route("/catch/{id}/another/{uid}/again", any(always_ok))
519 .route("/catch/{id}/another/{uid}/again/{tid}", any(always_ok))
520 .route("/catch_all/{*all}", any(always_ok));
521 let server = Server::new(router).into_test_server();
522
523 assert!(is_ok(&server, "/").await);
524 assert!(is_ok(&server, "/catch/114").await);
525 assert!(is_ok(&server, "/catch/514").await);
526 assert!(is_ok(&server, "/catch/ll45l4").await);
527 assert!(is_ok(&server, "/catch/ll45l4/another").await);
528 assert!(is_ok(&server, "/catch/ll45l4/another/1919").await);
529 assert!(is_ok(&server, "/catch/ll45l4/another/1919/again").await);
530 assert!(is_ok(&server, "/catch/ll45l4/another/1919/again/810").await);
531 assert!(is_ok(&server, "/catch_all/114").await);
532 assert!(is_ok(&server, "/catch_all/114/514/1919/810").await);
533
534 assert!(!is_ok(&server, "/catch").await);
535 assert!(!is_ok(&server, "/catch/114/").await);
536 assert!(!is_ok(&server, "/catch/114/another/514/").await);
537 assert!(!is_ok(&server, "/catch/11/another/45/again/14/").await);
538 assert!(!is_ok(&server, "/catch_all").await);
539 assert!(!is_ok(&server, "/catch_all/").await);
540 }
541
542 #[tokio::test]
543 async fn router_fallback() {
544 async fn is_teapot(
545 server: &TestServer<Router<Option<Body>>, Option<Body>>,
546 uri: &str,
547 ) -> bool {
548 server.call_route(Method::GET, uri, None).await.status() == StatusCode::IM_A_TEAPOT
549 }
550 let router: Router<Option<Body>> = Router::new()
551 .route("/", any(always_ok))
552 .route("/catch/{id}", any(always_ok))
553 .route("/catch_all/{*all}", any(always_ok))
554 .fallback(teapot);
555 let server = Server::new(router).into_test_server();
556
557 assert!(is_teapot(&server, "//").await);
558 assert!(is_teapot(&server, "/catch/").await);
559 assert!(is_teapot(&server, "/catch_all/").await);
560
561 assert!(!is_teapot(&server, "/catch/114").await);
562 assert!(!is_teapot(&server, "/catch_all/514").await);
563 assert!(!is_teapot(&server, "/catch_all/114/514/1919/810").await);
564 }
565
566 #[tokio::test]
567 async fn nest_router() {
568 async fn uri_and_params(uri: Uri, params: PathParamsVec) -> String {
569 let mut v = vec![FastStr::from_string(uri.to_string())];
570 v.extend(params.into_iter().map(|(_, v)| v));
571 v.join("\n")
572 }
573 async fn get_res(
574 server: &TestServer<Router<Option<Body>>, Option<Body>>,
575 uri: &str,
576 ) -> String {
577 server
578 .call_route(Method::GET, uri, None)
579 .await
580 .into_string()
581 .await
582 .unwrap()
583 }
584
585 let router: Router<Option<Body>> = Router::new()
586 .nest(
587 "/test-1",
589 Router::new()
590 .route("/", any(uri_and_params))
591 .route("/id/{id}", any(uri_and_params))
592 .route("/catch/{*content}", any(uri_and_params)),
593 )
594 .nest(
595 "/test-2/",
597 Router::new()
598 .route("/", any(uri_and_params))
599 .route("/id/{id}", any(uri_and_params))
600 .route("/catch/{*content}", any(uri_and_params)),
601 )
602 .nest(
603 "/test-3/{catch}",
605 Router::new()
606 .route("/", any(uri_and_params))
607 .route("/id/{id}", any(uri_and_params))
608 .route("/catch/{*content}", any(uri_and_params)),
609 )
610 .nest(
611 "/test-4/{catch}/",
613 Router::new()
614 .route("/", any(uri_and_params))
615 .route("/id/{id}", any(uri_and_params))
616 .route("/catch/{*content}", any(uri_and_params)),
617 );
618 let server = Server::new(router).into_test_server();
619
620 assert_eq!(get_res(&server, "/test-1").await, "/");
622 assert_eq!(get_res(&server, "/test-1/").await, "/");
623 assert_eq!(get_res(&server, "/test-1/id/114").await, "/id/114\n114");
624 assert_eq!(
625 get_res(&server, "/test-1/catch/114/514/1919/810").await,
626 "/catch/114/514/1919/810\n114/514/1919/810",
627 );
628
629 assert!(get_res(&server, "/test-2").await.is_empty());
631 assert_eq!(get_res(&server, "/test-2/").await, "/");
632 assert_eq!(get_res(&server, "/test-2/id/114").await, "/id/114\n114");
633 assert_eq!(
634 get_res(&server, "/test-2/catch/114/514/1919/810").await,
635 "/catch/114/514/1919/810\n114/514/1919/810",
636 );
637
638 assert_eq!(get_res(&server, "/test-3/114").await, "/\n114");
640 assert_eq!(get_res(&server, "/test-3/114/").await, "/\n114");
641 assert_eq!(
642 get_res(&server, "/test-3/114/id/514").await,
643 "/id/514\n114\n514",
644 );
645 assert_eq!(
646 get_res(&server, "/test-3/114/catch/514/1919/810").await,
647 "/catch/514/1919/810\n114\n514/1919/810",
648 );
649
650 assert!(get_res(&server, "/test-4/114").await.is_empty());
652 assert_eq!(get_res(&server, "/test-4/114/").await, "/\n114");
653 assert_eq!(
654 get_res(&server, "/test-4/114/id/514").await,
655 "/id/514\n114\n514",
656 );
657 assert_eq!(
658 get_res(&server, "/test-4/114/catch/514/1919/810").await,
659 "/catch/514/1919/810\n114\n514/1919/810",
660 );
661 }
662
663 #[tokio::test]
664 async fn deep_nest_router() {
665 async fn uri_and_params(uri: Uri, params: PathParamsVec) -> String {
666 let mut v = vec![FastStr::from_string(uri.to_string())];
667 v.extend(params.into_iter().map(|(_, v)| v));
668 v.join("\n")
669 }
670 async fn get_res(
671 server: &TestServer<Router<Option<Body>>, Option<Body>>,
672 uri: &str,
673 ) -> String {
674 server
675 .call_route(Method::GET, uri, None)
676 .await
677 .into_string()
678 .await
679 .unwrap()
680 }
681
682 let router: Router<Option<Body>> = Router::new().nest(
684 "/test-1/{catch1}",
685 Router::new().nest(
686 "/test-2/{catch2}/",
687 Router::new().nest(
688 "/test-3",
689 Router::new()
690 .route("/", any(uri_and_params))
691 .route("/id/{id}", any(uri_and_params))
692 .route("/catch/{*content}", any(uri_and_params)),
693 ),
694 ),
695 );
696 let server = Server::new(router).into_test_server();
697
698 assert_eq!(
702 get_res(&server, "/test-1/114/test-2/514/test-3/").await,
703 "/\n114\n514",
704 );
705 assert_eq!(
710 get_res(&server, "/test-1/114/test-2/514/test-3/id/1919").await,
711 "/id/1919\n114\n514\n1919",
712 );
713 assert_eq!(
718 get_res(&server, "/test-1/114/test-2/514/test-3/catch/1919/810").await,
719 "/catch/1919/810\n114\n514\n1919/810",
720 );
721 }
722
723 #[tokio::test]
724 async fn nest_router_with_query() {
725 async fn get_query(uri: Uri) -> Result<String, StatusCode> {
726 if let Some(query) = uri.query() {
727 Ok(query.to_owned())
728 } else {
729 Err(StatusCode::BAD_REQUEST)
730 }
731 }
732 async fn get_res(
733 server: &TestServer<Router<Option<Body>>, Option<Body>>,
734 uri: &str,
735 ) -> Result<String, StatusCode> {
736 let resp = server.call_route(Method::GET, uri, None).await;
737 if resp.status().is_success() {
738 Ok(resp
739 .into_string()
740 .await
741 .expect("response is not a valid string"))
742 } else {
743 Err(resp.status())
744 }
745 }
746
747 let router: Router<Option<Body>> =
748 Router::new().nest("/nest", Router::new().route("/query", any(get_query)));
749 let server = Server::new(router).into_test_server();
750
751 assert_eq!(
752 get_res(&server, "/nest/query?foo=bar").await.unwrap(),
753 "foo=bar",
754 );
755 assert_eq!(get_res(&server, "/nest/query?foo").await.unwrap(), "foo");
756 assert_eq!(get_res(&server, "/nest/query?").await.unwrap(), "");
757 assert!(get_res(&server, "/nest/query").await.is_err());
758 }
759
760 #[tokio::test]
761 async fn deep_nest_router_with_query() {
762 async fn get_query(uri: Uri) -> Result<String, StatusCode> {
763 if let Some(query) = uri.query() {
764 Ok(query.to_owned())
765 } else {
766 Err(StatusCode::BAD_REQUEST)
767 }
768 }
769 async fn get_res(
770 server: &TestServer<Router<Option<Body>>, Option<Body>>,
771 uri: &str,
772 ) -> Result<String, StatusCode> {
773 let resp = server.call_route(Method::GET, uri, None).await;
774 if resp.status().is_success() {
775 Ok(resp
776 .into_string()
777 .await
778 .expect("response is not a valid string"))
779 } else {
780 Err(resp.status())
781 }
782 }
783
784 let router: Router<Option<Body>> = Router::new().nest(
785 "/nest-1",
786 Router::new().nest(
787 "/nest-2",
788 Router::new().nest("/nest-3", Router::new().route("/query", any(get_query))),
789 ),
790 );
791 let server = Server::new(router).into_test_server();
792
793 assert_eq!(
794 get_res(&server, "/nest-1/nest-2/nest-3/query?foo=bar")
795 .await
796 .unwrap(),
797 "foo=bar",
798 );
799 assert_eq!(
800 get_res(&server, "/nest-1/nest-2/nest-3/query?foo")
801 .await
802 .unwrap(),
803 "foo",
804 );
805 assert_eq!(
806 get_res(&server, "/nest-1/nest-2/nest-3/query?")
807 .await
808 .unwrap(),
809 "",
810 );
811 assert!(
812 get_res(&server, "/nest-1/nest-2/nest-3/query")
813 .await
814 .is_err()
815 );
816 }
817}