tower_async_http/compression/
mod.rs

1//! Middleware that compresses response bodies.
2//!
3//! # Example
4//!
5//! Example showing how to respond with the compressed contents of a file.
6//!
7//! ```rust
8//! use bytes::{Bytes, BytesMut};
9//! use http::{Request, Response, header::ACCEPT_ENCODING};
10//! use http_body::Frame;
11//! use http_body_util::{Full, BodyExt, StreamBody, combinators::UnsyncBoxBody};
12//! use std::convert::Infallible;
13//! use tokio::fs::{self, File};
14//! use tokio_util::io::ReaderStream;
15//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
16//! use tower_async_http::{compression::CompressionLayer};
17//! use futures_util::TryStreamExt;
18//!
19//! type BoxBody = UnsyncBoxBody<Bytes, std::io::Error>;
20//!
21//! # #[tokio::main]
22//! # async fn main() -> Result<(), BoxError> {
23//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<BoxBody>, Infallible> {
24//!     // Open the file.
25//!     let file = File::open("Cargo.toml").await.expect("file missing");
26//!     // Convert the file into a `Stream` of `Bytes`.
27//!     let stream = ReaderStream::new(file);
28//!     // Convert the stream into a stream of data `Frame`s.
29//!     let stream = stream.map_ok(Frame::data);
30//!     // Convert the `Stream` into a `Body`.
31//!     let body = StreamBody::new(stream);
32//!     // Erase the type because its very hard to name in the function signature.
33//!     let body = body.boxed_unsync();
34//!     // Create response.
35//!     Ok(Response::new(body))
36//! }
37//!
38//! let mut service = ServiceBuilder::new()
39//!     // Compress responses based on the `Accept-Encoding` header.
40//!     .layer(CompressionLayer::new())
41//!     .service_fn(handle);
42//!
43//! // Call the service.
44//! let request = Request::builder()
45//!     .header(ACCEPT_ENCODING, "gzip")
46//!     .body(Full::<Bytes>::default())?;
47//!
48//! let response = service
49//!     .call(request)
50//!     .await?;
51//!
52//! assert_eq!(response.headers()["content-encoding"], "gzip");
53//!
54//! // Read the body
55//! let bytes = response
56//!     .into_body()
57//!     .collect()
58//!     .await?
59//!     .to_bytes();
60//!
61//! // The compressed body should be smaller 🤞
62//! let uncompressed_len = fs::read_to_string("Cargo.toml").await?.len();
63//! assert!(bytes.len() < uncompressed_len);
64//! #
65//! # Ok(())
66//! # }
67//! ```
68//!
69
70pub mod predicate;
71
72mod body;
73mod layer;
74mod pin_project_cfg;
75mod service;
76
77#[doc(inline)]
78pub use self::{
79    body::CompressionBody,
80    layer::CompressionLayer,
81    predicate::{DefaultPredicate, Predicate},
82    service::Compression,
83};
84pub use crate::compression_utils::CompressionLevel;
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    use crate::compression::predicate::SizeAbove;
91    use crate::test_helpers::{Body, WithTrailers};
92
93    use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
94    use flate2::read::GzDecoder;
95    use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
96    use http::{HeaderMap, HeaderName, Request, Response};
97    use http_body_util::BodyExt;
98    use std::convert::Infallible;
99    use std::io::Read;
100    use std::sync::{Arc, RwLock};
101    use tokio::io::{AsyncReadExt, AsyncWriteExt};
102    use tokio_util::io::StreamReader;
103    use tower_async::{service_fn, Service, ServiceExt};
104
105    // Compression filter allows every other request to be compressed
106    #[derive(Clone)]
107    struct Always;
108
109    impl Predicate for Always {
110        fn should_compress<B>(&self, _: &http::Response<B>) -> bool
111        where
112            B: http_body::Body,
113        {
114            true
115        }
116    }
117
118    #[tokio::test]
119    async fn gzip_works() {
120        let svc = service_fn(handle);
121        let svc = Compression::new(svc).compress_when(Always);
122
123        // call the service
124        let req = Request::builder()
125            .header("accept-encoding", "gzip")
126            .body(Body::empty())
127            .unwrap();
128        let res = svc.call(req).await.unwrap();
129
130        // read the compressed body
131        let collected = res.into_body().collect().await.unwrap();
132        let trailers = collected.trailers().cloned().unwrap();
133        let compressed_data = collected.to_bytes();
134
135        // decompress the body
136        // doing this with flate2 as that is much easier than async-compression and blocking during
137        // tests is fine
138        let mut decoder = GzDecoder::new(&compressed_data[..]);
139        let mut decompressed = String::new();
140        decoder.read_to_string(&mut decompressed).unwrap();
141
142        assert_eq!(decompressed, "Hello, World!");
143
144        // trailers are maintained
145        assert_eq!(trailers["foo"], "bar");
146    }
147
148    #[tokio::test]
149    async fn zstd_works() {
150        let svc = service_fn(handle);
151        let svc = Compression::new(svc).compress_when(Always);
152
153        // call the service
154        let req = Request::builder()
155            .header("accept-encoding", "zstd")
156            .body(Body::empty())
157            .unwrap();
158        let res = svc.call(req).await.unwrap();
159
160        // read the compressed body
161        let body = res.into_body();
162        let compressed_data = body.collect().await.unwrap().to_bytes();
163
164        // decompress the body
165        let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
166        let decompressed = String::from_utf8(decompressed).unwrap();
167
168        assert_eq!(decompressed, "Hello, World!");
169    }
170
171    #[tokio::test]
172    async fn no_recompress() {
173        const DATA: &str = "Hello, World! I'm already compressed with br!";
174
175        let svc = service_fn(|_| async {
176            let buf = {
177                let mut buf = Vec::new();
178
179                let mut enc = BrotliEncoder::new(&mut buf);
180                enc.write_all(DATA.as_bytes()).await?;
181                enc.flush().await?;
182                buf
183            };
184
185            let resp = Response::builder()
186                .header("content-encoding", "br")
187                .body(Body::from(buf))
188                .unwrap();
189            Ok::<_, std::io::Error>(resp)
190        });
191        let svc = Compression::new(svc);
192
193        // call the service
194        //
195        // note: the accept-encoding doesn't match the content-encoding above, so that
196        // we're able to see if the compression layer triggered or not
197        let req = Request::builder()
198            .header("accept-encoding", "gzip")
199            .body(Body::empty())
200            .unwrap();
201        let res = svc.call(req).await.unwrap();
202
203        // check we didn't recompress
204        assert_eq!(
205            res.headers()
206                .get("content-encoding")
207                .and_then(|h| h.to_str().ok())
208                .unwrap_or_default(),
209            "br",
210        );
211
212        // read the compressed body
213        let body = res.into_body();
214        let data = body.collect().await.unwrap().to_bytes();
215
216        // decompress the body
217        let data = {
218            let mut output_buf = Vec::new();
219            let mut decoder = BrotliDecoder::new(&mut output_buf);
220            decoder
221                .write_all(&data)
222                .await
223                .expect("couldn't brotli-decode");
224            decoder.flush().await.expect("couldn't flush");
225            output_buf
226        };
227
228        assert_eq!(data, DATA.as_bytes());
229    }
230
231    async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> {
232        let mut trailers = HeaderMap::new();
233        trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap());
234        let body = Body::from("Hello, World!").with_trailers(trailers);
235        Ok(Response::builder().body(body).unwrap())
236    }
237
238    #[tokio::test]
239    async fn will_not_compress_if_filtered_out() {
240        use predicate::Predicate;
241
242        const DATA: &str = "Hello world uncompressed";
243
244        let svc_fn = service_fn(|_| async {
245            let resp = Response::builder()
246                // .header("content-encoding", "br")
247                .body(Body::from(DATA.as_bytes()))
248                .unwrap();
249            Ok::<_, std::io::Error>(resp)
250        });
251
252        // Compression filter allows every other request to be compressed
253        #[derive(Default, Clone)]
254        struct EveryOtherResponse(Arc<RwLock<u64>>);
255
256        #[allow(clippy::dbg_macro)]
257        impl Predicate for EveryOtherResponse {
258            fn should_compress<B>(&self, _: &http::Response<B>) -> bool
259            where
260                B: http_body::Body,
261            {
262                let mut guard = self.0.write().unwrap();
263                let should_compress = *guard % 2 != 0;
264                *guard += 1;
265                should_compress
266            }
267        }
268
269        let svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
270        let req = Request::builder()
271            .header("accept-encoding", "br")
272            .body(Body::empty())
273            .unwrap();
274        let res = svc.call(req).await.unwrap();
275
276        // read the uncompressed body
277        let body = res.into_body();
278        let data = body.collect().await.unwrap().to_bytes();
279        let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
280        assert_eq!(DATA, &still_uncompressed);
281
282        // Compression filter will compress the next body
283        let req = Request::builder()
284            .header("accept-encoding", "br")
285            .body(Body::empty())
286            .unwrap();
287        let res = svc.call(req).await.unwrap();
288
289        // read the compressed body
290        let body = res.into_body();
291        let data = body.collect().await.unwrap().to_bytes();
292        assert!(String::from_utf8(data.to_vec()).is_err());
293    }
294
295    #[tokio::test]
296    async fn doesnt_compress_images() {
297        async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
298            let mut res = Response::new(Body::from(
299                "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
300            ));
301            res.headers_mut()
302                .insert(CONTENT_TYPE, "image/png".parse().unwrap());
303            Ok(res)
304        }
305
306        let svc = Compression::new(service_fn(handle));
307
308        let res = svc
309            .oneshot(
310                Request::builder()
311                    .header(ACCEPT_ENCODING, "gzip")
312                    .body(Body::empty())
313                    .unwrap(),
314            )
315            .await
316            .unwrap();
317        assert!(res.headers().get(CONTENT_ENCODING).is_none());
318    }
319
320    #[tokio::test]
321    async fn does_compress_svg() {
322        async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
323            let mut res = Response::new(Body::from(
324                "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
325            ));
326            res.headers_mut()
327                .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
328            Ok(res)
329        }
330
331        let svc = Compression::new(service_fn(handle));
332
333        let res = svc
334            .oneshot(
335                Request::builder()
336                    .header(ACCEPT_ENCODING, "gzip")
337                    .body(Body::empty())
338                    .unwrap(),
339            )
340            .await
341            .unwrap();
342        assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
343    }
344
345    #[tokio::test]
346    async fn compress_with_quality() {
347        const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
348        let level = CompressionLevel::Best;
349
350        let svc = service_fn(|_| async {
351            let resp = Response::builder()
352                .body(Body::from(DATA.as_bytes()))
353                .unwrap();
354            Ok::<_, std::io::Error>(resp)
355        });
356
357        let svc = Compression::new(svc).quality(level);
358
359        // call the service
360        let req = Request::builder()
361            .header("accept-encoding", "br")
362            .body(Body::empty())
363            .unwrap();
364        let res = svc.call(req).await.unwrap();
365
366        // read the compressed body
367        let body = res.into_body();
368        let compressed_data = body.collect().await.unwrap().to_bytes();
369
370        // build the compressed body with the same quality level
371        let compressed_with_level = {
372            use async_compression::tokio::bufread::BrotliEncoder;
373
374            let stream = Box::pin(futures::stream::once(async move {
375                Ok::<_, std::io::Error>(DATA.as_bytes())
376            }));
377            let reader = StreamReader::new(stream);
378            let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
379
380            let mut buf = Vec::new();
381            enc.read_to_end(&mut buf).await.unwrap();
382            buf
383        };
384
385        assert_eq!(
386            compressed_data,
387            compressed_with_level.as_slice(),
388            "Compression level is not respected"
389        );
390    }
391}