tower_async_http/map_request_body.rs
1//! Apply a transformation to the request body.
2//!
3//! # Example
4//!
5//! ```
6//! use http_body_util::Full;
7//! use bytes::Bytes;
8//! use http::{Request, Response};
9//! use std::convert::Infallible;
10//! use std::{pin::Pin, task::{ready, Context, Poll}};
11//! use tower_async::{ServiceBuilder, service_fn, ServiceExt, Service};
12//! use tower_async_http::map_request_body::MapRequestBodyLayer;
13//!
14//! // A wrapper for a `Full<Bytes>`
15//! struct BodyWrapper {
16//! inner: Full<Bytes>,
17//! }
18//!
19//! impl BodyWrapper {
20//! fn new(inner: Full<Bytes>) -> Self {
21//! Self { inner }
22//! }
23//! }
24//!
25//! impl http_body::Body for BodyWrapper {
26//! // ...
27//! # type Data = Bytes;
28//! # type Error = tower::BoxError;
29//! # fn poll_frame(
30//! # self: Pin<&mut Self>,
31//! # cx: &mut Context<'_>
32//! # ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { unimplemented!() }
33//! # fn is_end_stream(&self) -> bool { unimplemented!() }
34//! # fn size_hint(&self) -> http_body::SizeHint { unimplemented!() }
35//! }
36//!
37//! async fn handle<B>(_: Request<B>) -> Result<Response<Full<Bytes>>, Infallible> {
38//! // ...
39//! # Ok(Response::new(Full::default()))
40//! }
41//!
42//! # #[tokio::main]
43//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
44//! let mut svc = ServiceBuilder::new()
45//! // Wrap response bodies in `BodyWrapper`
46//! .layer(MapRequestBodyLayer::new(BodyWrapper::new))
47//! .service_fn(handle);
48//!
49//! // Call the service
50//! let request = Request::new(Full::default());
51//!
52//! svc.call(request).await?;
53//! # Ok(())
54//! # }
55//! ```
56
57use http::{Request, Response};
58use std::fmt;
59use tower_async_layer::Layer;
60use tower_async_service::Service;
61
62/// Apply a transformation to the request body.
63///
64/// See the [module docs](crate::map_request_body) for an example.
65#[derive(Clone)]
66pub struct MapRequestBodyLayer<F> {
67 f: F,
68}
69
70impl<F> MapRequestBodyLayer<F> {
71 /// Create a new [`MapRequestBodyLayer`].
72 ///
73 /// `F` is expected to be a function that takes a body and returns another body.
74 pub fn new(f: F) -> Self {
75 Self { f }
76 }
77}
78
79impl<S, F> Layer<S> for MapRequestBodyLayer<F>
80where
81 F: Clone,
82{
83 type Service = MapRequestBody<S, F>;
84
85 fn layer(&self, inner: S) -> Self::Service {
86 MapRequestBody::new(inner, self.f.clone())
87 }
88}
89
90impl<F> fmt::Debug for MapRequestBodyLayer<F> {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.debug_struct("MapRequestBodyLayer")
93 .field("f", &std::any::type_name::<F>())
94 .finish()
95 }
96}
97
98/// Apply a transformation to the request body.
99///
100/// See the [module docs](crate::map_request_body) for an example.
101#[derive(Clone)]
102pub struct MapRequestBody<S, F> {
103 inner: S,
104 f: F,
105}
106
107impl<S, F> MapRequestBody<S, F> {
108 /// Create a new [`MapRequestBody`].
109 ///
110 /// `F` is expected to be a function that takes a body and returns another body.
111 pub fn new(service: S, f: F) -> Self {
112 Self { inner: service, f }
113 }
114
115 /// Returns a new [`Layer`] that wraps services with a `MapRequestBodyLayer` middleware.
116 ///
117 /// [`Layer`]: tower_async_layer::Layer
118 pub fn layer(f: F) -> MapRequestBodyLayer<F> {
119 MapRequestBodyLayer::new(f)
120 }
121
122 define_inner_service_accessors!();
123}
124
125impl<F, S, ReqBody, ResBody, NewReqBody> Service<Request<ReqBody>> for MapRequestBody<S, F>
126where
127 S: Service<Request<NewReqBody>, Response = Response<ResBody>>,
128 F: Fn(ReqBody) -> NewReqBody,
129{
130 type Response = S::Response;
131 type Error = S::Error;
132
133 async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
134 let req = req.map(&self.f);
135 self.inner.call(req).await
136 }
137}
138
139impl<S, F> fmt::Debug for MapRequestBody<S, F>
140where
141 S: fmt::Debug,
142{
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 f.debug_struct("MapRequestBody")
145 .field("inner", &self.inner)
146 .field("f", &std::any::type_name::<F>())
147 .finish()
148 }
149}