tower_async_http/compression/
mod.rs1pub mod predicate;
71
72mod body;
73mod layer;
74mod pin_project_cfg;
75mod service;
76
77#[doc(inline)]
78pub use self::{
79 body::CompressionBody,
80 layer::CompressionLayer,
81 predicate::{DefaultPredicate, Predicate},
82 service::Compression,
83};
84pub use crate::compression_utils::CompressionLevel;
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89
90 use crate::compression::predicate::SizeAbove;
91 use crate::test_helpers::{Body, WithTrailers};
92
93 use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
94 use flate2::read::GzDecoder;
95 use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
96 use http::{HeaderMap, HeaderName, Request, Response};
97 use http_body_util::BodyExt;
98 use std::convert::Infallible;
99 use std::io::Read;
100 use std::sync::{Arc, RwLock};
101 use tokio::io::{AsyncReadExt, AsyncWriteExt};
102 use tokio_util::io::StreamReader;
103 use tower_async::{service_fn, Service, ServiceExt};
104
105 #[derive(Clone)]
107 struct Always;
108
109 impl Predicate for Always {
110 fn should_compress<B>(&self, _: &http::Response<B>) -> bool
111 where
112 B: http_body::Body,
113 {
114 true
115 }
116 }
117
118 #[tokio::test]
119 async fn gzip_works() {
120 let svc = service_fn(handle);
121 let svc = Compression::new(svc).compress_when(Always);
122
123 let req = Request::builder()
125 .header("accept-encoding", "gzip")
126 .body(Body::empty())
127 .unwrap();
128 let res = svc.call(req).await.unwrap();
129
130 let collected = res.into_body().collect().await.unwrap();
132 let trailers = collected.trailers().cloned().unwrap();
133 let compressed_data = collected.to_bytes();
134
135 let mut decoder = GzDecoder::new(&compressed_data[..]);
139 let mut decompressed = String::new();
140 decoder.read_to_string(&mut decompressed).unwrap();
141
142 assert_eq!(decompressed, "Hello, World!");
143
144 assert_eq!(trailers["foo"], "bar");
146 }
147
148 #[tokio::test]
149 async fn zstd_works() {
150 let svc = service_fn(handle);
151 let svc = Compression::new(svc).compress_when(Always);
152
153 let req = Request::builder()
155 .header("accept-encoding", "zstd")
156 .body(Body::empty())
157 .unwrap();
158 let res = svc.call(req).await.unwrap();
159
160 let body = res.into_body();
162 let compressed_data = body.collect().await.unwrap().to_bytes();
163
164 let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
166 let decompressed = String::from_utf8(decompressed).unwrap();
167
168 assert_eq!(decompressed, "Hello, World!");
169 }
170
171 #[tokio::test]
172 async fn no_recompress() {
173 const DATA: &str = "Hello, World! I'm already compressed with br!";
174
175 let svc = service_fn(|_| async {
176 let buf = {
177 let mut buf = Vec::new();
178
179 let mut enc = BrotliEncoder::new(&mut buf);
180 enc.write_all(DATA.as_bytes()).await?;
181 enc.flush().await?;
182 buf
183 };
184
185 let resp = Response::builder()
186 .header("content-encoding", "br")
187 .body(Body::from(buf))
188 .unwrap();
189 Ok::<_, std::io::Error>(resp)
190 });
191 let svc = Compression::new(svc);
192
193 let req = Request::builder()
198 .header("accept-encoding", "gzip")
199 .body(Body::empty())
200 .unwrap();
201 let res = svc.call(req).await.unwrap();
202
203 assert_eq!(
205 res.headers()
206 .get("content-encoding")
207 .and_then(|h| h.to_str().ok())
208 .unwrap_or_default(),
209 "br",
210 );
211
212 let body = res.into_body();
214 let data = body.collect().await.unwrap().to_bytes();
215
216 let data = {
218 let mut output_buf = Vec::new();
219 let mut decoder = BrotliDecoder::new(&mut output_buf);
220 decoder
221 .write_all(&data)
222 .await
223 .expect("couldn't brotli-decode");
224 decoder.flush().await.expect("couldn't flush");
225 output_buf
226 };
227
228 assert_eq!(data, DATA.as_bytes());
229 }
230
231 async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> {
232 let mut trailers = HeaderMap::new();
233 trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap());
234 let body = Body::from("Hello, World!").with_trailers(trailers);
235 Ok(Response::builder().body(body).unwrap())
236 }
237
238 #[tokio::test]
239 async fn will_not_compress_if_filtered_out() {
240 use predicate::Predicate;
241
242 const DATA: &str = "Hello world uncompressed";
243
244 let svc_fn = service_fn(|_| async {
245 let resp = Response::builder()
246 .body(Body::from(DATA.as_bytes()))
248 .unwrap();
249 Ok::<_, std::io::Error>(resp)
250 });
251
252 #[derive(Default, Clone)]
254 struct EveryOtherResponse(Arc<RwLock<u64>>);
255
256 #[allow(clippy::dbg_macro)]
257 impl Predicate for EveryOtherResponse {
258 fn should_compress<B>(&self, _: &http::Response<B>) -> bool
259 where
260 B: http_body::Body,
261 {
262 let mut guard = self.0.write().unwrap();
263 let should_compress = *guard % 2 != 0;
264 *guard += 1;
265 should_compress
266 }
267 }
268
269 let svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
270 let req = Request::builder()
271 .header("accept-encoding", "br")
272 .body(Body::empty())
273 .unwrap();
274 let res = svc.call(req).await.unwrap();
275
276 let body = res.into_body();
278 let data = body.collect().await.unwrap().to_bytes();
279 let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
280 assert_eq!(DATA, &still_uncompressed);
281
282 let req = Request::builder()
284 .header("accept-encoding", "br")
285 .body(Body::empty())
286 .unwrap();
287 let res = svc.call(req).await.unwrap();
288
289 let body = res.into_body();
291 let data = body.collect().await.unwrap().to_bytes();
292 assert!(String::from_utf8(data.to_vec()).is_err());
293 }
294
295 #[tokio::test]
296 async fn doesnt_compress_images() {
297 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
298 let mut res = Response::new(Body::from(
299 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
300 ));
301 res.headers_mut()
302 .insert(CONTENT_TYPE, "image/png".parse().unwrap());
303 Ok(res)
304 }
305
306 let svc = Compression::new(service_fn(handle));
307
308 let res = svc
309 .oneshot(
310 Request::builder()
311 .header(ACCEPT_ENCODING, "gzip")
312 .body(Body::empty())
313 .unwrap(),
314 )
315 .await
316 .unwrap();
317 assert!(res.headers().get(CONTENT_ENCODING).is_none());
318 }
319
320 #[tokio::test]
321 async fn does_compress_svg() {
322 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
323 let mut res = Response::new(Body::from(
324 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
325 ));
326 res.headers_mut()
327 .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
328 Ok(res)
329 }
330
331 let svc = Compression::new(service_fn(handle));
332
333 let res = svc
334 .oneshot(
335 Request::builder()
336 .header(ACCEPT_ENCODING, "gzip")
337 .body(Body::empty())
338 .unwrap(),
339 )
340 .await
341 .unwrap();
342 assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
343 }
344
345 #[tokio::test]
346 async fn compress_with_quality() {
347 const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
348 let level = CompressionLevel::Best;
349
350 let svc = service_fn(|_| async {
351 let resp = Response::builder()
352 .body(Body::from(DATA.as_bytes()))
353 .unwrap();
354 Ok::<_, std::io::Error>(resp)
355 });
356
357 let svc = Compression::new(svc).quality(level);
358
359 let req = Request::builder()
361 .header("accept-encoding", "br")
362 .body(Body::empty())
363 .unwrap();
364 let res = svc.call(req).await.unwrap();
365
366 let body = res.into_body();
368 let compressed_data = body.collect().await.unwrap().to_bytes();
369
370 let compressed_with_level = {
372 use async_compression::tokio::bufread::BrotliEncoder;
373
374 let stream = Box::pin(futures::stream::once(async move {
375 Ok::<_, std::io::Error>(DATA.as_bytes())
376 }));
377 let reader = StreamReader::new(stream);
378 let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
379
380 let mut buf = Vec::new();
381 enc.read_to_end(&mut buf).await.unwrap();
382 buf
383 };
384
385 assert_eq!(
386 compressed_data,
387 compressed_with_level.as_slice(),
388 "Compression level is not respected"
389 );
390 }
391}