Skip to main content

tower_http/compression/
layer.rs

1use super::{Compression, Predicate};
2use crate::compression::predicate::DefaultPredicate;
3use crate::compression::CompressionLevel;
4use crate::compression_utils::AcceptEncoding;
5use tower_layer::Layer;
6
7/// Compress response bodies of the underlying service.
8///
9/// This uses the `Accept-Encoding` header to pick an appropriate encoding and adds the
10/// `Content-Encoding` header to responses.
11///
12/// See the [module docs](crate::compression) for more details.
13#[derive(Clone, Debug, Default)]
14pub struct CompressionLayer<P = DefaultPredicate> {
15    accept: AcceptEncoding,
16    predicate: P,
17    quality: CompressionLevel,
18}
19
20impl<S, P> Layer<S> for CompressionLayer<P>
21where
22    P: Predicate,
23{
24    type Service = Compression<S, P>;
25
26    fn layer(&self, inner: S) -> Self::Service {
27        Compression {
28            inner,
29            accept: self.accept,
30            predicate: self.predicate.clone(),
31            quality: self.quality,
32        }
33    }
34}
35
36impl CompressionLayer {
37    /// Creates a new [`CompressionLayer`].
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Sets whether to enable the gzip encoding.
43    #[cfg(feature = "compression-gzip")]
44    pub fn gzip(mut self, enable: bool) -> Self {
45        self.accept.set_gzip(enable);
46        self
47    }
48
49    /// Sets whether to enable the Deflate encoding.
50    #[cfg(feature = "compression-deflate")]
51    pub fn deflate(mut self, enable: bool) -> Self {
52        self.accept.set_deflate(enable);
53        self
54    }
55
56    /// Sets whether to enable the Brotli encoding.
57    #[cfg(feature = "compression-br")]
58    pub fn br(mut self, enable: bool) -> Self {
59        self.accept.set_br(enable);
60        self
61    }
62
63    /// Sets whether to enable the Zstd encoding.
64    #[cfg(feature = "compression-zstd")]
65    pub fn zstd(mut self, enable: bool) -> Self {
66        self.accept.set_zstd(enable);
67        self
68    }
69
70    /// Sets the compression quality.
71    pub fn quality(mut self, quality: CompressionLevel) -> Self {
72        self.quality = quality;
73        self
74    }
75
76    /// Disables the gzip encoding.
77    ///
78    /// This method is available even if the `gzip` crate feature is disabled.
79    pub fn no_gzip(mut self) -> Self {
80        self.accept.set_gzip(false);
81        self
82    }
83
84    /// Disables the Deflate encoding.
85    ///
86    /// This method is available even if the `deflate` crate feature is disabled.
87    pub fn no_deflate(mut self) -> Self {
88        self.accept.set_deflate(false);
89        self
90    }
91
92    /// Disables the Brotli encoding.
93    ///
94    /// This method is available even if the `br` crate feature is disabled.
95    pub fn no_br(mut self) -> Self {
96        self.accept.set_br(false);
97        self
98    }
99
100    /// Disables the Zstd encoding.
101    ///
102    /// This method is available even if the `zstd` crate feature is disabled.
103    pub fn no_zstd(mut self) -> Self {
104        self.accept.set_zstd(false);
105        self
106    }
107
108    /// Replace the current compression predicate.
109    ///
110    /// See [`Compression::compress_when`] for more details.
111    pub fn compress_when<C>(self, predicate: C) -> CompressionLayer<C>
112    where
113        C: Predicate,
114    {
115        CompressionLayer {
116            accept: self.accept,
117            predicate,
118            quality: self.quality,
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::test_helpers::Body;
127    use http::{header::ACCEPT_ENCODING, Request, Response};
128    use http_body_util::BodyExt;
129    use std::convert::Infallible;
130    use tokio::fs::File;
131    use tokio_util::io::ReaderStream;
132    use tower::{Service, ServiceBuilder, ServiceExt};
133
134    async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
135        // Open the file.
136        let file = File::open("Cargo.toml").await.expect("file missing");
137        // Convert the file into a `Stream`.
138        let stream = ReaderStream::new(file);
139        // Convert the `Stream` into a `Body`.
140        let body = Body::from_stream(stream);
141        // Create response.
142        Ok(Response::new(body))
143    }
144
145    #[tokio::test]
146    async fn accept_encoding_configuration_works() -> Result<(), crate::BoxError> {
147        use std::io::Read;
148
149        fn decode<R: Read>(mut r: R) -> std::io::Result<Vec<u8>> {
150            let mut buf = Vec::new();
151            r.read_to_end(&mut buf)?;
152            Ok(buf)
153        }
154
155        // Read the source file once so we can verify each response round-trips to the same bytes.
156        let expected = tokio::fs::read("Cargo.toml").await?;
157
158        // Configure a layer that only offers deflate, then confirm the response is actually
159        // deflate-encoded by decoding it and comparing to the original content.
160        let deflate_only_layer = CompressionLayer::new()
161            .quality(CompressionLevel::Best)
162            .no_br()
163            .no_gzip();
164
165        let mut service = ServiceBuilder::new()
166            .layer(deflate_only_layer)
167            .service_fn(handle);
168
169        let request = Request::builder()
170            .header(ACCEPT_ENCODING, "gzip, deflate, br")
171            .body(Body::empty())?;
172
173        let response = service.ready().await?.call(request).await?;
174
175        assert_eq!(response.headers()["content-encoding"], "deflate");
176
177        let deflate_body = response.into_body().collect().await?.to_bytes();
178
179        // The "deflate" Content-Encoding is RFC 1950 zlib framing (2-byte header + Adler-32),
180        // not raw RFC 1951 deflate, so use ZlibDecoder rather than DeflateDecoder.
181        let decoded = decode(flate2::bufread::ZlibDecoder::new(&deflate_body[..]))?;
182        assert_eq!(decoded, expected);
183
184        // Same check for brotli.
185        let br_only_layer = CompressionLayer::new()
186            .quality(CompressionLevel::Best)
187            .no_gzip()
188            .no_deflate();
189
190        let mut service = ServiceBuilder::new()
191            .layer(br_only_layer)
192            .service_fn(handle);
193
194        let request = Request::builder()
195            .header(ACCEPT_ENCODING, "gzip, deflate, br")
196            .body(Body::empty())?;
197
198        let response = service.ready().await?.call(request).await?;
199
200        assert_eq!(response.headers()["content-encoding"], "br");
201
202        let br_body = response.into_body().collect().await?.to_bytes();
203
204        // 4096 is the decoder's internal read-buffer size, not a content-length bound.
205        let decoded = decode(brotli::Decompressor::new(&br_body[..], 4096))?;
206        assert_eq!(decoded, expected);
207
208        Ok(())
209    }
210
211    /// Test ensuring that zstd compression will not exceed an 8MiB window size; browsers do not
212    /// accept responses using 16MiB+ window sizes.
213    #[tokio::test]
214    async fn zstd_is_web_safe() -> Result<(), crate::BoxError> {
215        async fn zeroes(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
216            Ok(Response::new(Body::from(vec![0u8; 18_874_368])))
217        }
218        // zstd will (I believe) lower its window size if a larger one isn't beneficial and
219        // it knows the size of the input; use an 18MiB body to ensure it would want a
220        // >=16MiB window (though it might not be able to see the input size here).
221
222        let zstd_layer = CompressionLayer::new()
223            .quality(CompressionLevel::Best)
224            .no_br()
225            .no_deflate()
226            .no_gzip();
227
228        let mut service = ServiceBuilder::new().layer(zstd_layer).service_fn(zeroes);
229
230        let request = Request::builder()
231            .header(ACCEPT_ENCODING, "zstd")
232            .body(Body::empty())?;
233
234        let response = service.ready().await?.call(request).await?;
235
236        assert_eq!(response.headers()["content-encoding"], "zstd");
237
238        let body = response.into_body();
239        let bytes = body.collect().await?.to_bytes();
240        let mut dec = zstd::Decoder::new(&*bytes)?;
241        dec.window_log_max(23)?; // Limit window size accepted by decoder to 2 ^ 23 bytes (8MiB)
242
243        std::io::copy(&mut dec, &mut std::io::sink())?;
244
245        Ok(())
246    }
247}