warp_request_body/
lib.rs

1use core::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use bytes::Bytes;
7use pin_project_lite::pin_project;
8use warp::{
9    hyper::{Body as HyperBody, Request as HyperRequest},
10    Buf, Error as WarpError, Stream,
11};
12
13pub mod error;
14pub mod utils;
15
16use error::Error;
17
18//
19pin_project! {
20    #[project = BodyProj]
21    pub enum Body {
22        Buf { inner: Box<dyn Buf + Send + 'static> },
23        Bytes { inner: Bytes },
24        Stream { #[pin] inner: Pin<Box<dyn Stream<Item = Result<Bytes, WarpError>> + Send + 'static>> },
25        HyperBody { #[pin] inner: HyperBody }
26    }
27}
28
29impl core::fmt::Debug for Body {
30    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
31        match self {
32            Self::Buf { inner } => f.debug_tuple("Buf").field(&inner.chunk()).finish(),
33            Self::Bytes { inner } => f.debug_tuple("Bytes").field(&inner).finish(),
34            Self::Stream { inner: _ } => write!(f, "Stream"),
35            Self::HyperBody { inner: _ } => write!(f, "HyperBody"),
36        }
37    }
38}
39
40impl core::fmt::Display for Body {
41    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42        write!(f, "{self:?}")
43    }
44}
45
46impl Default for Body {
47    fn default() -> Self {
48        Self::Bytes {
49            inner: Bytes::default(),
50        }
51    }
52}
53
54//
55impl Body {
56    pub fn with_buf(buf: impl Buf + Send + 'static) -> Self {
57        Self::Buf {
58            inner: Box::new(buf),
59        }
60    }
61
62    pub fn with_bytes(bytes: Bytes) -> Self {
63        Self::Bytes { inner: bytes }
64    }
65
66    pub fn with_stream(
67        stream: impl Stream<Item = Result<impl Buf + 'static, WarpError>> + Send + 'static,
68    ) -> Self {
69        Self::Stream {
70            inner: Box::pin(utils::buf_stream_to_bytes_stream(stream)),
71        }
72    }
73
74    pub fn with_hyper_body(hyper_body: HyperBody) -> Self {
75        Self::HyperBody { inner: hyper_body }
76    }
77}
78
79impl From<HyperBody> for Body {
80    fn from(body: HyperBody) -> Self {
81        Self::with_hyper_body(body)
82    }
83}
84
85impl Body {
86    pub fn require_to_bytes_async(&self) -> bool {
87        matches!(
88            self,
89            Self::Stream { inner: _ } | Self::HyperBody { inner: _ }
90        )
91    }
92
93    pub fn to_bytes(self) -> Bytes {
94        match self {
95            Self::Buf { inner } => utils::buf_to_bytes(inner),
96            Self::Bytes { inner } => inner,
97            Self::Stream { inner: _ } => panic!("Please call require_to_bytes_async first"),
98            Self::HyperBody { inner: _ } => panic!("Please call require_to_bytes_async first"),
99        }
100    }
101
102    pub async fn to_bytes_async(self) -> Result<Bytes, Error> {
103        match self {
104            Self::Buf { inner } => Ok(utils::buf_to_bytes(inner)),
105            Self::Bytes { inner } => Ok(inner),
106            Self::Stream { inner } => utils::bytes_stream_to_bytes(inner)
107                .await
108                .map_err(Into::into),
109            Self::HyperBody { inner } => {
110                utils::hyper_body_to_bytes(inner).await.map_err(Into::into)
111            }
112        }
113    }
114}
115
116//
117
118//
119impl Stream for Body {
120    type Item = Result<Bytes, Error>;
121
122    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123        match self.project() {
124            BodyProj::Buf { inner: buf } => {
125                if buf.has_remaining() {
126                    let bytes = Bytes::copy_from_slice(buf.chunk());
127                    let cnt = buf.chunk().len();
128                    buf.advance(cnt);
129                    Poll::Ready(Some(Ok(bytes)))
130                } else {
131                    Poll::Ready(None)
132                }
133            }
134            BodyProj::Bytes { inner } => {
135                if !inner.is_empty() {
136                    let bytes = inner.clone();
137                    inner.clear();
138                    Poll::Ready(Some(Ok(bytes)))
139                } else {
140                    Poll::Ready(None)
141                }
142            }
143            BodyProj::Stream { inner } => inner.poll_next(cx).map_err(Into::into),
144            BodyProj::HyperBody { inner } => inner.poll_next(cx).map_err(Into::into),
145        }
146    }
147}
148
149//
150pub fn buf_request_to_body_request(
151    req: HyperRequest<impl Buf + Send + 'static>,
152) -> HyperRequest<Body> {
153    let (parts, body) = req.into_parts();
154    HyperRequest::from_parts(parts, Body::with_buf(body))
155}
156
157pub fn bytes_request_to_body_request(req: HyperRequest<Bytes>) -> HyperRequest<Body> {
158    let (parts, body) = req.into_parts();
159    HyperRequest::from_parts(parts, Body::with_bytes(body))
160}
161
162pub fn stream_request_to_body_request(
163    req: HyperRequest<impl Stream<Item = Result<impl Buf + 'static, WarpError>> + Send + 'static>,
164) -> HyperRequest<Body> {
165    let (parts, body) = req.into_parts();
166    HyperRequest::from_parts(parts, Body::with_stream(body))
167}
168
169pub fn hyper_body_request_to_body_request(req: HyperRequest<HyperBody>) -> HyperRequest<Body> {
170    let (parts, body) = req.into_parts();
171    HyperRequest::from_parts(parts, Body::with_hyper_body(body))
172}
173
174#[cfg(test)]
175mod tests {
176    use futures_util::{stream::BoxStream, StreamExt as _, TryStreamExt};
177
178    use super::*;
179
180    #[tokio::test]
181    async fn test_with_buf() {
182        //
183        let buf = warp::test::request()
184            .body("foo")
185            .filter(&warp::body::aggregate())
186            .await
187            .unwrap();
188        let body = Body::with_buf(buf);
189        assert!(matches!(body, Body::Buf { inner: _ }));
190        assert!(!body.require_to_bytes_async());
191        assert_eq!(body.to_bytes(), Bytes::copy_from_slice(b"foo"));
192
193        //
194        let buf = warp::test::request()
195            .body("foo")
196            .filter(&warp::body::aggregate())
197            .await
198            .unwrap();
199        let body = Body::with_buf(buf);
200        assert_eq!(
201            body.to_bytes_async().await.unwrap(),
202            Bytes::copy_from_slice(b"foo")
203        );
204
205        //
206        let buf = warp::test::request()
207            .body("foo")
208            .filter(&warp::body::aggregate())
209            .await
210            .unwrap();
211        let mut body = Body::with_buf(buf);
212        assert_eq!(
213            body.next().await.unwrap().unwrap(),
214            Bytes::copy_from_slice(b"foo")
215        );
216        assert!(body.next().await.is_none());
217
218        //
219        let req = warp::test::request()
220            .body("foo")
221            .filter(&warp_filter_request::with_body_aggregate())
222            .await
223            .unwrap();
224        let (_, body) = buf_request_to_body_request(req).into_parts();
225        assert!(matches!(body, Body::Buf { inner: _ }));
226        assert!(!body.require_to_bytes_async());
227        assert_eq!(body.to_bytes(), Bytes::copy_from_slice(b"foo"));
228    }
229
230    #[tokio::test]
231    async fn test_with_bytes() {
232        //
233        let bytes = warp::test::request()
234            .body("foo")
235            .filter(&warp::body::bytes())
236            .await
237            .unwrap();
238        let body = Body::with_bytes(bytes);
239        assert!(matches!(body, Body::Bytes { inner: _ }));
240        assert!(!body.require_to_bytes_async());
241        assert_eq!(body.to_bytes(), Bytes::copy_from_slice(b"foo"));
242
243        //
244        let bytes = warp::test::request()
245            .body("foo")
246            .filter(&warp::body::bytes())
247            .await
248            .unwrap();
249        let body = Body::with_bytes(bytes);
250        assert_eq!(
251            body.to_bytes_async().await.unwrap(),
252            Bytes::copy_from_slice(b"foo")
253        );
254
255        //
256        let bytes = warp::test::request()
257            .body("foo")
258            .filter(&warp::body::bytes())
259            .await
260            .unwrap();
261        let mut body = Body::with_bytes(bytes);
262        assert_eq!(
263            body.next().await.unwrap().unwrap(),
264            Bytes::copy_from_slice(b"foo")
265        );
266        assert!(body.next().await.is_none());
267
268        //
269        let req = warp::test::request()
270            .body("foo")
271            .filter(&warp_filter_request::with_body_bytes())
272            .await
273            .unwrap();
274        let (_, body) = bytes_request_to_body_request(req).into_parts();
275        assert!(matches!(body, Body::Bytes { inner: _ }));
276        assert!(!body.require_to_bytes_async());
277        assert_eq!(body.to_bytes(), Bytes::copy_from_slice(b"foo"));
278    }
279
280    #[tokio::test]
281    async fn test_with_stream() {
282        //
283        let stream = warp::test::request()
284            .body("foo")
285            .filter(&warp::body::stream())
286            .await
287            .unwrap();
288        let body = Body::with_stream(stream);
289        assert!(matches!(body, Body::Stream { inner: _ }));
290        assert!(body.require_to_bytes_async());
291        assert_eq!(
292            body.to_bytes_async().await.unwrap(),
293            Bytes::copy_from_slice(b"foo")
294        );
295
296        //
297        let stream = warp::test::request()
298            .body("foo")
299            .filter(&warp::body::stream())
300            .await
301            .unwrap();
302        let mut body = Body::with_stream(stream);
303        assert_eq!(
304            body.next().await.unwrap().unwrap(),
305            Bytes::copy_from_slice(b"foo")
306        );
307        assert!(body.next().await.is_none());
308
309        //
310        let req = warp::test::request()
311            .body("foo")
312            .filter(&warp_filter_request::with_body_stream())
313            .await
314            .unwrap();
315        let (_, body) = stream_request_to_body_request(req).into_parts();
316        assert!(matches!(body, Body::Stream { inner: _ }));
317        assert!(body.require_to_bytes_async());
318        assert_eq!(
319            body.to_bytes_async().await.unwrap(),
320            Bytes::copy_from_slice(b"foo")
321        );
322    }
323
324    #[tokio::test]
325    async fn test_with_hyper_body() {
326        //
327        let hyper_body = HyperBody::from("foo");
328        let body = Body::with_hyper_body(hyper_body);
329        assert!(matches!(body, Body::HyperBody { inner: _ }));
330        assert!(body.require_to_bytes_async());
331        assert_eq!(
332            body.to_bytes_async().await.unwrap(),
333            Bytes::copy_from_slice(b"foo")
334        );
335
336        //
337        let hyper_body = HyperBody::from("foo");
338        let mut body = Body::with_hyper_body(hyper_body);
339        assert_eq!(
340            body.next().await.unwrap().unwrap(),
341            Bytes::copy_from_slice(b"foo")
342        );
343        assert!(body.next().await.is_none());
344
345        //
346        let req = HyperRequest::new(HyperBody::from("foo"));
347        let (_, body) = hyper_body_request_to_body_request(req).into_parts();
348        assert!(matches!(body, Body::HyperBody { inner: _ }));
349        assert!(body.require_to_bytes_async());
350        assert_eq!(
351            body.to_bytes_async().await.unwrap(),
352            Bytes::copy_from_slice(b"foo")
353        );
354    }
355
356    pin_project! {
357        pub struct BodyWrapper {
358            #[pin]
359            inner: BoxStream<'static, Result<Bytes, Box<dyn std::error::Error + Send + Sync + 'static>>>
360        }
361    }
362    #[tokio::test]
363    async fn test_wrapper() {
364        //
365        let buf = warp::test::request()
366            .body("foo")
367            .filter(&warp::body::aggregate())
368            .await
369            .unwrap();
370        let body = Body::with_buf(buf);
371        let _ = BodyWrapper {
372            inner: body.err_into().boxed(),
373        };
374
375        //
376        let stream = warp::test::request()
377            .body("foo")
378            .filter(&warp::body::stream())
379            .await
380            .unwrap();
381        let body = Body::with_stream(stream);
382        let _ = BodyWrapper {
383            inner: body.err_into().boxed(),
384        };
385    }
386}