1use super::conversions::{IntoResponse, TryFromRequest, TryIntoRequest};
4use super::{responses, Method, Request, Response};
5use async_trait::async_trait;
6use routefinder::{Captures, Router as MethodRouter};
7use std::future::Future;
8use std::{collections::HashMap, fmt::Display};
9
10#[async_trait(?Send)]
15pub trait Handler {
16 async fn handle(&self, req: Request, params: Params) -> Response;
18}
19
20#[async_trait(?Send)]
21impl Handler for Box<dyn Handler> {
22 async fn handle(&self, req: Request, params: Params) -> Response {
23 self.as_ref().handle(req, params).await
24 }
25}
26
27#[async_trait(?Send)]
28impl<F, Fut> Handler for F
29where
30 F: Fn(Request, Params) -> Fut + 'static,
31 Fut: Future<Output = Response> + 'static,
32{
33 async fn handle(&self, req: Request, params: Params) -> Response {
34 let fut = (self)(req, params);
35 fut.await
36 }
37}
38
39pub type Params = Captures<'static, 'static>;
41
42pub struct Router {
181 methods_map: HashMap<Method, MethodRouter<Box<dyn Handler>>>,
182 any_methods: MethodRouter<Box<dyn Handler>>,
183 route_on: RouteOn,
184}
185
186enum RouteOn {
188 FullPath,
190 Suffix,
193}
194
195impl Default for Router {
196 fn default() -> Router {
197 Router::new()
198 }
199}
200
201impl Display for Router {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 writeln!(f, "Registered routes:")?;
204 for (method, router) in &self.methods_map {
205 for route in router.iter() {
206 writeln!(f, "- {}: {}", method, route.0)?;
207 }
208 }
209 Ok(())
210 }
211}
212
213struct RouteMatch<'a> {
214 params: Captures<'static, 'static>,
215 handler: &'a dyn Handler,
216}
217
218impl Router {
219 pub fn handle<R>(&self, request: R) -> Response
221 where
222 R: TryIntoRequest,
223 R::Error: IntoResponse,
224 {
225 crate::http::executor::run(self.handle_async(request))
226 }
227
228 pub async fn handle_async<R>(&self, request: R) -> Response
230 where
231 R: TryIntoRequest,
232 R::Error: IntoResponse,
233 {
234 let request = match R::try_into_request(request) {
235 Ok(r) => r,
236 Err(e) => return e.into_response(),
237 };
238 let method = request.method.clone();
239 let path = match self.route_on {
240 RouteOn::FullPath => request.path(),
241 RouteOn::Suffix => match trailing_suffix(&request) {
242 Some(path) => path,
243 None => {
244 eprintln!("Internal error: Router configured with suffix routing but trigger route has no trailing wildcard");
245 return responses::internal_server_error();
246 }
247 },
248 };
249 let RouteMatch { params, handler } = self.find(path, method);
250 handler.handle(request, params).await
251 }
252
253 fn find(&self, path: &str, method: Method) -> RouteMatch<'_> {
254 let best_match = self
255 .methods_map
256 .get(&method)
257 .and_then(|r| r.best_match(path));
258
259 if let Some(m) = best_match {
260 let params = m.captures().into_owned();
261 let handler = m.handler();
262 return RouteMatch { handler, params };
263 }
264
265 let best_match = self.any_methods.best_match(path);
266
267 match best_match {
268 Some(m) => {
269 let params = m.captures().into_owned();
270 let handler = m.handler();
271 RouteMatch { handler, params }
272 }
273 None if method == Method::Head => {
274 self.find(path, Method::Get)
277 }
278 None => {
279 self.fail(path, method)
281 }
282 }
283 }
284
285 fn fail(&self, path: &str, method: Method) -> RouteMatch<'_> {
287 let is_method_not_allowed = self
289 .methods_map
290 .iter()
291 .filter(|(k, _)| **k != method)
292 .any(|(_, r)| r.best_match(path).is_some());
293
294 if is_method_not_allowed {
295 RouteMatch {
298 handler: &method_not_allowed,
299 params: Captures::default(),
300 }
301 } else {
302 RouteMatch {
304 handler: ¬_found,
305 params: Captures::default(),
306 }
307 }
308 }
309
310 pub fn any<F, Req, Resp>(&mut self, path: &str, handler: F)
312 where
313 F: Fn(Req, Params) -> Resp + 'static,
314 Req: TryFromRequest + 'static,
315 Req::Error: IntoResponse + 'static,
316 Resp: IntoResponse + 'static,
317 {
318 let handler = move |req, params| {
319 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
320 async move {
321 match res {
322 Ok(res) => res.into_response(),
323 Err(e) => e.into_response(),
324 }
325 }
326 };
327
328 self.any_async(path, handler)
329 }
330
331 pub fn any_async<F, Fut, I, O>(&mut self, path: &str, handler: F)
333 where
334 F: Fn(I, Params) -> Fut + 'static,
335 Fut: Future<Output = O> + 'static,
336 I: TryFromRequest + 'static,
337 I::Error: IntoResponse + 'static,
338 O: IntoResponse + 'static,
339 {
340 let handler = move |req, params| {
341 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
342 async move {
343 match res {
344 Ok(f) => f.await.into_response(),
345 Err(e) => e.into_response(),
346 }
347 }
348 };
349
350 self.any_methods.add(path, Box::new(handler)).unwrap();
351 }
352
353 pub fn add<F, Req, Resp>(&mut self, path: &str, method: Method, handler: F)
355 where
356 F: Fn(Req, Params) -> Resp + 'static,
357 Req: TryFromRequest + 'static,
358 Req::Error: IntoResponse + 'static,
359 Resp: IntoResponse + 'static,
360 {
361 let handler = move |req, params| {
362 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
363 async move {
364 match res {
365 Ok(res) => res.into_response(),
366 Err(e) => e.into_response(),
367 }
368 }
369 };
370
371 self.add_async(path, method, handler)
372 }
373
374 pub fn add_async<F, Fut, I, O>(&mut self, path: &str, method: Method, handler: F)
376 where
377 F: Fn(I, Params) -> Fut + 'static,
378 Fut: Future<Output = O> + 'static,
379 I: TryFromRequest + 'static,
380 I::Error: IntoResponse + 'static,
381 O: IntoResponse + 'static,
382 {
383 let handler = move |req, params| {
384 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
385 async move {
386 match res {
387 Ok(f) => f.await.into_response(),
388 Err(e) => e.into_response(),
389 }
390 }
391 };
392
393 self.methods_map
394 .entry(method)
395 .or_default()
396 .add(path, Box::new(handler))
397 .unwrap();
398 }
399
400 pub fn get<F, Req, Resp>(&mut self, path: &str, handler: F)
402 where
403 F: Fn(Req, Params) -> Resp + 'static,
404 Req: TryFromRequest + 'static,
405 Req::Error: IntoResponse + 'static,
406 Resp: IntoResponse + 'static,
407 {
408 self.add(path, Method::Get, handler)
409 }
410
411 pub fn get_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
413 where
414 F: Fn(Req, Params) -> Fut + 'static,
415 Fut: Future<Output = Resp> + 'static,
416 Req: TryFromRequest + 'static,
417 Req::Error: IntoResponse + 'static,
418 Resp: IntoResponse + 'static,
419 {
420 self.add_async(path, Method::Get, handler)
421 }
422
423 pub fn head<F, Req, Resp>(&mut self, path: &str, handler: F)
425 where
426 F: Fn(Req, Params) -> Resp + 'static,
427 Req: TryFromRequest + 'static,
428 Req::Error: IntoResponse + 'static,
429 Resp: IntoResponse + 'static,
430 {
431 self.add(path, Method::Head, handler)
432 }
433
434 pub fn head_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
436 where
437 F: Fn(Req, Params) -> Fut + 'static,
438 Fut: Future<Output = Resp> + 'static,
439 Req: TryFromRequest + 'static,
440 Req::Error: IntoResponse + 'static,
441 Resp: IntoResponse + 'static,
442 {
443 self.add_async(path, Method::Head, handler)
444 }
445
446 pub fn post<F, Req, Resp>(&mut self, path: &str, handler: F)
448 where
449 F: Fn(Req, Params) -> Resp + 'static,
450 Req: TryFromRequest + 'static,
451 Req::Error: IntoResponse + 'static,
452 Resp: IntoResponse + 'static,
453 {
454 self.add(path, Method::Post, handler)
455 }
456
457 pub fn post_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
459 where
460 F: Fn(Req, Params) -> Fut + 'static,
461 Fut: Future<Output = Resp> + 'static,
462 Req: TryFromRequest + 'static,
463 Req::Error: IntoResponse + 'static,
464 Resp: IntoResponse + 'static,
465 {
466 self.add_async(path, Method::Post, handler)
467 }
468
469 pub fn delete<F, Req, Resp>(&mut self, path: &str, handler: F)
471 where
472 F: Fn(Req, Params) -> Resp + 'static,
473 Req: TryFromRequest + 'static,
474 Req::Error: IntoResponse + 'static,
475 Resp: IntoResponse + 'static,
476 {
477 self.add(path, Method::Delete, handler)
478 }
479
480 pub fn delete_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
482 where
483 F: Fn(Req, Params) -> Fut + 'static,
484 Fut: Future<Output = Resp> + 'static,
485 Req: TryFromRequest + 'static,
486 Req::Error: IntoResponse + 'static,
487 Resp: IntoResponse + 'static,
488 {
489 self.add_async(path, Method::Delete, handler)
490 }
491
492 pub fn put<F, Req, Resp>(&mut self, path: &str, handler: F)
494 where
495 F: Fn(Req, Params) -> Resp + 'static,
496 Req: TryFromRequest + 'static,
497 Req::Error: IntoResponse + 'static,
498 Resp: IntoResponse + 'static,
499 {
500 self.add(path, Method::Put, handler)
501 }
502
503 pub fn put_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
505 where
506 F: Fn(Req, Params) -> Fut + 'static,
507 Fut: Future<Output = Resp> + 'static,
508 Req: TryFromRequest + 'static,
509 Req::Error: IntoResponse + 'static,
510 Resp: IntoResponse + 'static,
511 {
512 self.add_async(path, Method::Put, handler)
513 }
514
515 pub fn patch<F, Req, Resp>(&mut self, path: &str, handler: F)
517 where
518 F: Fn(Req, Params) -> Resp + 'static,
519 Req: TryFromRequest + 'static,
520 Req::Error: IntoResponse + 'static,
521 Resp: IntoResponse + 'static,
522 {
523 self.add(path, Method::Patch, handler)
524 }
525
526 pub fn patch_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
528 where
529 F: Fn(Req, Params) -> Fut + 'static,
530 Fut: Future<Output = Resp> + 'static,
531 Req: TryFromRequest + 'static,
532 Req::Error: IntoResponse + 'static,
533 Resp: IntoResponse + 'static,
534 {
535 self.add_async(path, Method::Patch, handler)
536 }
537
538 pub fn options<F, Req, Resp>(&mut self, path: &str, handler: F)
540 where
541 F: Fn(Req, Params) -> Resp + 'static,
542 Req: TryFromRequest + 'static,
543 Req::Error: IntoResponse + 'static,
544 Resp: IntoResponse + 'static,
545 {
546 self.add(path, Method::Options, handler)
547 }
548
549 pub fn options_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
551 where
552 F: Fn(Req, Params) -> Fut + 'static,
553 Fut: Future<Output = Resp> + 'static,
554 Req: TryFromRequest + 'static,
555 Req::Error: IntoResponse + 'static,
556 Resp: IntoResponse + 'static,
557 {
558 self.add_async(path, Method::Options, handler)
559 }
560
561 pub fn new() -> Self {
563 Router {
564 methods_map: HashMap::default(),
565 any_methods: MethodRouter::new(),
566 route_on: RouteOn::FullPath,
567 }
568 }
569
570 pub fn suffix() -> Self {
573 Router {
574 methods_map: HashMap::default(),
575 any_methods: MethodRouter::new(),
576 route_on: RouteOn::Suffix,
577 }
578 }
579}
580
581async fn not_found(_req: Request, _params: Params) -> Response {
582 responses::not_found()
583}
584
585async fn method_not_allowed(_req: Request, _params: Params) -> Response {
586 responses::method_not_allowed()
587}
588
589fn trailing_suffix(req: &Request) -> Option<&str> {
590 req.header("spin-path-info")
591 .and_then(|path_info| path_info.as_str())
592}
593
594#[macro_export]
621macro_rules! http_router {
622 ($($method:tt $path:literal => $h:expr),*) => {
623 {
624 let mut router = $crate::http::Router::new();
625 $(
626 $crate::http_router!(@build router $method $path => $h);
627 )*
628 router
629 }
630 };
631 (@build $r:ident HEAD $path:literal => $h:expr) => {
632 $r.head($path, $h);
633 };
634 (@build $r:ident GET $path:literal => $h:expr) => {
635 $r.get($path, $h);
636 };
637 (@build $r:ident PUT $path:literal => $h:expr) => {
638 $r.put($path, $h);
639 };
640 (@build $r:ident POST $path:literal => $h:expr) => {
641 $r.post($path, $h);
642 };
643 (@build $r:ident PATCH $path:literal => $h:expr) => {
644 $r.patch($path, $h);
645 };
646 (@build $r:ident DELETE $path:literal => $h:expr) => {
647 $r.delete($path, $h);
648 };
649 (@build $r:ident OPTIONS $path:literal => $h:expr) => {
650 $r.options($path, $h);
651 };
652 (@build $r:ident _ $path:literal => $h:expr) => {
653 $r.any($path, $h);
654 };
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660
661 fn make_request(method: Method, path: &str) -> Request {
662 Request::new(method, path)
663 }
664
665 fn make_wildcard_request(method: Method, path: &str, trailing: &str) -> Request {
666 let mut req = Request::new(method, path);
667 req.set_header("spin-path-info", trailing);
668 req
669 }
670
671 fn echo_param(_req: Request, params: Params) -> Response {
672 match params.get("x") {
673 Some(path) => Response::new(200, path),
674 None => responses::not_found(),
675 }
676 }
677
678 #[test]
679 fn test_method_not_allowed() {
680 let mut router = Router::default();
681 router.get("/:x", echo_param);
682
683 let req = make_request(Method::Post, "/foobar");
684 let res = router.handle(req);
685 assert_eq!(res.status, hyperium::StatusCode::METHOD_NOT_ALLOWED);
686 }
687
688 #[test]
689 fn test_not_found() {
690 fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
691 Ok(Response::new(200, ()))
692 }
693
694 let mut router = Router::default();
695 router.get("/h1/:param", h1);
696
697 let req = make_request(Method::Get, "/h1/");
698 let res = router.handle(req);
699 assert_eq!(res.status, hyperium::StatusCode::NOT_FOUND);
700 }
701
702 #[test]
703 fn test_multi_param() {
704 fn multiply(_req: Request, params: Params) -> anyhow::Result<Response> {
705 let x: i64 = params.get("x").unwrap().parse()?;
706 let y: i64 = params.get("y").unwrap().parse()?;
707 Ok(Response::new(200, format!("{result}", result = x * y)))
708 }
709
710 let mut router = Router::default();
711 router.get("/multiply/:x/:y", multiply);
712
713 let req = make_request(Method::Get, "/multiply/2/4");
714 let res = router.handle(req);
715
716 assert_eq!(res.body, "8".to_owned().into_bytes());
717 }
718
719 #[test]
720 fn test_param() {
721 let mut router = Router::default();
722 router.get("/:x", echo_param);
723
724 let req = make_request(Method::Get, "/y");
725 let res = router.handle(req);
726
727 assert_eq!(res.body, "y".to_owned().into_bytes());
728 }
729
730 #[test]
731 fn test_wildcard() {
732 fn echo_wildcard(_req: Request, params: Params) -> Response {
733 match params.wildcard() {
734 Some(path) => Response::new(200, path),
735 None => responses::not_found(),
736 }
737 }
738
739 let mut router = Router::default();
740 router.get("/*", echo_wildcard);
741
742 let req = make_request(Method::Get, "/foo/bar");
743 let res = router.handle(req);
744 assert_eq!(res.status, hyperium::StatusCode::OK);
745 assert_eq!(res.body, "foo/bar".to_owned().into_bytes());
746 }
747
748 #[test]
749 fn test_wildcard_last_segment() {
750 let mut router = Router::default();
751 router.get("/:x/*", echo_param);
752
753 let req = make_request(Method::Get, "/foo/bar");
754 let res = router.handle(req);
755 assert_eq!(res.body, "foo".to_owned().into_bytes());
756 }
757
758 #[test]
759 fn test_spin_trailing_wildcard() {
760 let mut router = Router::suffix();
761 router.get("/:x/*", echo_param);
762
763 let req = make_wildcard_request(Method::Get, "/base/baser/foo/bar", "/foo/bar");
764 let res = router.handle(req);
765 assert_eq!(res.body, "foo".to_owned().into_bytes());
766 }
767
768 #[test]
769 fn test_router_display() {
770 let mut router = Router::default();
771 router.get("/:x", echo_param);
772
773 let expected = "Registered routes:\n- GET: /:x\n";
774 let actual = format!("{}", router);
775
776 assert_eq!(actual.as_str(), expected);
777 }
778
779 #[test]
780 fn test_ambiguous_wildcard_vs_star() {
781 fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
782 Ok(Response::new(200, "one/two"))
783 }
784
785 fn h2(_req: Request, _params: Params) -> anyhow::Result<Response> {
786 Ok(Response::new(200, "posts/*"))
787 }
788
789 let mut router = Router::default();
790 router.get("/:one/:two", h1);
791 router.get("/posts/*", h2);
792
793 let req = make_request(Method::Get, "/posts/2");
794 let res = router.handle(req);
795
796 assert_eq!(res.body, "posts/*".to_owned().into_bytes());
797 }
798}