tower_async_http/compression/
service.rs

1use super::body::BodyInner;
2use super::{CompressionBody, CompressionLayer};
3use crate::compression::predicate::{DefaultPredicate, Predicate};
4use crate::compression::CompressionLevel;
5use crate::compression_utils::WrapBody;
6use crate::{compression_utils::AcceptEncoding, content_encoding::Encoding};
7use http::{header, Request, Response};
8use http_body::Body;
9use tower_async_service::Service;
10
11/// Compress response bodies of the underlying service.
12///
13/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
14/// `Content-Encoding` header to responses.
15///
16/// See the [module docs](crate::compression) for more details.
17#[derive(Clone, Copy)]
18pub struct Compression<S, P = DefaultPredicate> {
19    pub(crate) inner: S,
20    pub(crate) accept: AcceptEncoding,
21    pub(crate) predicate: P,
22    pub(crate) quality: CompressionLevel,
23}
24
25impl<S> Compression<S, DefaultPredicate> {
26    /// Creates a new `Compression` wrapping the `service`.
27    pub fn new(service: S) -> Compression<S, DefaultPredicate> {
28        Self {
29            inner: service,
30            accept: AcceptEncoding::default(),
31            predicate: DefaultPredicate::default(),
32            quality: CompressionLevel::default(),
33        }
34    }
35}
36
37impl<S, P> Compression<S, P> {
38    define_inner_service_accessors!();
39
40    /// Returns a new [`Layer`] that wraps services with a `Compression` middleware.
41    ///
42    /// [`Layer`]: tower_async_layer::Layer
43    pub fn layer() -> CompressionLayer {
44        CompressionLayer::new()
45    }
46
47    /// Sets whether to enable the gzip encoding.
48    #[cfg(feature = "compression-gzip")]
49    pub fn gzip(mut self, enable: bool) -> Self {
50        self.accept.set_gzip(enable);
51        self
52    }
53
54    /// Sets whether to enable the Deflate encoding.
55    #[cfg(feature = "compression-deflate")]
56    pub fn deflate(mut self, enable: bool) -> Self {
57        self.accept.set_deflate(enable);
58        self
59    }
60
61    /// Sets whether to enable the Brotli encoding.
62    #[cfg(feature = "compression-br")]
63    pub fn br(mut self, enable: bool) -> Self {
64        self.accept.set_br(enable);
65        self
66    }
67
68    /// Sets whether to enable the Zstd encoding.
69    #[cfg(feature = "compression-zstd")]
70    pub fn zstd(mut self, enable: bool) -> Self {
71        self.accept.set_zstd(enable);
72        self
73    }
74
75    /// Sets the compression quality.
76    pub fn quality(mut self, quality: CompressionLevel) -> Self {
77        self.quality = quality;
78        self
79    }
80
81    /// Disables the gzip encoding.
82    ///
83    /// This method is available even if the `gzip` crate feature is disabled.
84    pub fn no_gzip(mut self) -> Self {
85        self.accept.set_gzip(false);
86        self
87    }
88
89    /// Disables the Deflate encoding.
90    ///
91    /// This method is available even if the `deflate` crate feature is disabled.
92    pub fn no_deflate(mut self) -> Self {
93        self.accept.set_deflate(false);
94        self
95    }
96
97    /// Disables the Brotli encoding.
98    ///
99    /// This method is available even if the `br` crate feature is disabled.
100    pub fn no_br(mut self) -> Self {
101        self.accept.set_br(false);
102        self
103    }
104
105    /// Disables the Zstd encoding.
106    ///
107    /// This method is available even if the `zstd` crate feature is disabled.
108    pub fn no_zstd(mut self) -> Self {
109        self.accept.set_zstd(false);
110        self
111    }
112
113    /// Replace the current compression predicate.
114    ///
115    /// Predicates are used to determine whether a response should be compressed or not.
116    ///
117    /// The default predicate is [`DefaultPredicate`]. See its documentation for more
118    /// details on which responses it wont compress.
119    ///
120    /// # Changing the compression predicate
121    ///
122    /// ```
123    /// use tower_async_http::compression::{
124    ///     Compression,
125    ///     predicate::{Predicate, NotForContentType, DefaultPredicate},
126    /// };
127    /// use tower_async::util::service_fn;
128    ///
129    /// // Placeholder service_fn
130    /// let service = service_fn(|_: ()| async {
131    ///     Ok::<_, std::io::Error>(http::Response::new(()))
132    /// });
133    ///
134    /// // build our custom compression predicate
135    /// // its recommended to still include `DefaultPredicate` as part of
136    /// // custom predicates
137    /// let predicate = DefaultPredicate::new()
138    ///     // don't compress responses who's `content-type` starts with `application/json`
139    ///     .and(NotForContentType::new("application/json"));
140    ///
141    /// let service = Compression::new(service).compress_when(predicate);
142    /// ```
143    ///
144    /// See [`predicate`](super::predicate) for more utilities for building compression predicates.
145    ///
146    /// Responses that are already compressed (ie have a `content-encoding` header) will _never_ be
147    /// recompressed, regardless what they predicate says.
148    pub fn compress_when<C>(self, predicate: C) -> Compression<S, C>
149    where
150        C: Predicate,
151    {
152        Compression {
153            inner: self.inner,
154            accept: self.accept,
155            predicate,
156            quality: self.quality,
157        }
158    }
159}
160
161impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for Compression<S, P>
162where
163    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
164    ResBody: Body,
165    P: Predicate,
166{
167    type Response = Response<CompressionBody<ResBody>>;
168    type Error = S::Error;
169
170    #[allow(unreachable_code, unused_mut, unused_variables, unreachable_patterns)]
171    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
172        let encoding = Encoding::from_headers(req.headers(), self.accept);
173
174        let res = self.inner.call(req).await?;
175
176        // never recompress responses that are already compressed
177        let should_compress = !res.headers().contains_key(header::CONTENT_ENCODING)
178            && self.predicate.should_compress(&res);
179
180        let (mut parts, body) = res.into_parts();
181
182        let body = match (should_compress, encoding) {
183            // if compression is _not_ support or the client doesn't accept it
184            (false, _) | (_, Encoding::Identity) => {
185                return Ok(Response::from_parts(
186                    parts,
187                    CompressionBody::new(BodyInner::identity(body)),
188                ))
189            }
190
191            #[cfg(feature = "compression-gzip")]
192            (_, Encoding::Gzip) => {
193                CompressionBody::new(BodyInner::gzip(WrapBody::new(body, self.quality)))
194            }
195            #[cfg(feature = "compression-deflate")]
196            (_, Encoding::Deflate) => {
197                CompressionBody::new(BodyInner::deflate(WrapBody::new(body, self.quality)))
198            }
199            #[cfg(feature = "compression-br")]
200            (_, Encoding::Brotli) => {
201                CompressionBody::new(BodyInner::brotli(WrapBody::new(body, self.quality)))
202            }
203            #[cfg(feature = "compression-zstd")]
204            (_, Encoding::Zstd) => {
205                CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality)))
206            }
207            #[cfg(feature = "fs")]
208            (true, _) => {
209                // This should never happen because the `AcceptEncoding` struct which is used to determine
210                // `self.encoding` will only enable the different compression algorithms if the
211                // corresponding crate feature has been enabled. This means
212                // Encoding::[Gzip|Brotli|Deflate] should be impossible at this point without the
213                // features enabled.
214                //
215                // The match arm is still required though because the `fs` feature uses the
216                // Encoding struct independently and requires no compression logic to be enabled.
217                // This means a combination of an individual compression feature and `fs` will fail
218                // to compile without this branch even though it will never be reached.
219                //
220                // To safeguard against refactors that changes this relationship or other bugs the
221                // server will return an uncompressed response instead of panicking since that could
222                // become a ddos attack vector.
223                return Ok(Response::from_parts(
224                    parts,
225                    CompressionBody::new(BodyInner::identity(body)),
226                ));
227            }
228        };
229
230        parts.headers.remove(header::CONTENT_LENGTH);
231
232        parts
233            .headers
234            .insert(header::CONTENT_ENCODING, encoding.into_header_value());
235
236        let res = Response::from_parts(parts, body);
237        Ok(res)
238    }
239}