volo_http/server/
middleware.rs

1//! Server middleware utilities.
2
3use std::{convert::Infallible, marker::PhantomData, sync::Arc};
4
5use motore::{layer::Layer, service::Service};
6
7use super::{
8    IntoResponse,
9    handler::{MiddlewareHandlerFromFn, MiddlewareHandlerMapResponse},
10    route::Route,
11};
12use crate::{body::Body, context::ServerContext, request::Request, response::Response};
13
14/// A [`Layer`] from an async function
15///
16/// This layer is created with [`from_fn`], see that function for more details.
17pub struct FromFnLayer<F, T, B, B2, E2> {
18    f: F,
19    #[allow(clippy::type_complexity)]
20    _marker: PhantomData<fn(T, B, B2, E2)>,
21}
22
23impl<F, T, B, B2, E2> Clone for FromFnLayer<F, T, B, B2, E2>
24where
25    F: Clone,
26{
27    fn clone(&self) -> Self {
28        Self {
29            f: self.f.clone(),
30            _marker: self._marker,
31        }
32    }
33}
34
35/// Create a middleware from an async function
36///
37/// The function must have the three params `&mut ServerContext`, `Request` and [`Next`],
38/// and the three params must be at the end of function.
39///
40/// There can also be some other types that implement
41/// [`FromContext`](crate::server::extract::FromContext) before the above three params.
42///
43/// # Examples
44///
45/// Without any extra params:
46///
47/// ```
48/// use volo_http::{
49///     context::ServerContext,
50///     request::Request,
51///     response::Response,
52///     server::{
53///         middleware::{Next, from_fn},
54///         response::IntoResponse,
55///         route::{Router, get},
56///     },
57/// };
58///
59/// /// Caculate cost of inner services and print it
60/// async fn tracer(cx: &mut ServerContext, req: Request, next: Next) -> Response {
61///     let start = std::time::Instant::now();
62///     let resp = next.run(cx, req).await.into_response();
63///     let elapsed = start.elapsed();
64///     println!("request cost: {elapsed:?}");
65///     resp
66/// }
67///
68/// async fn handler() -> &'static str {
69///     "Hello, World"
70/// }
71///
72/// let router: Router = Router::new()
73///     .route("/", get(handler))
74///     .layer(from_fn(tracer));
75/// ```
76///
77/// With params that implement `FromContext`:
78///
79/// ```
80/// use http::{header::HeaderMap, status::StatusCode, uri::Uri};
81/// use volo_http::{
82///     context::ServerContext,
83///     request::Request,
84///     response::Response,
85///     server::{
86///         middleware::{Next, from_fn},
87///         response::IntoResponse,
88///         route::{Router, get},
89///     },
90/// };
91///
92/// struct Session;
93///
94/// fn get_session(headers: &HeaderMap) -> Option<Session> {
95///     unimplemented!()
96/// }
97///
98/// async fn cookies_check(
99///     uri: Uri,
100///     cx: &mut ServerContext,
101///     req: Request,
102///     next: Next,
103/// ) -> Result<Response, StatusCode> {
104///     // User is not logged in, and not try to login.
105///     let session = get_session(req.headers());
106///     if uri.path() != "/api/v1/login" && session.is_none() {
107///         return Err(StatusCode::FORBIDDEN);
108///     }
109///     // do something
110///     Ok(next.run(cx, req).await.into_response())
111/// }
112///
113/// async fn handler() -> &'static str {
114///     "Hello, World"
115/// }
116///
117/// let router: Router = Router::new()
118///     .route("/", get(handler))
119///     .layer(from_fn(cookies_check));
120/// ```
121///
122/// There are some advanced uses of this function, for example, we can convert types of request and
123/// error, for example:
124///
125/// ```
126/// use std::convert::Infallible;
127///
128/// use motore::service::service_fn;
129/// use volo_http::{
130///     body::BodyConversion,
131///     context::ServerContext,
132///     request::Request,
133///     response::Response,
134///     server::{
135///         middleware::{Next, from_fn},
136///         response::IntoResponse,
137///         route::{Router, get, get_service},
138///     },
139/// };
140///
141/// async fn converter(cx: &mut ServerContext, req: Request, next: Next<String>) -> Response {
142///     let (parts, body) = req.into_parts();
143///     let s = body.into_string().await.unwrap();
144///     let req = Request::from_parts(parts, s);
145///     next.run(cx, req).await.into_response()
146/// }
147///
148/// async fn service(cx: &mut ServerContext, req: Request<String>) -> Result<Response, Infallible> {
149///     unimplemented!()
150/// }
151///
152/// let router: Router = Router::new()
153///     .route("/", get_service(service_fn(service)))
154///     .layer(from_fn(converter));
155/// ```
156pub fn from_fn<F, T, B, B2, E2>(f: F) -> FromFnLayer<F, T, B, B2, E2> {
157    FromFnLayer {
158        f,
159        _marker: PhantomData,
160    }
161}
162
163impl<S, F, T, B, B2, E2> Layer<S> for FromFnLayer<F, T, B, B2, E2>
164where
165    S: Service<ServerContext, Request<B2>, Response = Response, Error = E2> + Send + Sync + 'static,
166{
167    type Service = FromFn<Arc<S>, F, T, B, B2, E2>;
168
169    fn layer(self, service: S) -> Self::Service {
170        FromFn {
171            service: Arc::new(service),
172            f: self.f,
173            _marker: PhantomData,
174        }
175    }
176}
177
178/// [`Service`] implementation from [`FromFnLayer`]
179pub struct FromFn<S, F, T, B, B2, E2> {
180    service: S,
181    f: F,
182    _marker: PhantomData<fn(T, B, B2, E2)>,
183}
184
185impl<S, F, T, B, B2, E2> Clone for FromFn<S, F, T, B, B2, E2>
186where
187    S: Clone,
188    F: Clone,
189{
190    fn clone(&self) -> Self {
191        Self {
192            service: self.service.clone(),
193            f: self.f.clone(),
194            _marker: self._marker,
195        }
196    }
197}
198
199impl<S, F, T, B, B2, E2> Service<ServerContext, Request<B>> for FromFn<S, F, T, B, B2, E2>
200where
201    S: Service<ServerContext, Request<B2>, Response = Response, Error = E2>
202        + Clone
203        + Send
204        + Sync
205        + 'static,
206    F: for<'r> MiddlewareHandlerFromFn<'r, T, B, B2, E2> + Sync,
207    B: Send,
208    B2: 'static,
209{
210    type Response = Response;
211    type Error = Infallible;
212
213    async fn call(
214        &self,
215        cx: &mut ServerContext,
216        req: Request<B>,
217    ) -> Result<Self::Response, Self::Error> {
218        let next = Next {
219            service: Route::new(self.service.clone()),
220        };
221        Ok(self.f.handle(cx, req, next).await.into_response())
222    }
223}
224
225/// Wrapper for inner [`Service`]
226///
227/// Call [`Next::run`] with context and request for calling the inner [`Service`] and get the
228/// response.
229///
230/// See [`from_fn`] for more details.
231pub struct Next<B = Body, E = Infallible> {
232    service: Route<B, E>,
233}
234
235impl<B, E> Next<B, E> {
236    /// Call the inner [`Service`]
237    pub async fn run(self, cx: &mut ServerContext, req: Request<B>) -> Result<Response, E> {
238        self.service.call(cx, req).await
239    }
240}
241
242/// A [`Layer`] for mapping a response
243///
244/// This layer is created with [`map_response`], see that function for more details.
245pub struct MapResponseLayer<F, T, R1, R2> {
246    f: F,
247    _marker: PhantomData<fn(T, R1, R2)>,
248}
249
250impl<F, T, R1, R2> Clone for MapResponseLayer<F, T, R1, R2>
251where
252    F: Clone,
253{
254    fn clone(&self) -> Self {
255        Self {
256            f: self.f.clone(),
257            _marker: self._marker,
258        }
259    }
260}
261
262/// Create a middleware for mapping a response from an async function
263///
264/// The async function can be:
265///
266/// - `async fn func(resp: Response) -> impl IntoResponse`
267/// - `async fn func(cx: &mut ServerContext, resp: Response) -> impl IntoResponse`
268///
269/// # Examples
270///
271/// Append some headers:
272///
273/// ```
274/// use volo_http::{
275///     response::Response,
276///     server::{
277///         middleware::map_response,
278///         route::{Router, get},
279///     },
280/// };
281///
282/// async fn handler() -> &'static str {
283///     "Hello, World"
284/// }
285///
286/// async fn append_header(resp: Response) -> ((&'static str, &'static str), Response) {
287///     (("Server", "nginx"), resp)
288/// }
289///
290/// let router: Router = Router::new()
291///     .route("/", get(handler))
292///     .layer(map_response(append_header));
293/// ```
294pub fn map_response<F, T, R1, R2>(f: F) -> MapResponseLayer<F, T, R1, R2> {
295    MapResponseLayer {
296        f,
297        _marker: PhantomData,
298    }
299}
300
301impl<S, F, T, R1, R2> Layer<S> for MapResponseLayer<F, T, R1, R2> {
302    type Service = MapResponse<S, F, T, R1, R2>;
303
304    fn layer(self, service: S) -> Self::Service {
305        MapResponse {
306            service,
307            f: self.f,
308            _marker: self._marker,
309        }
310    }
311}
312
313/// [`Service`] implementation from [`MapResponseLayer`]
314pub struct MapResponse<S, F, T, R1, R2> {
315    service: S,
316    f: F,
317    _marker: PhantomData<fn(T, R1, R2)>,
318}
319
320impl<S, F, T, R1, R2> Clone for MapResponse<S, F, T, R1, R2>
321where
322    S: Clone,
323    F: Clone,
324{
325    fn clone(&self) -> Self {
326        Self {
327            service: self.service.clone(),
328            f: self.f.clone(),
329            _marker: self._marker,
330        }
331    }
332}
333
334impl<S, F, T, Req, R1, R2> Service<ServerContext, Req> for MapResponse<S, F, T, R1, R2>
335where
336    S: Service<ServerContext, Req, Response = R1> + Send + Sync,
337    F: for<'r> MiddlewareHandlerMapResponse<'r, T, R1, R2> + Sync,
338    Req: Send,
339{
340    type Response = R2;
341    type Error = S::Error;
342
343    async fn call(&self, cx: &mut ServerContext, req: Req) -> Result<Self::Response, Self::Error> {
344        let resp = self.service.call(cx, req).await?;
345
346        Ok(self.f.handle(cx, resp).await)
347    }
348}
349
350#[cfg(test)]
351mod middleware_tests {
352    use faststr::FastStr;
353    use http::{HeaderValue, Method, StatusCode, Uri};
354    use motore::service::service_fn;
355
356    use super::*;
357    use crate::{
358        body::{Body, BodyConversion},
359        context::ServerContext,
360        request::Request,
361        response::Response,
362        server::{
363            response::IntoResponse,
364            route::{any, get_service},
365            test_helpers::empty_cx,
366        },
367        utils::test_helpers::simple_req,
368    };
369
370    async fn print_body_handler(
371        _: &mut ServerContext,
372        req: Request<String>,
373    ) -> Result<Response<Body>, Infallible> {
374        Ok(Response::new(req.into_body().into()))
375    }
376
377    async fn append_body_mw(
378        cx: &mut ServerContext,
379        req: Request<String>,
380        next: Next<String>,
381    ) -> Response {
382        let (parts, mut body) = req.into_parts();
383        body += "test";
384        let req = Request::from_parts(parts, body);
385        next.run(cx, req).await.into_response()
386    }
387
388    async fn cors_mw(
389        method: Method,
390        url: Uri,
391        cx: &mut ServerContext,
392        req: Request<String>,
393        next: Next<String>,
394    ) -> Response {
395        let mut resp = next.run(cx, req).await.into_response();
396        resp.headers_mut().insert(
397            "Access-Control-Allow-Methods",
398            HeaderValue::from_str(method.as_str()).unwrap(),
399        );
400        resp.headers_mut().insert(
401            "Access-Control-Allow-Origin",
402            HeaderValue::from_str(url.to_string().as_str()).unwrap(),
403        );
404        resp.headers_mut().insert(
405            "Access-Control-Allow-Headers",
406            HeaderValue::from_str("*").unwrap(),
407        );
408        resp
409    }
410
411    #[tokio::test]
412    async fn test_from_fn_with_necessary_params() {
413        let handler = service_fn(print_body_handler);
414        let mut cx = empty_cx();
415
416        let service = from_fn(append_body_mw).layer(handler);
417        let req = simple_req(Method::GET, "/", String::from(""));
418        let resp = service.call(&mut cx, req).await.unwrap();
419        assert_eq!(resp.into_body().into_string().await.unwrap(), "test");
420
421        // Test case 3: Return type [`Result<_,_>`]
422        async fn error_mw(
423            _: &mut ServerContext,
424            _: Request<String>,
425            _: Next<String>,
426        ) -> Result<Response, StatusCode> {
427            Err(StatusCode::INTERNAL_SERVER_ERROR)
428        }
429        let service = from_fn(error_mw).layer(handler);
430        let req = simple_req(Method::GET, "/", String::from("test"));
431        let resp = service.call(&mut cx, req).await.unwrap();
432        let status = resp.status();
433        let (_, body) = resp.into_parts();
434        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
435        assert_eq!(body.into_string().await.unwrap(), "");
436    }
437
438    #[tokio::test]
439    async fn test_from_fn_with_optional_params() {
440        let handler = service_fn(print_body_handler);
441        let mut cx = empty_cx();
442
443        let service = from_fn(cors_mw).layer(handler);
444        let req = simple_req(Method::GET, "/", String::from(""));
445        let resp = service.call(&mut cx, req).await.unwrap();
446        assert_eq!(
447            resp.headers().get("Access-Control-Allow-Methods").unwrap(),
448            "GET"
449        );
450        assert_eq!(
451            resp.headers().get("Access-Control-Allow-Origin").unwrap(),
452            "/"
453        );
454        assert_eq!(
455            resp.headers().get("Access-Control-Allow-Headers").unwrap(),
456            "*"
457        );
458    }
459
460    #[tokio::test]
461    async fn test_from_fn_with_multiple_mws() {
462        let handler = service_fn(print_body_handler);
463        let mut cx = empty_cx();
464
465        let service = from_fn(cors_mw).layer(handler);
466        let service = from_fn(append_body_mw).layer(service);
467        let req = simple_req(Method::GET, "/", String::from(""));
468        let resp = service.call(&mut cx, req).await.unwrap();
469        let (parts, body) = resp.into_parts();
470        assert_eq!(
471            parts.headers.get("Access-Control-Allow-Methods").unwrap(),
472            "GET"
473        );
474        assert_eq!(
475            parts.headers.get("Access-Control-Allow-Origin").unwrap(),
476            "/"
477        );
478        assert_eq!(
479            parts.headers.get("Access-Control-Allow-Headers").unwrap(),
480            "*"
481        );
482        assert_eq!(body.into_string().await.unwrap(), "test");
483    }
484
485    #[tokio::test]
486    async fn test_from_fn_converts() {
487        async fn converter(
488            cx: &mut ServerContext,
489            req: Request<String>,
490            next: Next<FastStr>,
491        ) -> Response {
492            let (parts, body) = req.into_parts();
493            let s = body.into_faststr().await.unwrap();
494            let req = Request::from_parts(parts, s);
495            let _: Request<FastStr> = req;
496            next.run(cx, req).await.into_response()
497        }
498
499        async fn service(
500            _: &mut ServerContext,
501            _: Request<FastStr>,
502        ) -> Result<Response, Infallible> {
503            Ok(Response::new(String::from("Hello, World").into()))
504        }
505
506        let route = Route::new(get_service(service_fn(service)));
507        let service = from_fn(converter).layer(route);
508
509        let _: Result<Response, Infallible> = service
510            .call(
511                &mut empty_cx(),
512                simple_req(Method::GET, "/", String::from("")),
513            )
514            .await;
515    }
516
517    async fn index_handler() -> &'static str {
518        "Hello, World"
519    }
520
521    #[tokio::test]
522    async fn test_map_response() {
523        async fn append_header(resp: Response) -> ((&'static str, &'static str), Response) {
524            (("Server", "nginx"), resp)
525        }
526
527        let route: Route<String> = Route::new(any(index_handler));
528        let service = map_response(append_header).layer(route);
529
530        let mut cx = empty_cx();
531        let req = simple_req(Method::GET, "/", String::from(""));
532        let resp = service.call(&mut cx, req).await.unwrap();
533        let (parts, _) = resp.into_response().into_parts();
534        assert_eq!(parts.headers.get("Server").unwrap(), "nginx");
535    }
536}