tower_async_http/decompression/
body.rs

1#![allow(unused_imports)]
2
3use crate::compression_utils::CompressionLevel;
4use crate::{
5    compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody},
6    BoxError,
7};
8#[cfg(feature = "decompression-br")]
9use async_compression::tokio::bufread::BrotliDecoder;
10#[cfg(feature = "decompression-gzip")]
11use async_compression::tokio::bufread::GzipDecoder;
12#[cfg(feature = "decompression-deflate")]
13use async_compression::tokio::bufread::ZlibDecoder;
14#[cfg(feature = "decompression-zstd")]
15use async_compression::tokio::bufread::ZstdDecoder;
16use bytes::{Buf, Bytes};
17use futures_util::ready;
18use http::HeaderMap;
19use http_body::{Body, Frame};
20use pin_project_lite::pin_project;
21use std::task::Context;
22use std::{io, marker::PhantomData, pin::Pin, task::Poll};
23use tokio_util::io::StreamReader;
24
25pin_project! {
26    /// Response body of [`RequestDecompression`] and [`Decompression`].
27    ///
28    /// [`RequestDecompression`]: super::RequestDecompression
29    /// [`Decompression`]: super::Decompression
30    pub struct DecompressionBody<B>
31    where
32        B: Body
33    {
34        #[pin]
35        pub(crate) inner: BodyInner<B>,
36    }
37}
38
39impl<B> Default for DecompressionBody<B>
40where
41    B: Body + Default,
42{
43    fn default() -> Self {
44        Self {
45            inner: BodyInner::Identity {
46                inner: B::default(),
47            },
48        }
49    }
50}
51
52impl<B> DecompressionBody<B>
53where
54    B: Body,
55{
56    pub(crate) fn new(inner: BodyInner<B>) -> Self {
57        Self { inner }
58    }
59}
60
61#[cfg(any(
62    not(feature = "decompression-gzip"),
63    not(feature = "decompression-deflate"),
64    not(feature = "decompression-br"),
65    not(feature = "decompression-zstd")
66))]
67pub(crate) enum Never {}
68
69#[cfg(feature = "decompression-gzip")]
70type GzipBody<B> = WrapBody<GzipDecoder<B>>;
71#[cfg(not(feature = "decompression-gzip"))]
72type GzipBody<B> = (Never, PhantomData<B>);
73
74#[cfg(feature = "decompression-deflate")]
75type DeflateBody<B> = WrapBody<ZlibDecoder<B>>;
76#[cfg(not(feature = "decompression-deflate"))]
77type DeflateBody<B> = (Never, PhantomData<B>);
78
79#[cfg(feature = "decompression-br")]
80type BrotliBody<B> = WrapBody<BrotliDecoder<B>>;
81#[cfg(not(feature = "decompression-br"))]
82type BrotliBody<B> = (Never, PhantomData<B>);
83
84#[cfg(feature = "decompression-zstd")]
85type ZstdBody<B> = WrapBody<ZstdDecoder<B>>;
86#[cfg(not(feature = "decompression-zstd"))]
87type ZstdBody<B> = (Never, PhantomData<B>);
88
89pin_project! {
90    #[project = BodyInnerProj]
91    pub(crate) enum BodyInner<B>
92    where
93        B: Body,
94    {
95        Gzip {
96            #[pin]
97            inner: GzipBody<B>,
98        },
99        Deflate {
100            #[pin]
101            inner: DeflateBody<B>,
102        },
103        Brotli {
104            #[pin]
105            inner: BrotliBody<B>,
106        },
107        Zstd {
108            #[pin]
109            inner: ZstdBody<B>,
110        },
111        Identity {
112            #[pin]
113            inner: B,
114        },
115    }
116}
117
118impl<B: Body> BodyInner<B> {
119    #[cfg(feature = "decompression-gzip")]
120    pub(crate) fn gzip(inner: WrapBody<GzipDecoder<B>>) -> Self {
121        Self::Gzip { inner }
122    }
123
124    #[cfg(feature = "decompression-deflate")]
125    pub(crate) fn deflate(inner: WrapBody<ZlibDecoder<B>>) -> Self {
126        Self::Deflate { inner }
127    }
128
129    #[cfg(feature = "decompression-br")]
130    pub(crate) fn brotli(inner: WrapBody<BrotliDecoder<B>>) -> Self {
131        Self::Brotli { inner }
132    }
133
134    #[cfg(feature = "decompression-zstd")]
135    pub(crate) fn zstd(inner: WrapBody<ZstdDecoder<B>>) -> Self {
136        Self::Zstd { inner }
137    }
138
139    pub(crate) fn identity(inner: B) -> Self {
140        Self::Identity { inner }
141    }
142}
143
144impl<B> Body for DecompressionBody<B>
145where
146    B: Body,
147    B::Error: Into<BoxError>,
148{
149    type Data = Bytes;
150    type Error = BoxError;
151
152    fn poll_frame(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
156        match self.project().inner.project() {
157            #[cfg(feature = "decompression-gzip")]
158            BodyInnerProj::Gzip { inner } => inner.poll_frame(cx),
159            #[cfg(feature = "decompression-deflate")]
160            BodyInnerProj::Deflate { inner } => inner.poll_frame(cx),
161            #[cfg(feature = "decompression-br")]
162            BodyInnerProj::Brotli { inner } => inner.poll_frame(cx),
163            #[cfg(feature = "decompression-zstd")]
164            BodyInnerProj::Zstd { inner } => inner.poll_frame(cx),
165            BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) {
166                Some(Ok(frame)) => {
167                    let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()));
168                    Poll::Ready(Some(Ok(frame)))
169                }
170                Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
171                None => Poll::Ready(None),
172            },
173
174            #[cfg(not(feature = "decompression-gzip"))]
175            BodyInnerProj::Gzip { inner } => match inner.0 {},
176            #[cfg(not(feature = "decompression-deflate"))]
177            BodyInnerProj::Deflate { inner } => match inner.0 {},
178            #[cfg(not(feature = "decompression-br"))]
179            BodyInnerProj::Brotli { inner } => match inner.0 {},
180            #[cfg(not(feature = "decompression-zstd"))]
181            BodyInnerProj::Zstd { inner } => match inner.0 {},
182        }
183    }
184}
185
186#[cfg(feature = "decompression-gzip")]
187impl<B> DecorateAsyncRead for GzipDecoder<B>
188where
189    B: Body,
190{
191    type Input = AsyncReadBody<B>;
192    type Output = GzipDecoder<Self::Input>;
193
194    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
195        let mut decoder = GzipDecoder::new(input);
196        decoder.multiple_members(true);
197        decoder
198    }
199
200    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
201        pinned.get_pin_mut()
202    }
203}
204
205#[cfg(feature = "decompression-deflate")]
206impl<B> DecorateAsyncRead for ZlibDecoder<B>
207where
208    B: Body,
209{
210    type Input = AsyncReadBody<B>;
211    type Output = ZlibDecoder<Self::Input>;
212
213    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
214        ZlibDecoder::new(input)
215    }
216
217    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
218        pinned.get_pin_mut()
219    }
220}
221
222#[cfg(feature = "decompression-br")]
223impl<B> DecorateAsyncRead for BrotliDecoder<B>
224where
225    B: Body,
226{
227    type Input = AsyncReadBody<B>;
228    type Output = BrotliDecoder<Self::Input>;
229
230    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
231        BrotliDecoder::new(input)
232    }
233
234    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
235        pinned.get_pin_mut()
236    }
237}
238
239#[cfg(feature = "decompression-zstd")]
240impl<B> DecorateAsyncRead for ZstdDecoder<B>
241where
242    B: Body,
243{
244    type Input = AsyncReadBody<B>;
245    type Output = ZstdDecoder<Self::Input>;
246
247    fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output {
248        ZstdDecoder::new(input)
249    }
250
251    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
252        pinned.get_pin_mut()
253    }
254}