1pub mod predicate;
73
74mod body;
75mod future;
76mod layer;
77mod pin_project_cfg;
78mod service;
79
80#[doc(inline)]
81pub use self::{
82 body::CompressionBody,
83 future::ResponseFuture,
84 layer::CompressionLayer,
85 predicate::{DefaultPredicate, Predicate},
86 service::Compression,
87};
88pub use crate::compression_utils::CompressionLevel;
89
90#[cfg(test)]
91mod tests {
92 use crate::compression::predicate::SizeAbove;
93
94 use super::*;
95 use crate::test_helpers::{Body, WithTrailers};
96 use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
97 use bytes::Bytes;
98 use flate2::read::GzDecoder;
99 use http::header::{
100 ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE,
101 };
102 use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
103 use http_body::Body as _;
104 use http_body_util::BodyExt;
105 use std::convert::Infallible;
106 use std::io::Read;
107 use std::sync::{Arc, RwLock};
108 use tokio::io::{AsyncReadExt, AsyncWriteExt};
109 use tokio_util::io::StreamReader;
110 use tower::{service_fn, BoxError, Service, ServiceExt};
111
112 #[derive(Clone)]
114 struct Always;
115
116 impl Predicate for Always {
117 fn should_compress<B>(&self, _: &http::Response<B>) -> bool
118 where
119 B: http_body::Body,
120 {
121 true
122 }
123 }
124
125 #[tokio::test]
126 async fn gzip_works() {
127 let svc = service_fn(handle);
128 let mut svc = Compression::new(svc).compress_when(Always);
129
130 let req = Request::builder()
132 .header("accept-encoding", "gzip")
133 .body(Body::empty())
134 .unwrap();
135 let res = svc.ready().await.unwrap().call(req).await.unwrap();
136
137 let collected = res.into_body().collect().await.unwrap();
139 let trailers = collected.trailers().cloned().unwrap();
140 let compressed_data = collected.to_bytes();
141
142 let mut decoder = GzDecoder::new(&compressed_data[..]);
146 let mut decompressed = String::new();
147 decoder.read_to_string(&mut decompressed).unwrap();
148
149 assert_eq!(decompressed, "Hello, World!");
150
151 assert_eq!(trailers["foo"], "bar");
153 }
154
155 #[tokio::test]
156 async fn x_gzip_works() {
157 let svc = service_fn(handle);
158 let mut svc = Compression::new(svc).compress_when(Always);
159
160 let req = Request::builder()
162 .header("accept-encoding", "x-gzip")
163 .body(Body::empty())
164 .unwrap();
165 let res = svc.ready().await.unwrap().call(req).await.unwrap();
166
167 assert_eq!(
170 res.headers()
171 .get_all("content-encoding")
172 .iter()
173 .collect::<Vec<&HeaderValue>>(),
174 vec!(HeaderValue::from_static("gzip"))
175 );
176
177 let collected = res.into_body().collect().await.unwrap();
179 let trailers = collected.trailers().cloned().unwrap();
180 let compressed_data = collected.to_bytes();
181
182 let mut decoder = GzDecoder::new(&compressed_data[..]);
186 let mut decompressed = String::new();
187 decoder.read_to_string(&mut decompressed).unwrap();
188
189 assert_eq!(decompressed, "Hello, World!");
190
191 assert_eq!(trailers["foo"], "bar");
193 }
194
195 #[tokio::test]
196 async fn zstd_works() {
197 let svc = service_fn(handle);
198 let mut svc = Compression::new(svc).compress_when(Always);
199
200 let req = Request::builder()
202 .header("accept-encoding", "zstd")
203 .body(Body::empty())
204 .unwrap();
205 let res = svc.ready().await.unwrap().call(req).await.unwrap();
206
207 let body = res.into_body();
209 let compressed_data = body.collect().await.unwrap().to_bytes();
210
211 let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
213 let decompressed = String::from_utf8(decompressed).unwrap();
214
215 assert_eq!(decompressed, "Hello, World!");
216 }
217
218 #[tokio::test]
219 async fn no_recompress() {
220 const DATA: &str = "Hello, World! I'm already compressed with br!";
221
222 let svc = service_fn(|_| async {
223 let buf = {
224 let mut buf = Vec::new();
225
226 let mut enc = BrotliEncoder::new(&mut buf);
227 enc.write_all(DATA.as_bytes()).await?;
228 enc.flush().await?;
229 buf
230 };
231
232 let resp = Response::builder()
233 .header("content-encoding", "br")
234 .body(Body::from(buf))
235 .unwrap();
236 Ok::<_, std::io::Error>(resp)
237 });
238 let mut svc = Compression::new(svc);
239
240 let req = Request::builder()
245 .header("accept-encoding", "gzip")
246 .body(Body::empty())
247 .unwrap();
248 let res = svc.ready().await.unwrap().call(req).await.unwrap();
249
250 assert_eq!(
252 res.headers()
253 .get("content-encoding")
254 .and_then(|h| h.to_str().ok())
255 .unwrap_or_default(),
256 "br",
257 );
258
259 let body = res.into_body();
261 let data = body.collect().await.unwrap().to_bytes();
262
263 let data = {
265 let mut output_buf = Vec::new();
266 let mut decoder = BrotliDecoder::new(&mut output_buf);
267 decoder
268 .write_all(&data)
269 .await
270 .expect("couldn't brotli-decode");
271 decoder.flush().await.expect("couldn't flush");
272 output_buf
273 };
274
275 assert_eq!(data, DATA.as_bytes());
276 }
277
278 async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> {
279 let mut trailers = HeaderMap::new();
280 trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap());
281 let body = Body::from("Hello, World!").with_trailers(trailers);
282 Ok(Response::builder().body(body).unwrap())
283 }
284
285 #[tokio::test]
286 async fn will_not_compress_if_filtered_out() {
287 use predicate::Predicate;
288
289 const DATA: &str = "Hello world uncompressed";
290
291 let svc_fn = service_fn(|_| async {
292 let resp = Response::builder()
293 .body(Body::from(DATA.as_bytes()))
295 .unwrap();
296 Ok::<_, std::io::Error>(resp)
297 });
298
299 #[derive(Default, Clone)]
301 struct EveryOtherResponse(Arc<RwLock<u64>>);
302
303 #[allow(clippy::dbg_macro)]
304 impl Predicate for EveryOtherResponse {
305 fn should_compress<B>(&self, _: &http::Response<B>) -> bool
306 where
307 B: http_body::Body,
308 {
309 let mut guard = self.0.write().unwrap();
310 let should_compress = *guard % 2 != 0;
311 *guard += 1;
312 should_compress
313 }
314 }
315
316 let mut svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
317 let req = Request::builder()
318 .header("accept-encoding", "br")
319 .body(Body::empty())
320 .unwrap();
321 let res = svc.ready().await.unwrap().call(req).await.unwrap();
322
323 let body = res.into_body();
325 let data = body.collect().await.unwrap().to_bytes();
326 let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
327 assert_eq!(DATA, &still_uncompressed);
328
329 let req = Request::builder()
331 .header("accept-encoding", "br")
332 .body(Body::empty())
333 .unwrap();
334 let res = svc.ready().await.unwrap().call(req).await.unwrap();
335
336 let body = res.into_body();
338 let data = body.collect().await.unwrap().to_bytes();
339 assert!(String::from_utf8(data.to_vec()).is_err());
340 }
341
342 #[tokio::test]
343 async fn doesnt_compress_images() {
344 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
345 let mut res = Response::new(Body::from(
346 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
347 ));
348 res.headers_mut()
349 .insert(CONTENT_TYPE, "image/png".parse().unwrap());
350 Ok(res)
351 }
352
353 let svc = Compression::new(service_fn(handle));
354
355 let res = svc
356 .oneshot(
357 Request::builder()
358 .header(ACCEPT_ENCODING, "gzip")
359 .body(Body::empty())
360 .unwrap(),
361 )
362 .await
363 .unwrap();
364 assert!(res.headers().get(CONTENT_ENCODING).is_none());
365 }
366
367 #[tokio::test]
368 async fn does_compress_svg() {
369 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
370 let mut res = Response::new(Body::from(
371 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
372 ));
373 res.headers_mut()
374 .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
375 Ok(res)
376 }
377
378 let svc = Compression::new(service_fn(handle));
379
380 let res = svc
381 .oneshot(
382 Request::builder()
383 .header(ACCEPT_ENCODING, "gzip")
384 .body(Body::empty())
385 .unwrap(),
386 )
387 .await
388 .unwrap();
389 assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
390 }
391
392 #[tokio::test]
393 async fn does_compress_grpc_web() {
394 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
395 let mut res = Response::new(Body::from(
396 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
397 ));
398 res.headers_mut()
399 .insert(CONTENT_TYPE, "application/grpc-web+proto".parse().unwrap());
400 Ok(res)
401 }
402
403 let svc = Compression::new(service_fn(handle));
404
405 let res = svc
406 .oneshot(
407 Request::builder()
408 .header(ACCEPT_ENCODING, "gzip")
409 .body(Body::empty())
410 .unwrap(),
411 )
412 .await
413 .unwrap();
414 assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
415 }
416
417 #[tokio::test]
418 async fn compress_with_quality() {
419 const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
420 let level = CompressionLevel::Best;
421
422 let svc = service_fn(|_| async {
423 let resp = Response::builder()
424 .body(Body::from(DATA.as_bytes()))
425 .unwrap();
426 Ok::<_, std::io::Error>(resp)
427 });
428
429 let mut svc = Compression::new(svc).quality(level);
430
431 let req = Request::builder()
433 .header("accept-encoding", "br")
434 .body(Body::empty())
435 .unwrap();
436 let res = svc.ready().await.unwrap().call(req).await.unwrap();
437
438 let body = res.into_body();
440 let compressed_data = body.collect().await.unwrap().to_bytes();
441
442 let compressed_with_level = {
444 use async_compression::tokio::bufread::BrotliEncoder;
445
446 let stream = Box::pin(futures_util::stream::once(async move {
447 Ok::<_, std::io::Error>(DATA.as_bytes())
448 }));
449 let reader = StreamReader::new(stream);
450 let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
451
452 let mut buf = Vec::new();
453 enc.read_to_end(&mut buf).await.unwrap();
454 buf
455 };
456
457 assert_eq!(
458 compressed_data,
459 compressed_with_level.as_slice(),
460 "Compression level is not respected"
461 );
462 }
463
464 #[tokio::test]
465 async fn should_not_compress_ranges() {
466 let svc = service_fn(|_| async {
467 let mut res = Response::new(Body::from("Hello"));
468 let headers = res.headers_mut();
469 headers.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
470 headers.insert(CONTENT_RANGE, "bytes 0-4/*".parse().unwrap());
471 Ok::<_, std::io::Error>(res)
472 });
473 let mut svc = Compression::new(svc).compress_when(Always);
474
475 let req = Request::builder()
477 .header(ACCEPT_ENCODING, "gzip")
478 .header(RANGE, "bytes=0-4")
479 .body(Body::empty())
480 .unwrap();
481 let res = svc.ready().await.unwrap().call(req).await.unwrap();
482 let headers = res.headers().clone();
483
484 let collected = res.into_body().collect().await.unwrap().to_bytes();
486
487 assert_eq!(headers[ACCEPT_RANGES], "bytes");
488 assert!(!headers.contains_key(CONTENT_ENCODING));
489 assert_eq!(collected, "Hello");
490 }
491
492 #[tokio::test]
493 async fn should_strip_accept_ranges_header_when_compressing() {
494 let svc = service_fn(|_| async {
495 let mut res = Response::new(Body::from("Hello, World!"));
496 res.headers_mut()
497 .insert(ACCEPT_RANGES, "bytes".parse().unwrap());
498 Ok::<_, std::io::Error>(res)
499 });
500 let mut svc = Compression::new(svc).compress_when(Always);
501
502 let req = Request::builder()
504 .header(ACCEPT_ENCODING, "gzip")
505 .body(Body::empty())
506 .unwrap();
507 let res = svc.ready().await.unwrap().call(req).await.unwrap();
508 let headers = res.headers().clone();
509
510 let collected = res.into_body().collect().await.unwrap();
512 let compressed_data = collected.to_bytes();
513
514 let mut decoder = GzDecoder::new(&compressed_data[..]);
518 let mut decompressed = String::new();
519 decoder.read_to_string(&mut decompressed).unwrap();
520
521 assert!(!headers.contains_key(ACCEPT_RANGES));
522 assert_eq!(headers[CONTENT_ENCODING], "gzip");
523 assert_eq!(decompressed, "Hello, World!");
524 }
525
526 #[tokio::test]
527 async fn trailers_with_empty_body() {
528 let svc = service_fn(|_req: Request<Body>| async {
529 let mut trailers = HeaderMap::new();
530 trailers.insert("grpc-status", "0".parse().unwrap());
531 trailers.insert("grpc-message", "OK".parse().unwrap());
532 let body = Body::empty().with_trailers(trailers);
533 Ok::<_, Infallible>(Response::builder().body(body).unwrap())
534 });
535 let mut svc = Compression::new(svc).compress_when(Always);
536
537 let req = Request::builder()
538 .header("accept-encoding", "gzip")
539 .body(Body::empty())
540 .unwrap();
541 let res = svc.ready().await.unwrap().call(req).await.unwrap();
542
543 let collected = res.into_body().collect().await.unwrap();
544 let trailers = collected.trailers().cloned().unwrap();
545 assert_eq!(trailers["grpc-status"], "0");
546 assert_eq!(trailers["grpc-message"], "OK");
547 }
548
549 #[tokio::test]
550 async fn trailers_with_streamed_body() {
551 let svc = service_fn(|_req: Request<Body>| async {
553 let stream = futures_util::stream::iter(vec![
554 Ok::<_, BoxError>(Bytes::from("chunk1")),
555 Ok(Bytes::from("chunk2")),
556 Ok(Bytes::from("chunk3")),
557 ]);
558 let mut trailers = HeaderMap::new();
559 trailers.insert("grpc-status", "0".parse().unwrap());
560 let body = Body::from_stream(stream).with_trailers(trailers);
561 Ok::<_, Infallible>(Response::builder().body(body).unwrap())
562 });
563 let mut svc = Compression::new(svc).compress_when(Always);
564
565 let req = Request::builder()
566 .header("accept-encoding", "gzip")
567 .body(Body::empty())
568 .unwrap();
569 let res = svc.ready().await.unwrap().call(req).await.unwrap();
570
571 let collected = res.into_body().collect().await.unwrap();
572 let trailers = collected.trailers().cloned().unwrap();
573 let compressed_data = collected.to_bytes();
574
575 let mut decoder = GzDecoder::new(&compressed_data[..]);
576 let mut decompressed = String::new();
577 decoder.read_to_string(&mut decompressed).unwrap();
578
579 assert_eq!(decompressed, "chunk1chunk2chunk3");
580 assert_eq!(trailers["grpc-status"], "0");
581 }
582
583 #[tokio::test]
584 async fn trailers_with_grpc_web_content_type() {
585 let svc = service_fn(|_req: Request<Body>| async {
586 let mut trailers = HeaderMap::new();
587 trailers.insert("grpc-status", "0".parse().unwrap());
588 let body = Body::from("a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize))
589 .with_trailers(trailers);
590 let mut res = Response::new(body);
591 res.headers_mut()
592 .insert(CONTENT_TYPE, "application/grpc-web+proto".parse().unwrap());
593 Ok::<_, Infallible>(res)
594 });
595 let mut svc = Compression::new(svc).compress_when(Always);
596
597 let req = Request::builder()
598 .header("accept-encoding", "gzip")
599 .body(Body::empty())
600 .unwrap();
601 let res = svc.ready().await.unwrap().call(req).await.unwrap();
602
603 let collected = res.into_body().collect().await.unwrap();
604 let trailers = collected.trailers().cloned().unwrap();
605 assert_eq!(trailers["grpc-status"], "0");
606 }
607
608 #[tokio::test]
609 async fn size_hint_identity() {
610 let msg = "Hello, world!";
611 let svc = service_fn(|_| async { Ok::<_, std::io::Error>(Response::new(Body::from(msg))) });
612 let mut svc = Compression::new(svc);
613
614 let req = Request::new(Body::empty());
615 let res = svc.ready().await.unwrap().call(req).await.unwrap();
616 let body = res.into_body();
617 assert_eq!(body.size_hint().exact().unwrap(), msg.len() as u64);
618 }
619
620 #[tokio::test]
621 async fn wildcard_q_zero_returns_406() {
622 let svc = service_fn(handle);
623 let mut svc = Compression::new(svc).compress_when(Always);
624
625 let req = Request::builder()
626 .header("accept-encoding", "*;q=0")
627 .body(Body::empty())
628 .unwrap();
629 let res = svc.ready().await.unwrap().call(req).await.unwrap();
630
631 assert_eq!(res.status(), http::StatusCode::NOT_ACCEPTABLE);
632 assert!(res
633 .headers()
634 .get_all(http::header::VARY)
635 .iter()
636 .any(|v| v.to_str().unwrap().contains("accept-encoding")));
637 }
638
639 #[tokio::test]
640 async fn wildcard_q_zero_with_gzip_picks_gzip() {
641 let svc = service_fn(handle);
642 let mut svc = Compression::new(svc).compress_when(Always);
643
644 let req = Request::builder()
645 .header("accept-encoding", "*;q=0,gzip")
646 .body(Body::empty())
647 .unwrap();
648 let res = svc.ready().await.unwrap().call(req).await.unwrap();
649
650 assert_eq!(res.status(), http::StatusCode::OK);
651 assert_eq!(
652 res.headers()
653 .get("content-encoding")
654 .and_then(|v| v.to_str().ok()),
655 Some("gzip")
656 );
657 }
658
659 #[tokio::test]
660 async fn wildcard_alone_compresses() {
661 let svc = service_fn(handle);
662 let mut svc = Compression::new(svc).compress_when(Always);
663
664 let req = Request::builder()
665 .header("accept-encoding", "*")
666 .body(Body::empty())
667 .unwrap();
668 let res = svc.ready().await.unwrap().call(req).await.unwrap();
669
670 assert_eq!(res.status(), http::StatusCode::OK);
671 assert!(res.headers().contains_key(CONTENT_ENCODING));
673 }
674
675 #[tokio::test]
676 async fn identity_q_zero_alone_returns_406() {
677 let svc = service_fn(handle);
678 let mut svc = Compression::new(svc).compress_when(Always);
679
680 let req = Request::builder()
681 .header("accept-encoding", "identity;q=0")
682 .body(Body::empty())
683 .unwrap();
684 let res = svc.ready().await.unwrap().call(req).await.unwrap();
685
686 assert_eq!(res.status(), http::StatusCode::NOT_ACCEPTABLE);
687 }
688
689 #[tokio::test]
690 async fn identity_q_zero_with_gzip_picks_gzip() {
691 let svc = service_fn(handle);
692 let mut svc = Compression::new(svc).compress_when(Always);
693
694 let req = Request::builder()
695 .header("accept-encoding", "identity;q=0,gzip")
696 .body(Body::empty())
697 .unwrap();
698 let res = svc.ready().await.unwrap().call(req).await.unwrap();
699
700 assert_eq!(res.status(), http::StatusCode::OK);
701 assert_eq!(
702 res.headers()
703 .get("content-encoding")
704 .and_then(|v| v.to_str().ok()),
705 Some("gzip")
706 );
707 }
708}