tower_http/
compression_utils.rs

1//! Types used by compression and decompression middleware.
2
3use crate::{content_encoding::SupportedEncodings, BoxError};
4use bytes::{Buf, Bytes, BytesMut};
5use futures_core::Stream;
6use http::HeaderValue;
7use http_body::{Body, Frame};
8use pin_project_lite::pin_project;
9use std::{
10    io,
11    pin::Pin,
12    task::{ready, Context, Poll},
13};
14use tokio::io::AsyncRead;
15use tokio_util::io::StreamReader;
16
17#[derive(Debug, Clone, Copy)]
18pub(crate) struct AcceptEncoding {
19    pub(crate) gzip: bool,
20    pub(crate) deflate: bool,
21    pub(crate) br: bool,
22    pub(crate) zstd: bool,
23}
24
25impl AcceptEncoding {
26    #[allow(dead_code)]
27    pub(crate) fn to_header_value(self) -> Option<HeaderValue> {
28        let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) {
29            (true, true, true, false) => "gzip,deflate,br",
30            (true, true, false, false) => "gzip,deflate",
31            (true, false, true, false) => "gzip,br",
32            (true, false, false, false) => "gzip",
33            (false, true, true, false) => "deflate,br",
34            (false, true, false, false) => "deflate",
35            (false, false, true, false) => "br",
36            (true, true, true, true) => "zstd,gzip,deflate,br",
37            (true, true, false, true) => "zstd,gzip,deflate",
38            (true, false, true, true) => "zstd,gzip,br",
39            (true, false, false, true) => "zstd,gzip",
40            (false, true, true, true) => "zstd,deflate,br",
41            (false, true, false, true) => "zstd,deflate",
42            (false, false, true, true) => "zstd,br",
43            (false, false, false, true) => "zstd",
44            (false, false, false, false) => return None,
45        };
46        Some(HeaderValue::from_static(accept))
47    }
48
49    #[allow(dead_code)]
50    pub(crate) fn set_gzip(&mut self, enable: bool) {
51        self.gzip = enable;
52    }
53
54    #[allow(dead_code)]
55    pub(crate) fn set_deflate(&mut self, enable: bool) {
56        self.deflate = enable;
57    }
58
59    #[allow(dead_code)]
60    pub(crate) fn set_br(&mut self, enable: bool) {
61        self.br = enable;
62    }
63
64    #[allow(dead_code)]
65    pub(crate) fn set_zstd(&mut self, enable: bool) {
66        self.zstd = enable;
67    }
68}
69
70impl SupportedEncodings for AcceptEncoding {
71    #[allow(dead_code)]
72    fn gzip(&self) -> bool {
73        #[cfg(any(feature = "decompression-gzip", feature = "compression-gzip"))]
74        return self.gzip;
75
76        #[cfg(not(any(feature = "decompression-gzip", feature = "compression-gzip")))]
77        return false;
78    }
79
80    #[allow(dead_code)]
81    fn deflate(&self) -> bool {
82        #[cfg(any(feature = "decompression-deflate", feature = "compression-deflate"))]
83        return self.deflate;
84
85        #[cfg(not(any(feature = "decompression-deflate", feature = "compression-deflate")))]
86        return false;
87    }
88
89    #[allow(dead_code)]
90    fn br(&self) -> bool {
91        #[cfg(any(feature = "decompression-br", feature = "compression-br"))]
92        return self.br;
93
94        #[cfg(not(any(feature = "decompression-br", feature = "compression-br")))]
95        return false;
96    }
97
98    #[allow(dead_code)]
99    fn zstd(&self) -> bool {
100        #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))]
101        return self.zstd;
102
103        #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))]
104        return false;
105    }
106}
107
108impl Default for AcceptEncoding {
109    fn default() -> Self {
110        AcceptEncoding {
111            gzip: true,
112            deflate: true,
113            br: true,
114            zstd: true,
115        }
116    }
117}
118
119/// A `Body` that has been converted into an `AsyncRead`.
120pub(crate) type AsyncReadBody<B> =
121    StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>;
122
123/// Trait for applying some decorator to an `AsyncRead`
124pub(crate) trait DecorateAsyncRead {
125    type Input: AsyncRead;
126    type Output: AsyncRead;
127
128    /// Apply the decorator
129    fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
130
131    /// Get a pinned mutable reference to the original input.
132    ///
133    /// This is necessary to implement `Body::poll_trailers`.
134    fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
135}
136
137pin_project! {
138    /// `Body` that has been decorated by an `AsyncRead`
139    pub(crate) struct WrapBody<M: DecorateAsyncRead> {
140        #[pin]
141        // rust-analyer thinks this field is private if its `pub(crate)` but works fine when its
142        // `pub`
143        pub read: M::Output,
144        // A buffer to temporarily store the data read from the underlying body.
145        // Reused as much as possible to optimize allocations.
146        buf: BytesMut,
147        read_all_data: bool,
148    }
149}
150
151impl<M: DecorateAsyncRead> WrapBody<M> {
152    const INTERNAL_BUF_CAPACITY: usize = 4096;
153}
154
155impl<M: DecorateAsyncRead> WrapBody<M> {
156    #[allow(dead_code)]
157    pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
158    where
159        B: Body,
160        M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
161    {
162        // convert `Body` into a `Stream`
163        let stream = BodyIntoStream::new(body);
164
165        // an adapter that converts the error type into `io::Error` while storing the actual error
166        // `StreamReader` requires the error type is `io::Error`
167        let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
168
169        // convert `Stream` into an `AsyncRead`
170        let read = StreamReader::new(stream);
171
172        // apply decorator to `AsyncRead` yielding another `AsyncRead`
173        let read = M::apply(read, quality);
174
175        Self {
176            read,
177            buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY),
178            read_all_data: false,
179        }
180    }
181}
182
183impl<B, M> Body for WrapBody<M>
184where
185    B: Body,
186    B::Error: Into<BoxError>,
187    M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
188{
189    type Data = Bytes;
190    type Error = BoxError;
191
192    fn poll_frame(
193        self: Pin<&mut Self>,
194        cx: &mut Context<'_>,
195    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
196        let mut this = self.project();
197
198        if !*this.read_all_data {
199            if this.buf.capacity() == 0 {
200                this.buf.reserve(Self::INTERNAL_BUF_CAPACITY);
201            }
202
203            let result = tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf);
204
205            match ready!(result) {
206                Ok(0) => {
207                    *this.read_all_data = true;
208                }
209                Ok(_) => {
210                    let chunk = this.buf.split().freeze();
211                    return Poll::Ready(Some(Ok(Frame::data(chunk))));
212                }
213                Err(err) => {
214                    let body_error: Option<B::Error> = M::get_pin_mut(this.read.as_mut())
215                        .get_pin_mut()
216                        .project()
217                        .error
218                        .take();
219
220                    let read_some_data = M::get_pin_mut(this.read.as_mut())
221                        .get_pin_mut()
222                        .project()
223                        .read_some_data;
224
225                    if let Some(body_error) = body_error {
226                        return Poll::Ready(Some(Err(body_error.into())));
227                    } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
228                        // SENTINEL_ERROR_CODE only gets used when storing
229                        // an underlying body error
230                        unreachable!()
231                    } else if *read_some_data {
232                        return Poll::Ready(Some(Err(err.into())));
233                    }
234                }
235            }
236        }
237
238        // poll any remaining frames, such as trailers
239        let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
240        body.poll_frame(cx).map(|option| {
241            option.map(|result| {
242                result
243                    .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining())))
244                    .map_err(|err| err.into())
245            })
246        })
247    }
248}
249
250pin_project! {
251    pub(crate) struct BodyIntoStream<B>
252    where
253        B: Body,
254    {
255        #[pin]
256        body: B,
257        yielded_all_data: bool,
258        non_data_frame: Option<Frame<B::Data>>,
259    }
260}
261
262#[allow(dead_code)]
263impl<B> BodyIntoStream<B>
264where
265    B: Body,
266{
267    pub(crate) fn new(body: B) -> Self {
268        Self {
269            body,
270            yielded_all_data: false,
271            non_data_frame: None,
272        }
273    }
274
275    /// Get a reference to the inner body
276    pub(crate) fn get_ref(&self) -> &B {
277        &self.body
278    }
279
280    /// Get a mutable reference to the inner body
281    pub(crate) fn get_mut(&mut self) -> &mut B {
282        &mut self.body
283    }
284
285    /// Get a pinned mutable reference to the inner body
286    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
287        self.project().body
288    }
289
290    /// Consume `self`, returning the inner body
291    pub(crate) fn into_inner(self) -> B {
292        self.body
293    }
294}
295
296impl<B> Stream for BodyIntoStream<B>
297where
298    B: Body,
299{
300    type Item = Result<B::Data, B::Error>;
301
302    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303        loop {
304            let this = self.as_mut().project();
305
306            if *this.yielded_all_data {
307                return Poll::Ready(None);
308            }
309
310            match std::task::ready!(this.body.poll_frame(cx)) {
311                Some(Ok(frame)) => match frame.into_data() {
312                    Ok(data) => return Poll::Ready(Some(Ok(data))),
313                    Err(frame) => {
314                        *this.yielded_all_data = true;
315                        *this.non_data_frame = Some(frame);
316                    }
317                },
318                Some(Err(err)) => return Poll::Ready(Some(Err(err))),
319                None => {
320                    *this.yielded_all_data = true;
321                }
322            }
323        }
324    }
325}
326
327impl<B> Body for BodyIntoStream<B>
328where
329    B: Body,
330{
331    type Data = B::Data;
332    type Error = B::Error;
333
334    fn poll_frame(
335        mut self: Pin<&mut Self>,
336        cx: &mut Context<'_>,
337    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
338        // First drive the stream impl. This consumes all data frames and buffer at most one
339        // trailers frame.
340        if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
341            return Poll::Ready(Some(frame.map(Frame::data)));
342        }
343
344        let this = self.project();
345
346        // Yield the trailers frame `poll_next` hit.
347        if let Some(frame) = this.non_data_frame.take() {
348            return Poll::Ready(Some(Ok(frame)));
349        }
350
351        // Yield any remaining frames in the body. There shouldn't be any after the trailers but
352        // you never know.
353        this.body.poll_frame(cx)
354    }
355
356    #[inline]
357    fn size_hint(&self) -> http_body::SizeHint {
358        self.body.size_hint()
359    }
360}
361
362pin_project! {
363    pub(crate) struct StreamErrorIntoIoError<S, E> {
364        #[pin]
365        inner: S,
366        error: Option<E>,
367        read_some_data: bool
368    }
369}
370
371impl<S, E> StreamErrorIntoIoError<S, E> {
372    pub(crate) fn new(inner: S) -> Self {
373        Self {
374            inner,
375            error: None,
376            read_some_data: false,
377        }
378    }
379
380    /// Get a reference to the inner body
381    pub(crate) fn get_ref(&self) -> &S {
382        &self.inner
383    }
384
385    /// Get a mutable reference to the inner inner
386    pub(crate) fn get_mut(&mut self) -> &mut S {
387        &mut self.inner
388    }
389
390    /// Get a pinned mutable reference to the inner inner
391    pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
392        self.project().inner
393    }
394
395    /// Consume `self`, returning the inner inner
396    pub(crate) fn into_inner(self) -> S {
397        self.inner
398    }
399}
400
401impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
402where
403    S: Stream<Item = Result<T, E>>,
404{
405    type Item = Result<T, io::Error>;
406
407    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
408        let this = self.project();
409        match ready!(this.inner.poll_next(cx)) {
410            None => Poll::Ready(None),
411            Some(Ok(value)) => {
412                *this.read_some_data = true;
413                Poll::Ready(Some(Ok(value)))
414            }
415            Some(Err(err)) => {
416                *this.error = Some(err);
417                Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
418            }
419        }
420    }
421}
422
423pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
424
425/// Level of compression data should be compressed with.
426#[non_exhaustive]
427#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
428pub enum CompressionLevel {
429    /// Fastest quality of compression, usually produces bigger size.
430    Fastest,
431    /// Best quality of compression, usually produces the smallest size.
432    Best,
433    /// Default quality of compression defined by the selected compression
434    /// algorithm.
435    #[default]
436    Default,
437    /// Precise quality based on the underlying compression algorithms'
438    /// qualities.
439    ///
440    /// The interpretation of this depends on the algorithm chosen and the
441    /// specific implementation backing it.
442    ///
443    /// Qualities are implicitly clamped to the algorithm's maximum.
444    Precise(i32),
445}
446
447#[cfg(any(
448    feature = "compression-br",
449    feature = "compression-gzip",
450    feature = "compression-deflate",
451    feature = "compression-zstd"
452))]
453use async_compression::Level as AsyncCompressionLevel;
454
455#[cfg(any(
456    feature = "compression-br",
457    feature = "compression-gzip",
458    feature = "compression-deflate",
459    feature = "compression-zstd"
460))]
461impl CompressionLevel {
462    pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
463        match self {
464            CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
465            CompressionLevel::Best => AsyncCompressionLevel::Best,
466            CompressionLevel::Default => AsyncCompressionLevel::Default,
467            CompressionLevel::Precise(quality) => AsyncCompressionLevel::Precise(quality),
468        }
469    }
470}