tower_async_http/compression/
body.rs

1#![allow(unused_imports)]
2
3use crate::compression::CompressionLevel;
4use crate::{
5    compression_utils::{AsyncReadBody, BodyIntoStream, DecorateAsyncRead, WrapBody},
6    BoxError,
7};
8#[cfg(feature = "compression-br")]
9use async_compression::tokio::bufread::BrotliEncoder;
10#[cfg(feature = "compression-gzip")]
11use async_compression::tokio::bufread::GzipEncoder;
12#[cfg(feature = "compression-deflate")]
13use async_compression::tokio::bufread::ZlibEncoder;
14#[cfg(feature = "compression-zstd")]
15use async_compression::tokio::bufread::ZstdEncoder;
16
17use bytes::{Buf, Bytes};
18use futures_util::ready;
19use http::HeaderMap;
20use http_body::{Body, Frame};
21use pin_project_lite::pin_project;
22use std::{
23    io,
24    marker::PhantomData,
25    pin::Pin,
26    task::{Context, Poll},
27};
28use tokio_util::io::StreamReader;
29
30use super::pin_project_cfg::pin_project_cfg;
31
32pin_project! {
33    /// Response body of [`Compression`].
34    ///
35    /// [`Compression`]: super::Compression
36    pub struct CompressionBody<B>
37    where
38        B: Body,
39    {
40        #[pin]
41        pub(crate) inner: BodyInner<B>,
42    }
43}
44
45impl<B> Default for CompressionBody<B>
46where
47    B: Body + Default,
48{
49    fn default() -> Self {
50        Self {
51            inner: BodyInner::Identity {
52                inner: B::default(),
53            },
54        }
55    }
56}
57
58impl<B> CompressionBody<B>
59where
60    B: Body,
61{
62    pub(crate) fn new(inner: BodyInner<B>) -> Self {
63        Self { inner }
64    }
65}
66
67#[cfg(feature = "compression-gzip")]
68type GzipBody<B> = WrapBody<GzipEncoder<B>>;
69
70#[cfg(feature = "compression-deflate")]
71type DeflateBody<B> = WrapBody<ZlibEncoder<B>>;
72
73#[cfg(feature = "compression-br")]
74type BrotliBody<B> = WrapBody<BrotliEncoder<B>>;
75
76#[cfg(feature = "compression-zstd")]
77type ZstdBody<B> = WrapBody<ZstdEncoder<B>>;
78
79pin_project_cfg! {
80    #[project = BodyInnerProj]
81    pub(crate) enum BodyInner<B>
82    where
83        B: Body,
84    {
85        #[cfg(feature = "compression-gzip")]
86        Gzip {
87            #[pin]
88            inner: GzipBody<B>,
89        },
90        #[cfg(feature = "compression-deflate")]
91        Deflate {
92            #[pin]
93            inner: DeflateBody<B>,
94        },
95        #[cfg(feature = "compression-br")]
96        Brotli {
97            #[pin]
98            inner: BrotliBody<B>,
99        },
100        #[cfg(feature = "compression-zstd")]
101        Zstd {
102            #[pin]
103            inner: ZstdBody<B>,
104        },
105        Identity {
106            #[pin]
107            inner: B,
108        },
109    }
110}
111
112impl<B: Body> BodyInner<B> {
113    #[cfg(feature = "compression-gzip")]
114    pub(crate) fn gzip(inner: WrapBody<GzipEncoder<B>>) -> Self {
115        Self::Gzip { inner }
116    }
117
118    #[cfg(feature = "compression-deflate")]
119    pub(crate) fn deflate(inner: WrapBody<ZlibEncoder<B>>) -> Self {
120        Self::Deflate { inner }
121    }
122
123    #[cfg(feature = "compression-br")]
124    pub(crate) fn brotli(inner: WrapBody<BrotliEncoder<B>>) -> Self {
125        Self::Brotli { inner }
126    }
127
128    #[cfg(feature = "compression-zstd")]
129    pub(crate) fn zstd(inner: WrapBody<ZstdEncoder<B>>) -> Self {
130        Self::Zstd { inner }
131    }
132
133    pub(crate) fn identity(inner: B) -> Self {
134        Self::Identity { inner }
135    }
136}
137
138impl<B> Body for CompressionBody<B>
139where
140    B: Body,
141    B::Error: Into<BoxError>,
142{
143    type Data = Bytes;
144    type Error = BoxError;
145
146    fn poll_frame(
147        self: Pin<&mut Self>,
148        cx: &mut Context<'_>,
149    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
150        match self.project().inner.project() {
151            #[cfg(feature = "compression-gzip")]
152            BodyInnerProj::Gzip { inner } => inner.poll_frame(cx),
153            #[cfg(feature = "compression-deflate")]
154            BodyInnerProj::Deflate { inner } => inner.poll_frame(cx),
155            #[cfg(feature = "compression-br")]
156            BodyInnerProj::Brotli { inner } => inner.poll_frame(cx),
157            #[cfg(feature = "compression-zstd")]
158            BodyInnerProj::Zstd { inner } => inner.poll_frame(cx),
159            BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) {
160                Some(Ok(frame)) => {
161                    let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()));
162                    Poll::Ready(Some(Ok(frame)))
163                }
164                Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
165                None => Poll::Ready(None),
166            },
167        }
168    }
169}
170
171#[cfg(feature = "compression-gzip")]
172impl<B> DecorateAsyncRead for GzipEncoder<B>
173where
174    B: Body,
175{
176    type Input = AsyncReadBody<B>;
177    type Output = GzipEncoder<Self::Input>;
178
179    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
180        GzipEncoder::with_quality(input, quality.into_async_compression())
181    }
182
183    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
184        pinned.get_pin_mut()
185    }
186}
187
188#[cfg(feature = "compression-deflate")]
189impl<B> DecorateAsyncRead for ZlibEncoder<B>
190where
191    B: Body,
192{
193    type Input = AsyncReadBody<B>;
194    type Output = ZlibEncoder<Self::Input>;
195
196    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
197        ZlibEncoder::with_quality(input, quality.into_async_compression())
198    }
199
200    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
201        pinned.get_pin_mut()
202    }
203}
204
205#[cfg(feature = "compression-br")]
206impl<B> DecorateAsyncRead for BrotliEncoder<B>
207where
208    B: Body,
209{
210    type Input = AsyncReadBody<B>;
211    type Output = BrotliEncoder<Self::Input>;
212
213    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
214        // The brotli crate used under the hood here has a default compression level of 11,
215        // which is the max for brotli. This causes extremely slow compression times, so we
216        // manually set a default of 4 here.
217        //
218        // This is the same default used by NGINX for on-the-fly brotli compression.
219        let level = match quality {
220            CompressionLevel::Default => async_compression::Level::Precise(4),
221            other => other.into_async_compression(),
222        };
223        BrotliEncoder::with_quality(input, level)
224    }
225
226    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
227        pinned.get_pin_mut()
228    }
229}
230
231#[cfg(feature = "compression-zstd")]
232impl<B> DecorateAsyncRead for ZstdEncoder<B>
233where
234    B: Body,
235{
236    type Input = AsyncReadBody<B>;
237    type Output = ZstdEncoder<Self::Input>;
238
239    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output {
240        ZstdEncoder::with_quality(input, quality.into_async_compression())
241    }
242
243    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
244        pinned.get_pin_mut()
245    }
246}