spin_sdk/http/
router.rs

1// This router implementation is heavily inspired by the `Endpoint` type in the https://github.com/http-rs/tide project.
2
3use super::conversions::{IntoResponse, TryFromRequest, TryIntoRequest};
4use super::{responses, Method, Request, Response};
5use async_trait::async_trait;
6use routefinder::{Captures, Router as MethodRouter};
7use std::future::Future;
8use std::{collections::HashMap, fmt::Display};
9
10/// An HTTP request handler.
11///  
12/// This trait is automatically implemented for `Fn` types, and so is rarely implemented
13/// directly by Spin users.
14#[async_trait(?Send)]
15pub trait Handler {
16    /// Invoke the handler.
17    async fn handle(&self, req: Request, params: Params) -> Response;
18}
19
20#[async_trait(?Send)]
21impl Handler for Box<dyn Handler> {
22    async fn handle(&self, req: Request, params: Params) -> Response {
23        self.as_ref().handle(req, params).await
24    }
25}
26
27#[async_trait(?Send)]
28impl<F, Fut> Handler for F
29where
30    F: Fn(Request, Params) -> Fut + 'static,
31    Fut: Future<Output = Response> + 'static,
32{
33    async fn handle(&self, req: Request, params: Params) -> Response {
34        let fut = (self)(req, params);
35        fut.await
36    }
37}
38
39/// Route parameters extracted from a URI that match a route pattern.
40pub type Params = Captures<'static, 'static>;
41
42/// Routes HTTP requests within a Spin component.
43///
44/// Routes may contain wildcards:
45///
46/// * `:name` is a single segment wildcard. The handler can retrieve it using
47///   [Params::get()].
48/// * `*` is a trailing wildcard (matches anything). The handler can retrieve it
49///   using [Params::wildcard()].
50///
51/// If a request matches more than one route, the match is selected according to the follow criteria:
52///
53/// * An exact route takes priority over any wildcard.
54/// * A single segment wildcard takes priority over a trailing wildcard.
55///
56/// (This is the same logic as overlapping routes in the Spin manifest.)
57///
58/// # Examples
59///
60/// Handle GET requests to a path with a wildcard, falling back to "not found":
61///
62/// ```no_run
63/// # use spin_sdk::http::{IntoResponse, Params, Request, Response, Router};
64/// fn handle_route(req: Request) -> Response {
65///     let mut router = Router::new();
66///     router.get("/hello/:planet", hello_planet);
67///     router.any("/*", not_found);
68///     router.handle(req)
69/// }
70///
71/// fn hello_planet(req: Request, params: Params) -> anyhow::Result<Response> {
72///     let planet = params.get("planet").unwrap_or("world");
73///     Ok(Response::new(200, format!("hello, {planet}")))
74/// }
75///
76/// fn not_found(req: Request, params: Params) -> anyhow::Result<Response> {
77///     Ok(Response::new(404, "not found"))
78/// }
79/// ```
80///
81/// Handle requests using a mix of synchronous and asynchronous handlers:
82///
83/// ```no_run
84/// # use spin_sdk::http::{IntoResponse, Params, Request, Response, Router};
85/// fn handle_route(req: Request) -> Response {
86///     let mut router = Router::new();
87///     router.get("/hello/:planet", hello_planet);
88///     router.get_async("/goodbye/:planet", goodbye_planet);
89///     router.handle(req)
90/// }
91///
92/// fn hello_planet(req: Request, params: Params) -> anyhow::Result<Response> {
93///     todo!()
94/// }
95///
96/// async fn goodbye_planet(req: Request, params: Params) -> anyhow::Result<Response> {
97///     todo!()
98/// }
99/// ```
100///
101/// Route differently according to HTTP method:
102///
103/// ```no_run
104/// # use spin_sdk::http::{IntoResponse, Params, Request, Response, Router};
105/// fn handle_route(req: Request) -> Response {
106///     let mut router = Router::new();
107///     router.get("/user", list_users);
108///     router.post("/user", create_user);
109///     router.get("/user/:id", get_user);
110///     router.put("/user/:id", update_user);
111///     router.delete("/user/:id", delete_user);
112///     router.any("/user", method_not_allowed);
113///     router.any("/user/:id", method_not_allowed);
114///     router.handle(req)
115/// }
116/// # fn list_users(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
117/// # fn create_user(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
118/// # fn get_user(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
119/// # fn update_user(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
120/// # fn delete_user(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
121/// # fn method_not_allowed(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
122/// ```
123///
124/// Run the handler asynchronously:
125///
126/// ```no_run
127/// # use spin_sdk::http::{IntoResponse, Params, Request, Response, Router};
128/// async fn handle_route(req: Request) -> Response {
129///     let mut router = Router::new();
130///     router.get_async("/user", list_users);
131///     router.post_async("/user", create_user);
132///     router.handle_async(req).await
133/// }
134/// # async fn list_users(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
135/// # async fn create_user(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
136/// ```
137///
138/// Priority when routes overlap:
139///
140/// ```no_run
141/// # use spin_sdk::http::{IntoResponse, Params, Request, Response, Router};
142/// fn handle_route(req: Request) -> Response {
143///     let mut router = Router::new();
144///     router.any("/*", handle_any);
145///     router.any("/:seg", handle_single_segment);
146///     router.any("/fie", handle_exact);
147///
148///     // '/fie' is routed to `handle_exact`
149///     // '/zounds' is routed to `handle_single_segment`
150///     // '/zounds/fie' is routed to `handle_any`
151///     router.handle(req)
152/// }
153/// # fn handle_any(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
154/// # fn handle_single_segment(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
155/// # fn handle_exact(req: Request, params: Params) -> anyhow::Result<Response> { todo!() }
156/// ```
157pub struct Router {
158    methods_map: HashMap<Method, MethodRouter<Box<dyn Handler>>>,
159    any_methods: MethodRouter<Box<dyn Handler>>,
160}
161
162impl Default for Router {
163    fn default() -> Router {
164        Router::new()
165    }
166}
167
168impl Display for Router {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        writeln!(f, "Registered routes:")?;
171        for (method, router) in &self.methods_map {
172            for route in router.iter() {
173                writeln!(f, "- {}: {}", method, route.0)?;
174            }
175        }
176        Ok(())
177    }
178}
179
180struct RouteMatch<'a> {
181    params: Captures<'static, 'static>,
182    handler: &'a dyn Handler,
183}
184
185impl Router {
186    /// Synchronously dispatches a request to the appropriate handler along with the URI parameters.
187    pub fn handle<R>(&self, request: R) -> Response
188    where
189        R: TryIntoRequest,
190        R::Error: IntoResponse,
191    {
192        crate::http::executor::run(self.handle_async(request))
193    }
194
195    /// Asynchronously dispatches a request to the appropriate handler along with the URI parameters.
196    pub async fn handle_async<R>(&self, request: R) -> Response
197    where
198        R: TryIntoRequest,
199        R::Error: IntoResponse,
200    {
201        let request = match R::try_into_request(request) {
202            Ok(r) => r,
203            Err(e) => return e.into_response(),
204        };
205        let method = request.method.clone();
206        let path = &request.path();
207        let RouteMatch { params, handler } = self.find(path, method);
208        handler.handle(request, params).await
209    }
210
211    fn find(&self, path: &str, method: Method) -> RouteMatch<'_> {
212        let best_match = self
213            .methods_map
214            .get(&method)
215            .and_then(|r| r.best_match(path));
216
217        if let Some(m) = best_match {
218            let params = m.captures().into_owned();
219            let handler = m.handler();
220            return RouteMatch { handler, params };
221        }
222
223        let best_match = self.any_methods.best_match(path);
224
225        match best_match {
226            Some(m) => {
227                let params = m.captures().into_owned();
228                let handler = m.handler();
229                RouteMatch { handler, params }
230            }
231            None if method == Method::Head => {
232                // If it is a HTTP HEAD request then check if there is a callback in the methods map
233                // if not then fallback to the behavior of HTTP GET else proceed as usual
234                self.find(path, Method::Get)
235            }
236            None => {
237                // Handle the failure case where no match could be resolved.
238                self.fail(path, method)
239            }
240        }
241    }
242
243    // Helper function to handle the case where a best match couldn't be resolved.
244    fn fail(&self, path: &str, method: Method) -> RouteMatch<'_> {
245        // First, filter all routers to determine if the path can match but the provided method is not allowed.
246        let is_method_not_allowed = self
247            .methods_map
248            .iter()
249            .filter(|(k, _)| **k != method)
250            .any(|(_, r)| r.best_match(path).is_some());
251
252        if is_method_not_allowed {
253            // If this `path` can be handled by a callback registered with a different HTTP method
254            // should return 405 Method Not Allowed
255            RouteMatch {
256                handler: &method_not_allowed,
257                params: Captures::default(),
258            }
259        } else {
260            // ... Otherwise, nothing matched so 404.
261            RouteMatch {
262                handler: &not_found,
263                params: Captures::default(),
264            }
265        }
266    }
267
268    /// Register a handler at the path for all methods.
269    pub fn any<F, Req, Resp>(&mut self, path: &str, handler: F)
270    where
271        F: Fn(Req, Params) -> Resp + 'static,
272        Req: TryFromRequest + 'static,
273        Req::Error: IntoResponse + 'static,
274        Resp: IntoResponse + 'static,
275    {
276        let handler = move |req, params| {
277            let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
278            async move {
279                match res {
280                    Ok(res) => res.into_response(),
281                    Err(e) => e.into_response(),
282                }
283            }
284        };
285
286        self.any_async(path, handler)
287    }
288
289    /// Register an async handler at the path for all methods.
290    pub fn any_async<F, Fut, I, O>(&mut self, path: &str, handler: F)
291    where
292        F: Fn(I, Params) -> Fut + 'static,
293        Fut: Future<Output = O> + 'static,
294        I: TryFromRequest + 'static,
295        I::Error: IntoResponse + 'static,
296        O: IntoResponse + 'static,
297    {
298        let handler = move |req, params| {
299            let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
300            async move {
301                match res {
302                    Ok(f) => f.await.into_response(),
303                    Err(e) => e.into_response(),
304                }
305            }
306        };
307
308        self.any_methods.add(path, Box::new(handler)).unwrap();
309    }
310
311    /// Register a handler at the path for the specified HTTP method.
312    pub fn add<F, Req, Resp>(&mut self, path: &str, method: Method, handler: F)
313    where
314        F: Fn(Req, Params) -> Resp + 'static,
315        Req: TryFromRequest + 'static,
316        Req::Error: IntoResponse + 'static,
317        Resp: IntoResponse + 'static,
318    {
319        let handler = move |req, params| {
320            let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
321            async move {
322                match res {
323                    Ok(res) => res.into_response(),
324                    Err(e) => e.into_response(),
325                }
326            }
327        };
328
329        self.add_async(path, method, handler)
330    }
331
332    /// Register an async handler at the path for the specified HTTP method.
333    pub fn add_async<F, Fut, I, O>(&mut self, path: &str, method: Method, handler: F)
334    where
335        F: Fn(I, Params) -> Fut + 'static,
336        Fut: Future<Output = O> + 'static,
337        I: TryFromRequest + 'static,
338        I::Error: IntoResponse + 'static,
339        O: IntoResponse + 'static,
340    {
341        let handler = move |req, params| {
342            let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
343            async move {
344                match res {
345                    Ok(f) => f.await.into_response(),
346                    Err(e) => e.into_response(),
347                }
348            }
349        };
350
351        self.methods_map
352            .entry(method)
353            .or_default()
354            .add(path, Box::new(handler))
355            .unwrap();
356    }
357
358    /// Register a handler at the path for the HTTP GET method.
359    pub fn get<F, Req, Resp>(&mut self, path: &str, handler: F)
360    where
361        F: Fn(Req, Params) -> Resp + 'static,
362        Req: TryFromRequest + 'static,
363        Req::Error: IntoResponse + 'static,
364        Resp: IntoResponse + 'static,
365    {
366        self.add(path, Method::Get, handler)
367    }
368
369    /// Register an async handler at the path for the HTTP GET method.
370    pub fn get_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
371    where
372        F: Fn(Req, Params) -> Fut + 'static,
373        Fut: Future<Output = Resp> + 'static,
374        Req: TryFromRequest + 'static,
375        Req::Error: IntoResponse + 'static,
376        Resp: IntoResponse + 'static,
377    {
378        self.add_async(path, Method::Get, handler)
379    }
380
381    /// Register a handler at the path for the HTTP HEAD method.
382    pub fn head<F, Req, Resp>(&mut self, path: &str, handler: F)
383    where
384        F: Fn(Req, Params) -> Resp + 'static,
385        Req: TryFromRequest + 'static,
386        Req::Error: IntoResponse + 'static,
387        Resp: IntoResponse + 'static,
388    {
389        self.add(path, Method::Head, handler)
390    }
391
392    /// Register an async handler at the path for the HTTP HEAD method.
393    pub fn head_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
394    where
395        F: Fn(Req, Params) -> Fut + 'static,
396        Fut: Future<Output = Resp> + 'static,
397        Req: TryFromRequest + 'static,
398        Req::Error: IntoResponse + 'static,
399        Resp: IntoResponse + 'static,
400    {
401        self.add_async(path, Method::Head, handler)
402    }
403
404    /// Register a handler at the path for the HTTP POST method.
405    pub fn post<F, Req, Resp>(&mut self, path: &str, handler: F)
406    where
407        F: Fn(Req, Params) -> Resp + 'static,
408        Req: TryFromRequest + 'static,
409        Req::Error: IntoResponse + 'static,
410        Resp: IntoResponse + 'static,
411    {
412        self.add(path, Method::Post, handler)
413    }
414
415    /// Register an async handler at the path for the HTTP POST method.
416    pub fn post_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
417    where
418        F: Fn(Req, Params) -> Fut + 'static,
419        Fut: Future<Output = Resp> + 'static,
420        Req: TryFromRequest + 'static,
421        Req::Error: IntoResponse + 'static,
422        Resp: IntoResponse + 'static,
423    {
424        self.add_async(path, Method::Post, handler)
425    }
426
427    /// Register a handler at the path for the HTTP DELETE method.
428    pub fn delete<F, Req, Resp>(&mut self, path: &str, handler: F)
429    where
430        F: Fn(Req, Params) -> Resp + 'static,
431        Req: TryFromRequest + 'static,
432        Req::Error: IntoResponse + 'static,
433        Resp: IntoResponse + 'static,
434    {
435        self.add(path, Method::Delete, handler)
436    }
437
438    /// Register an async handler at the path for the HTTP DELETE method.
439    pub fn delete_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
440    where
441        F: Fn(Req, Params) -> Fut + 'static,
442        Fut: Future<Output = Resp> + 'static,
443        Req: TryFromRequest + 'static,
444        Req::Error: IntoResponse + 'static,
445        Resp: IntoResponse + 'static,
446    {
447        self.add_async(path, Method::Delete, handler)
448    }
449
450    /// Register a handler at the path for the HTTP PUT method.
451    pub fn put<F, Req, Resp>(&mut self, path: &str, handler: F)
452    where
453        F: Fn(Req, Params) -> Resp + 'static,
454        Req: TryFromRequest + 'static,
455        Req::Error: IntoResponse + 'static,
456        Resp: IntoResponse + 'static,
457    {
458        self.add(path, Method::Put, handler)
459    }
460
461    /// Register an async handler at the path for the HTTP PUT method.
462    pub fn put_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
463    where
464        F: Fn(Req, Params) -> Fut + 'static,
465        Fut: Future<Output = Resp> + 'static,
466        Req: TryFromRequest + 'static,
467        Req::Error: IntoResponse + 'static,
468        Resp: IntoResponse + 'static,
469    {
470        self.add_async(path, Method::Put, handler)
471    }
472
473    /// Register a handler at the path for the HTTP PATCH method.
474    pub fn patch<F, Req, Resp>(&mut self, path: &str, handler: F)
475    where
476        F: Fn(Req, Params) -> Resp + 'static,
477        Req: TryFromRequest + 'static,
478        Req::Error: IntoResponse + 'static,
479        Resp: IntoResponse + 'static,
480    {
481        self.add(path, Method::Patch, handler)
482    }
483
484    /// Register an async handler at the path for the HTTP PATCH method.
485    pub fn patch_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
486    where
487        F: Fn(Req, Params) -> Fut + 'static,
488        Fut: Future<Output = Resp> + 'static,
489        Req: TryFromRequest + 'static,
490        Req::Error: IntoResponse + 'static,
491        Resp: IntoResponse + 'static,
492    {
493        self.add_async(path, Method::Patch, handler)
494    }
495
496    /// Register a handler at the path for the HTTP OPTIONS method.
497    pub fn options<F, Req, Resp>(&mut self, path: &str, handler: F)
498    where
499        F: Fn(Req, Params) -> Resp + 'static,
500        Req: TryFromRequest + 'static,
501        Req::Error: IntoResponse + 'static,
502        Resp: IntoResponse + 'static,
503    {
504        self.add(path, Method::Options, handler)
505    }
506
507    /// Register an async handler at the path for the HTTP OPTIONS method.
508    pub fn options_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
509    where
510        F: Fn(Req, Params) -> Fut + 'static,
511        Fut: Future<Output = Resp> + 'static,
512        Req: TryFromRequest + 'static,
513        Req::Error: IntoResponse + 'static,
514        Resp: IntoResponse + 'static,
515    {
516        self.add_async(path, Method::Options, handler)
517    }
518
519    /// Construct a new Router.
520    pub fn new() -> Self {
521        Router {
522            methods_map: HashMap::default(),
523            any_methods: MethodRouter::new(),
524        }
525    }
526}
527
528async fn not_found(_req: Request, _params: Params) -> Response {
529    responses::not_found()
530}
531
532async fn method_not_allowed(_req: Request, _params: Params) -> Response {
533    responses::method_not_allowed()
534}
535
536/// A macro to help with constructing a Router from a stream of tokens.
537#[macro_export]
538macro_rules! http_router {
539    ($($method:tt $path:literal => $h:expr),*) => {
540        {
541            let mut router = $crate::http::Router::new();
542            $(
543                $crate::http_router!(@build router $method $path => $h);
544            )*
545            router
546        }
547    };
548    (@build $r:ident HEAD $path:literal => $h:expr) => {
549        $r.head($path, $h);
550    };
551    (@build $r:ident GET $path:literal => $h:expr) => {
552        $r.get($path, $h);
553    };
554    (@build $r:ident PUT $path:literal => $h:expr) => {
555        $r.put($path, $h);
556    };
557    (@build $r:ident POST $path:literal => $h:expr) => {
558        $r.post($path, $h);
559    };
560    (@build $r:ident PATCH $path:literal => $h:expr) => {
561        $r.patch($path, $h);
562    };
563    (@build $r:ident DELETE $path:literal => $h:expr) => {
564        $r.delete($path, $h);
565    };
566    (@build $r:ident OPTIONS $path:literal => $h:expr) => {
567        $r.options($path, $h);
568    };
569    (@build $r:ident _ $path:literal => $h:expr) => {
570        $r.any($path, $h);
571    };
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    fn make_request(method: Method, path: &str) -> Request {
579        Request::new(method, path)
580    }
581
582    fn echo_param(_req: Request, params: Params) -> Response {
583        match params.get("x") {
584            Some(path) => Response::new(200, path),
585            None => responses::not_found(),
586        }
587    }
588
589    #[test]
590    fn test_method_not_allowed() {
591        let mut router = Router::default();
592        router.get("/:x", echo_param);
593
594        let req = make_request(Method::Post, "/foobar");
595        let res = router.handle(req);
596        assert_eq!(res.status, hyperium::StatusCode::METHOD_NOT_ALLOWED);
597    }
598
599    #[test]
600    fn test_not_found() {
601        fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
602            Ok(Response::new(200, ()))
603        }
604
605        let mut router = Router::default();
606        router.get("/h1/:param", h1);
607
608        let req = make_request(Method::Get, "/h1/");
609        let res = router.handle(req);
610        assert_eq!(res.status, hyperium::StatusCode::NOT_FOUND);
611    }
612
613    #[test]
614    fn test_multi_param() {
615        fn multiply(_req: Request, params: Params) -> anyhow::Result<Response> {
616            let x: i64 = params.get("x").unwrap().parse()?;
617            let y: i64 = params.get("y").unwrap().parse()?;
618            Ok(Response::new(200, format!("{result}", result = x * y)))
619        }
620
621        let mut router = Router::default();
622        router.get("/multiply/:x/:y", multiply);
623
624        let req = make_request(Method::Get, "/multiply/2/4");
625        let res = router.handle(req);
626
627        assert_eq!(res.body, "8".to_owned().into_bytes());
628    }
629
630    #[test]
631    fn test_param() {
632        let mut router = Router::default();
633        router.get("/:x", echo_param);
634
635        let req = make_request(Method::Get, "/y");
636        let res = router.handle(req);
637
638        assert_eq!(res.body, "y".to_owned().into_bytes());
639    }
640
641    #[test]
642    fn test_wildcard() {
643        fn echo_wildcard(_req: Request, params: Params) -> Response {
644            match params.wildcard() {
645                Some(path) => Response::new(200, path),
646                None => responses::not_found(),
647            }
648        }
649
650        let mut router = Router::default();
651        router.get("/*", echo_wildcard);
652
653        let req = make_request(Method::Get, "/foo/bar");
654        let res = router.handle(req);
655        assert_eq!(res.status, hyperium::StatusCode::OK);
656        assert_eq!(res.body, "foo/bar".to_owned().into_bytes());
657    }
658
659    #[test]
660    fn test_wildcard_last_segment() {
661        let mut router = Router::default();
662        router.get("/:x/*", echo_param);
663
664        let req = make_request(Method::Get, "/foo/bar");
665        let res = router.handle(req);
666        assert_eq!(res.body, "foo".to_owned().into_bytes());
667    }
668
669    #[test]
670    fn test_router_display() {
671        let mut router = Router::default();
672        router.get("/:x", echo_param);
673
674        let expected = "Registered routes:\n- GET: /:x\n";
675        let actual = format!("{}", router);
676
677        assert_eq!(actual.as_str(), expected);
678    }
679
680    #[test]
681    fn test_ambiguous_wildcard_vs_star() {
682        fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
683            Ok(Response::new(200, "one/two"))
684        }
685
686        fn h2(_req: Request, _params: Params) -> anyhow::Result<Response> {
687            Ok(Response::new(200, "posts/*"))
688        }
689
690        let mut router = Router::default();
691        router.get("/:one/:two", h1);
692        router.get("/posts/*", h2);
693
694        let req = make_request(Method::Get, "/posts/2");
695        let res = router.handle(req);
696
697        assert_eq!(res.body, "posts/*".to_owned().into_bytes());
698    }
699}