volo_http/server/route/
method_router.rs

1//! [`MethodRouter`] implementation for [`Server`].
2//!
3//! [`Router`] will route a path to a [`MethodRouter`], and the [`MethodRouter`] will route the
4//! request through its HTTP method. If method of the request is not supported by the
5//! [`MethodRouter`], it will fallback to another [`Route`].
6//!
7//! You can use a HTTP method name as a function for creating a [`MethodRouter`], for example,
8//! [`get`] for creating a [`MethodRouter`] that can route a request with GET method to the target
9//! [`Route`].
10//!
11//! See [`MethodRouter`] and [`get`], [`post`], [`any`], [`get_service`]... for more details.
12//!
13//! [`Server`]: crate::server::Server
14//! [`Router`]: super::router::Router
15
16use std::convert::Infallible;
17
18use http::{method::Method, status::StatusCode};
19use motore::{ServiceExt, layer::Layer, service::Service};
20use paste::paste;
21
22use super::{Fallback, Route};
23use crate::{
24    body::Body,
25    context::ServerContext,
26    request::Request,
27    response::Response,
28    server::{IntoResponse, handler::Handler},
29};
30
31/// A method router that handle the request and dispatch it by its method.
32///
33/// There is no need to create [`MethodRouter`] directly, you can use specific method for creating
34/// it. What's more, the method router allows chaining additional handlers or services.
35///
36/// # Examples
37///
38/// ```
39/// use std::convert::Infallible;
40///
41/// use volo::service::service_fn;
42/// use volo_http::{
43///     context::ServerContext,
44///     request::Request,
45///     server::route::{MethodRouter, Router, any, get, post_service},
46/// };
47///
48/// async fn index() -> &'static str {
49///     "Hello, World"
50/// }
51///
52/// async fn index_fn(cx: &mut ServerContext, req: Request) -> Result<&'static str, Infallible> {
53///     Ok("Hello, World")
54/// }
55///
56/// let _: MethodRouter = get(index);
57/// let _: MethodRouter = any(index);
58/// let _: MethodRouter = post_service(service_fn(index_fn));
59///
60/// let _: MethodRouter = get(index).post(index).options_service(service_fn(index_fn));
61///
62/// let app: Router = Router::new().route("/", get(index));
63/// let app: Router = Router::new().route("/", get(index).post(index).head(index));
64/// ```
65pub struct MethodRouter<B = Body, E = Infallible> {
66    options: MethodEndpoint<B, E>,
67    get: MethodEndpoint<B, E>,
68    post: MethodEndpoint<B, E>,
69    put: MethodEndpoint<B, E>,
70    delete: MethodEndpoint<B, E>,
71    head: MethodEndpoint<B, E>,
72    trace: MethodEndpoint<B, E>,
73    connect: MethodEndpoint<B, E>,
74    patch: MethodEndpoint<B, E>,
75    fallback: Fallback<B, E>,
76}
77
78impl<B, E> Service<ServerContext, Request<B>> for MethodRouter<B, E>
79where
80    B: Send,
81{
82    type Response = Response;
83    type Error = E;
84
85    async fn call(
86        &self,
87        cx: &mut ServerContext,
88        req: Request<B>,
89    ) -> Result<Self::Response, Self::Error> {
90        let handler = match *req.method() {
91            Method::OPTIONS => Some(&self.options),
92            Method::GET => Some(&self.get),
93            Method::POST => Some(&self.post),
94            Method::PUT => Some(&self.put),
95            Method::DELETE => Some(&self.delete),
96            Method::HEAD => Some(&self.head),
97            Method::TRACE => Some(&self.trace),
98            Method::CONNECT => Some(&self.connect),
99            Method::PATCH => Some(&self.patch),
100            _ => None,
101        };
102
103        match handler {
104            Some(MethodEndpoint::Route(route)) => route.call(cx, req).await,
105            _ => self.fallback.call(cx, req).await,
106        }
107    }
108}
109
110impl<B, E> Default for MethodRouter<B, E>
111where
112    B: Send + 'static,
113    E: 'static,
114{
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120impl<B, E> MethodRouter<B, E>
121where
122    B: Send + 'static,
123    E: 'static,
124{
125    fn new() -> Self {
126        Self {
127            options: MethodEndpoint::None,
128            get: MethodEndpoint::None,
129            post: MethodEndpoint::None,
130            put: MethodEndpoint::None,
131            delete: MethodEndpoint::None,
132            head: MethodEndpoint::None,
133            trace: MethodEndpoint::None,
134            connect: MethodEndpoint::None,
135            patch: MethodEndpoint::None,
136            fallback: Fallback::from_status_code(StatusCode::METHOD_NOT_ALLOWED),
137        }
138    }
139
140    /// Add a new inner layer to all routes in this method router.
141    ///
142    /// The layer's `Service` should be `Clone + Send + Sync + 'static`.
143    pub fn layer<L, B2, E2>(self, l: L) -> MethodRouter<B2, E2>
144    where
145        L: Layer<Route<B, E>> + Clone + Send + Sync + 'static,
146        L::Service: Service<ServerContext, Request<B2>, Error = E2> + Send + Sync + 'static,
147        <L::Service as Service<ServerContext, Request<B2>>>::Response: IntoResponse,
148        B2: 'static,
149    {
150        let Self {
151            options,
152            get,
153            post,
154            put,
155            delete,
156            head,
157            trace,
158            connect,
159            patch,
160            fallback,
161        } = self;
162
163        let layer_fn = move |route: Route<B, E>| {
164            Route::new(
165                l.clone()
166                    .layer(route)
167                    .map_response(IntoResponse::into_response),
168            )
169        };
170
171        let options = options.map(layer_fn.clone());
172        let get = get.map(layer_fn.clone());
173        let post = post.map(layer_fn.clone());
174        let put = put.map(layer_fn.clone());
175        let delete = delete.map(layer_fn.clone());
176        let head = head.map(layer_fn.clone());
177        let trace = trace.map(layer_fn.clone());
178        let connect = connect.map(layer_fn.clone());
179        let patch = patch.map(layer_fn.clone());
180
181        let fallback = fallback.map(layer_fn);
182
183        MethodRouter {
184            options,
185            get,
186            post,
187            put,
188            delete,
189            head,
190            trace,
191            connect,
192            patch,
193            fallback,
194        }
195    }
196}
197
198macro_rules! for_all_methods {
199    ($name:ident) => {
200        $name!(options, get, post, put, delete, head, trace, connect, patch);
201    };
202}
203
204macro_rules! impl_method_register_for_builder {
205    ($( $method:ident ),*) => {
206        $(
207        #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
208        pub fn $method<H, T>(mut self, handler: H) -> Self
209        where
210            for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
211            B: Send,
212            T: 'static,
213        {
214            self.$method = MethodEndpoint::from_handler(handler);
215            self
216        }
217
218        paste! {
219        #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
220        pub fn [<$method _service>]<S>(mut self, service: S) -> MethodRouter<B, E>
221        where
222            for<'a> S: Service<ServerContext, Request<B>, Error = E>
223                + Send
224                + Sync
225                + 'a,
226            S::Response: IntoResponse,
227        {
228            self.$method = MethodEndpoint::from_service(service);
229            self
230        }
231        }
232        )+
233    };
234}
235
236impl<B, E> MethodRouter<B, E>
237where
238    B: Send + 'static,
239    E: IntoResponse + 'static,
240{
241    for_all_methods!(impl_method_register_for_builder);
242
243    /// Set a fallback handler for the route.
244    ///
245    /// If there is no method that the route can handle, method router will call the fallback
246    /// handler.
247    ///
248    /// Default is returning "405 Method Not Allowed".
249    pub fn fallback<H, T>(mut self, handler: H) -> Self
250    where
251        for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
252        T: 'static,
253    {
254        self.fallback = Fallback::from_handler(handler);
255        self
256    }
257
258    /// Set a fallback service for the route.
259    ///
260    /// If there is no method that the route can handle, method router will call the fallback
261    /// service.
262    ///
263    /// Default is returning "405 Method Not Allowed".
264    pub fn fallback_service<S>(mut self, service: S) -> Self
265    where
266        for<'a> S: Service<ServerContext, Request<B>, Error = E> + Send + Sync + 'a,
267        S::Response: IntoResponse,
268    {
269        self.fallback = Fallback::from_service(service);
270        self
271    }
272}
273
274macro_rules! impl_method_register {
275    ($( $method:ident ),*) => {
276        $(
277        #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
278        pub fn $method<H, T, B, E>(handler: H) -> MethodRouter<B, E>
279        where
280            for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
281            T: 'static,
282            B: Send + 'static,
283            E: IntoResponse + 'static,
284        {
285            MethodRouter {
286                $method: MethodEndpoint::from_handler(handler),
287                ..Default::default()
288            }
289        }
290
291        paste! {
292        #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
293        pub fn [<$method _service>]<S, B, E>(service: S) -> MethodRouter<B, E>
294        where
295            for<'a> S: Service<ServerContext, Request<B>, Error = E>
296                + Send
297                + Sync
298                + 'a,
299            S::Response: IntoResponse,
300            B: Send + 'static,
301            E: IntoResponse + 'static,
302        {
303            MethodRouter {
304                $method: MethodEndpoint::from_service(service),
305                ..Default::default()
306            }
307        }
308        }
309        )+
310    };
311}
312
313for_all_methods!(impl_method_register);
314
315/// Route any method to the given handler.
316pub fn any<H, T, B, E>(handler: H) -> MethodRouter<B, E>
317where
318    for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
319    T: 'static,
320    B: Send + 'static,
321    E: IntoResponse + 'static,
322{
323    MethodRouter {
324        fallback: Fallback::from_handler(handler),
325        ..Default::default()
326    }
327}
328
329/// Route any method to the given service.
330pub fn any_service<S, B, E>(service: S) -> MethodRouter<B, E>
331where
332    for<'a> S: Service<ServerContext, Request<B>, Error = E> + Send + Sync + 'a,
333    S::Response: IntoResponse,
334    B: Send + 'static,
335    E: IntoResponse + 'static,
336{
337    MethodRouter {
338        fallback: Fallback::from_service(service),
339        ..Default::default()
340    }
341}
342
343#[derive(Default)]
344enum MethodEndpoint<B = Body, E = Infallible> {
345    #[default]
346    None,
347    Route(Route<B, E>),
348}
349
350impl<B, E> MethodEndpoint<B, E>
351where
352    B: Send + 'static,
353{
354    fn from_handler<H, T>(handler: H) -> Self
355    where
356        for<'a> H: Handler<T, B, E> + Clone + Send + Sync + 'a,
357        T: 'static,
358        E: 'static,
359    {
360        Self::from_service(handler.into_service())
361    }
362
363    fn from_service<S>(service: S) -> Self
364    where
365        for<'a> S: Service<ServerContext, Request<B>, Error = E> + Send + Sync + 'a,
366        S::Response: IntoResponse,
367    {
368        Self::Route(Route::new(
369            service.map_response(IntoResponse::into_response),
370        ))
371    }
372
373    fn map<F, B2, E2>(self, f: F) -> MethodEndpoint<B2, E2>
374    where
375        F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + 'static,
376    {
377        match self {
378            Self::None => MethodEndpoint::None,
379            Self::Route(route) => MethodEndpoint::Route(f(route)),
380        }
381    }
382}
383
384#[cfg(test)]
385mod method_router_tests {
386    use http::{method::Method, status::StatusCode};
387
388    use super::{MethodRouter, any, get, head, options};
389    use crate::body::Body;
390
391    async fn always_ok() {}
392    async fn teapot() -> StatusCode {
393        StatusCode::IM_A_TEAPOT
394    }
395
396    #[tokio::test]
397    async fn method_router() {
398        async fn test_all_method<F>(router: MethodRouter<Option<Body>>, filter: F)
399        where
400            F: Fn(Method) -> bool,
401        {
402            let methods = [
403                Method::GET,
404                Method::POST,
405                Method::PUT,
406                Method::DELETE,
407                Method::HEAD,
408                Method::OPTIONS,
409                Method::CONNECT,
410                Method::PATCH,
411                Method::TRACE,
412            ];
413            for m in methods {
414                assert_eq!(
415                    router
416                        .call_route(m.clone(), None)
417                        .await
418                        .status()
419                        .is_success(),
420                    filter(m)
421                );
422            }
423        }
424
425        test_all_method(get(always_ok), |m| m == Method::GET).await;
426        test_all_method(head(always_ok), |m| m == Method::HEAD).await;
427        test_all_method(any(always_ok), |_| true).await;
428    }
429
430    #[tokio::test]
431    async fn method_fallback() {
432        async fn test_all_method<F>(router: MethodRouter<Option<Body>>, filter: F)
433        where
434            F: Fn(Method) -> bool,
435        {
436            let methods = [
437                Method::GET,
438                Method::POST,
439                Method::PUT,
440                Method::DELETE,
441                Method::HEAD,
442                Method::OPTIONS,
443                Method::CONNECT,
444                Method::PATCH,
445                Method::TRACE,
446            ];
447            for m in methods {
448                assert_eq!(
449                    router.call_route(m.clone(), None).await.status() == StatusCode::IM_A_TEAPOT,
450                    filter(m)
451                );
452            }
453        }
454
455        test_all_method(get(always_ok).fallback(teapot), |m| m != Method::GET).await;
456        test_all_method(options(always_ok).fallback(teapot), |m| {
457            m != Method::OPTIONS
458        })
459        .await;
460        test_all_method(any(teapot), |_| true).await;
461    }
462}