tower_async_http/
map_request_body.rs

1//! Apply a transformation to the request body.
2//!
3//! # Example
4//!
5//! ```
6//! use http_body_util::Full;
7//! use bytes::Bytes;
8//! use http::{Request, Response};
9//! use std::convert::Infallible;
10//! use std::{pin::Pin, task::{ready, Context, Poll}};
11//! use tower_async::{ServiceBuilder, service_fn, ServiceExt, Service};
12//! use tower_async_http::map_request_body::MapRequestBodyLayer;
13//!
14//! // A wrapper for a `Full<Bytes>`
15//! struct BodyWrapper {
16//!     inner: Full<Bytes>,
17//! }
18//!
19//! impl BodyWrapper {
20//!     fn new(inner: Full<Bytes>) -> Self {
21//!         Self { inner }
22//!     }
23//! }
24//!
25//! impl http_body::Body for BodyWrapper {
26//!     // ...
27//!     # type Data = Bytes;
28//!     # type Error = tower::BoxError;
29//!     # fn poll_frame(
30//!     #     self: Pin<&mut Self>,
31//!     #     cx: &mut Context<'_>
32//!     # ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { unimplemented!() }
33//!     # fn is_end_stream(&self) -> bool { unimplemented!() }
34//!     # fn size_hint(&self) -> http_body::SizeHint { unimplemented!() }
35//! }
36//!
37//! async fn handle<B>(_: Request<B>) -> Result<Response<Full<Bytes>>, Infallible> {
38//!     // ...
39//!     # Ok(Response::new(Full::default()))
40//! }
41//!
42//! # #[tokio::main]
43//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
44//! let mut svc = ServiceBuilder::new()
45//!     // Wrap response bodies in `BodyWrapper`
46//!     .layer(MapRequestBodyLayer::new(BodyWrapper::new))
47//!     .service_fn(handle);
48//!
49//! // Call the service
50//! let request = Request::new(Full::default());
51//!
52//! svc.call(request).await?;
53//! # Ok(())
54//! # }
55//! ```
56
57use http::{Request, Response};
58use std::fmt;
59use tower_async_layer::Layer;
60use tower_async_service::Service;
61
62/// Apply a transformation to the request body.
63///
64/// See the [module docs](crate::map_request_body) for an example.
65#[derive(Clone)]
66pub struct MapRequestBodyLayer<F> {
67    f: F,
68}
69
70impl<F> MapRequestBodyLayer<F> {
71    /// Create a new [`MapRequestBodyLayer`].
72    ///
73    /// `F` is expected to be a function that takes a body and returns another body.
74    pub fn new(f: F) -> Self {
75        Self { f }
76    }
77}
78
79impl<S, F> Layer<S> for MapRequestBodyLayer<F>
80where
81    F: Clone,
82{
83    type Service = MapRequestBody<S, F>;
84
85    fn layer(&self, inner: S) -> Self::Service {
86        MapRequestBody::new(inner, self.f.clone())
87    }
88}
89
90impl<F> fmt::Debug for MapRequestBodyLayer<F> {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        f.debug_struct("MapRequestBodyLayer")
93            .field("f", &std::any::type_name::<F>())
94            .finish()
95    }
96}
97
98/// Apply a transformation to the request body.
99///
100/// See the [module docs](crate::map_request_body) for an example.
101#[derive(Clone)]
102pub struct MapRequestBody<S, F> {
103    inner: S,
104    f: F,
105}
106
107impl<S, F> MapRequestBody<S, F> {
108    /// Create a new [`MapRequestBody`].
109    ///
110    /// `F` is expected to be a function that takes a body and returns another body.
111    pub fn new(service: S, f: F) -> Self {
112        Self { inner: service, f }
113    }
114
115    /// Returns a new [`Layer`] that wraps services with a `MapRequestBodyLayer` middleware.
116    ///
117    /// [`Layer`]: tower_async_layer::Layer
118    pub fn layer(f: F) -> MapRequestBodyLayer<F> {
119        MapRequestBodyLayer::new(f)
120    }
121
122    define_inner_service_accessors!();
123}
124
125impl<F, S, ReqBody, ResBody, NewReqBody> Service<Request<ReqBody>> for MapRequestBody<S, F>
126where
127    S: Service<Request<NewReqBody>, Response = Response<ResBody>>,
128    F: Fn(ReqBody) -> NewReqBody,
129{
130    type Response = S::Response;
131    type Error = S::Error;
132
133    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
134        let req = req.map(&self.f);
135        self.inner.call(req).await
136    }
137}
138
139impl<S, F> fmt::Debug for MapRequestBody<S, F>
140where
141    S: fmt::Debug,
142{
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        f.debug_struct("MapRequestBody")
145            .field("inner", &self.inner)
146            .field("f", &std::any::type_name::<F>())
147            .finish()
148    }
149}