tower_async_http/
compression_utils.rs

1//! Types used by compression and decompression middleware.
2
3use crate::{content_encoding::SupportedEncodings, BoxError};
4use bytes::{Buf, Bytes, BytesMut};
5use futures_core::Stream;
6use futures_util::ready;
7use http::HeaderValue;
8use http_body::{Body, Frame};
9use pin_project_lite::pin_project;
10use std::{
11    io,
12    pin::Pin,
13    task::{Context, Poll},
14};
15use tokio::io::AsyncRead;
16use tokio_util::io::StreamReader;
17
18#[derive(Debug, Clone, Copy)]
19pub(crate) struct AcceptEncoding {
20    pub(crate) gzip: bool,
21    pub(crate) deflate: bool,
22    pub(crate) br: bool,
23    pub(crate) zstd: bool,
24}
25
26impl AcceptEncoding {
27    #[allow(dead_code)]
28    pub(crate) fn to_header_value(self) -> Option<HeaderValue> {
29        let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) {
30            (true, true, true, false) => "gzip,deflate,br",
31            (true, true, false, false) => "gzip,deflate",
32            (true, false, true, false) => "gzip,br",
33            (true, false, false, false) => "gzip",
34            (false, true, true, false) => "deflate,br",
35            (false, true, false, false) => "deflate",
36            (false, false, true, false) => "br",
37            (true, true, true, true) => "zstd,gzip,deflate,br",
38            (true, true, false, true) => "zstd,gzip,deflate",
39            (true, false, true, true) => "zstd,gzip,br",
40            (true, false, false, true) => "zstd,gzip",
41            (false, true, true, true) => "zstd,deflate,br",
42            (false, true, false, true) => "zstd,deflate",
43            (false, false, true, true) => "zstd,br",
44            (false, false, false, true) => "zstd",
45            (false, false, false, false) => return None,
46        };
47        Some(HeaderValue::from_static(accept))
48    }
49
50    #[allow(dead_code)]
51    pub(crate) fn set_gzip(&mut self, enable: bool) {
52        self.gzip = enable;
53    }
54
55    #[allow(dead_code)]
56    pub(crate) fn set_deflate(&mut self, enable: bool) {
57        self.deflate = enable;
58    }
59
60    #[allow(dead_code)]
61    pub(crate) fn set_br(&mut self, enable: bool) {
62        self.br = enable;
63    }
64
65    #[allow(dead_code)]
66    pub(crate) fn set_zstd(&mut self, enable: bool) {
67        self.zstd = enable;
68    }
69}
70
71impl SupportedEncodings for AcceptEncoding {
72    #[allow(dead_code)]
73    fn gzip(&self) -> bool {
74        #[cfg(any(feature = "decompression-gzip", feature = "compression-gzip"))]
75        {
76            self.gzip
77        }
78        #[cfg(not(any(feature = "decompression-gzip", feature = "compression-gzip")))]
79        {
80            false
81        }
82    }
83
84    #[allow(dead_code)]
85    fn deflate(&self) -> bool {
86        #[cfg(any(feature = "decompression-deflate", feature = "compression-deflate"))]
87        {
88            self.deflate
89        }
90        #[cfg(not(any(feature = "decompression-deflate", feature = "compression-deflate")))]
91        {
92            false
93        }
94    }
95
96    #[allow(dead_code)]
97    fn br(&self) -> bool {
98        #[cfg(any(feature = "decompression-br", feature = "compression-br"))]
99        {
100            self.br
101        }
102        #[cfg(not(any(feature = "decompression-br", feature = "compression-br")))]
103        {
104            false
105        }
106    }
107
108    #[allow(dead_code)]
109    fn zstd(&self) -> bool {
110        #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))]
111        {
112            self.zstd
113        }
114        #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))]
115        {
116            false
117        }
118    }
119}
120
121impl Default for AcceptEncoding {
122    fn default() -> Self {
123        AcceptEncoding {
124            gzip: true,
125            deflate: true,
126            br: true,
127            zstd: true,
128        }
129    }
130}
131
132/// A `Body` that has been converted into an `AsyncRead`.
133pub(crate) type AsyncReadBody<B> =
134    StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>;
135
136/// Trait for applying some decorator to an `AsyncRead`
137pub(crate) trait DecorateAsyncRead {
138    type Input: AsyncRead;
139    type Output: AsyncRead;
140
141    /// Apply the decorator
142    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
143
144    /// Get a pinned mutable reference to the original input.
145    ///
146    /// This is necessary to implement `Body::poll_trailers`.
147    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
148}
149
150pin_project! {
151    /// `Body` that has been decorated by an `AsyncRead`
152    pub(crate) struct WrapBody<M: DecorateAsyncRead> {
153        #[pin]
154        // rust-analyer thinks this field is private if its `pub(crate)` but works fine when its
155        // `pub`
156        pub read: M::Output,
157        read_all_data: bool,
158    }
159}
160
161impl<M: DecorateAsyncRead> WrapBody<M> {
162    #[allow(dead_code)]
163    pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
164    where
165        B: Body,
166        M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
167    {
168        // convert `Body` into a `Stream`
169        let stream = BodyIntoStream::new(body);
170
171        // an adapter that converts the error type into `io::Error` while storing the actual error
172        // `StreamReader` requires the error type is `io::Error`
173        let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
174
175        // convert `Stream` into an `AsyncRead`
176        let read = StreamReader::new(stream);
177
178        // apply decorator to `AsyncRead` yielding another `AsyncRead`
179        let read = M::apply(read, quality);
180
181        Self {
182            read,
183            read_all_data: false,
184        }
185    }
186}
187
188impl<B, M> Body for WrapBody<M>
189where
190    B: Body,
191    B::Error: Into<BoxError>,
192    M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
193{
194    type Data = Bytes;
195    type Error = BoxError;
196
197    fn poll_frame(
198        self: Pin<&mut Self>,
199        cx: &mut Context<'_>,
200    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
201        let mut this = self.project();
202        let mut buf = BytesMut::new();
203        if !*this.read_all_data {
204            match tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut buf) {
205                Poll::Ready(result) => {
206                    match result {
207                        Ok(read) => {
208                            if read == 0 {
209                                *this.read_all_data = true;
210                            } else {
211                                return Poll::Ready(Some(Ok(Frame::data(buf.freeze()))));
212                            }
213                        }
214                        Err(err) => {
215                            let body_error: Option<B::Error> = M::get_pin_mut(this.read)
216                                .get_pin_mut()
217                                .project()
218                                .error
219                                .take();
220
221                            if let Some(body_error) = body_error {
222                                return Poll::Ready(Some(Err(body_error.into())));
223                            } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
224                                // SENTINEL_ERROR_CODE only gets used when storing an underlying body error
225                                unreachable!()
226                            } else {
227                                return Poll::Ready(Some(Err(err.into())));
228                            }
229                        }
230                    }
231                }
232                Poll::Pending => return Poll::Pending,
233            }
234        }
235        // poll any remaining frames, such as trailers
236        let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
237        body.poll_frame(cx).map(|option| {
238            option.map(|result| {
239                result
240                    .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining())))
241                    .map_err(|err| err.into())
242            })
243        })
244    }
245}
246
247pin_project! {
248    pub(crate) struct BodyIntoStream<B>
249    where
250        B: Body,
251    {
252        #[pin]
253        body: B,
254        yielded_all_data: bool,
255        non_data_frame: Option<Frame<B::Data>>,
256    }
257}
258
259#[allow(dead_code)]
260impl<B> BodyIntoStream<B>
261where
262    B: Body,
263{
264    pub(crate) fn new(body: B) -> Self {
265        Self {
266            body,
267            yielded_all_data: false,
268            non_data_frame: None,
269        }
270    }
271
272    /// Get a reference to the inner body
273    pub(crate) fn get_ref(&self) -> &B {
274        &self.body
275    }
276
277    /// Get a mutable reference to the inner body
278    pub(crate) fn get_mut(&mut self) -> &mut B {
279        &mut self.body
280    }
281
282    /// Get a pinned mutable reference to the inner body
283    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
284        self.project().body
285    }
286
287    /// Consume `self`, returning the inner body
288    pub(crate) fn into_inner(self) -> B {
289        self.body
290    }
291}
292
293impl<B> Stream for BodyIntoStream<B>
294where
295    B: Body,
296{
297    type Item = Result<B::Data, B::Error>;
298
299    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
300        loop {
301            let this = self.as_mut().project();
302
303            if *this.yielded_all_data {
304                return Poll::Ready(None);
305            }
306
307            match std::task::ready!(this.body.poll_frame(cx)) {
308                Some(Ok(frame)) => match frame.into_data() {
309                    Ok(data) => return Poll::Ready(Some(Ok(data))),
310                    Err(frame) => {
311                        *this.yielded_all_data = true;
312                        *this.non_data_frame = Some(frame);
313                    }
314                },
315                Some(Err(err)) => return Poll::Ready(Some(Err(err))),
316                None => {
317                    *this.yielded_all_data = true;
318                }
319            }
320        }
321    }
322}
323
324impl<B> Body for BodyIntoStream<B>
325where
326    B: Body,
327{
328    type Data = B::Data;
329    type Error = B::Error;
330
331    fn poll_frame(
332        mut self: Pin<&mut Self>,
333        cx: &mut Context<'_>,
334    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
335        // First drive the stream impl. This consumes all data frames and buffer at most one
336        // trailers frame.
337        if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
338            return Poll::Ready(Some(frame.map(Frame::data)));
339        }
340
341        let this = self.project();
342
343        // Yield the trailers frame `poll_next` hit.
344        if let Some(frame) = this.non_data_frame.take() {
345            return Poll::Ready(Some(Ok(frame)));
346        }
347
348        // Yield any remaining frames in the body. There shouldn't be any after the trailers but
349        // you never know.
350        this.body.poll_frame(cx)
351    }
352
353    #[inline]
354    fn size_hint(&self) -> http_body::SizeHint {
355        self.body.size_hint()
356    }
357}
358
359pin_project! {
360    pub(crate) struct StreamErrorIntoIoError<S, E> {
361        #[pin]
362        inner: S,
363        error: Option<E>,
364    }
365}
366
367impl<S, E> StreamErrorIntoIoError<S, E> {
368    pub(crate) fn new(inner: S) -> Self {
369        Self { inner, error: None }
370    }
371
372    /// Get a pinned mutable reference to the inner inner
373    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
374        self.project().inner
375    }
376}
377
378impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
379where
380    S: Stream<Item = Result<T, E>>,
381{
382    type Item = Result<T, io::Error>;
383
384    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
385        let this = self.project();
386        match ready!(this.inner.poll_next(cx)) {
387            None => Poll::Ready(None),
388            Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
389            Some(Err(err)) => {
390                *this.error = Some(err);
391                Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
392            }
393        }
394    }
395}
396
397pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
398
399/// Level of compression data should be compressed with.
400#[non_exhaustive]
401#[derive(Default, Clone, Copy, Debug, Eq, PartialEq)]
402pub enum CompressionLevel {
403    /// Fastest quality of compression, usually produces bigger size.
404    Fastest,
405    /// Best quality of compression, usually produces the smallest size.
406    Best,
407    /// Default quality of compression defined by the selected compression algorithm.
408    #[default]
409    Default,
410    /// Precise quality based on the underlying compression algorithms'
411    /// qualities. The interpretation of this depends on the algorithm chosen
412    /// and the specific implementation backing it.
413    /// Qualities are implicitly clamped to the algorithm's maximum.
414    Precise(u32),
415}
416
417#[cfg(any(
418    feature = "compression-br",
419    feature = "compression-gzip",
420    feature = "compression-deflate",
421    feature = "compression-zstd"
422))]
423use async_compression::Level as AsyncCompressionLevel;
424
425#[cfg(any(
426    feature = "compression-br",
427    feature = "compression-gzip",
428    feature = "compression-deflate",
429    feature = "compression-zstd"
430))]
431impl CompressionLevel {
432    pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
433        match self {
434            CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
435            CompressionLevel::Best => AsyncCompressionLevel::Best,
436            CompressionLevel::Default => AsyncCompressionLevel::Default,
437            CompressionLevel::Precise(quality) => {
438                AsyncCompressionLevel::Precise(quality.try_into().unwrap_or(i32::MAX))
439            }
440        }
441    }
442}