1#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24 Body, Method, Request, Response, StatusCode,
25 header::{CONTENT_ENCODING, CONTENT_LENGTH},
26};
27use mime_guess::Mime;
28use pin_project::pin_project;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use tokio_util::io::{ReaderStream, StreamReader};
32
33use crate::{
34 Error, Result, error_page,
35 handler::RequestHandlerOpts,
36 headers_ext::{AcceptEncoding, ContentCoding},
37 http_ext::MethodExt,
38 mime_ext::MimeExt,
39 settings::CompressionLevel,
40};
41
42const MIN_COMPRESS_SIZE: usize = 200;
44
45const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
47 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
48 ContentCoding::DEFLATE,
49 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
50 ContentCoding::GZIP,
51 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
52 ContentCoding::BROTLI,
53 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
54 ContentCoding::ZSTD,
55];
56
57pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
59 handler_opts.compression = enabled;
60 handler_opts.compression_level = level;
61
62 const FORMATS: &[&str] = &[
63 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
64 "deflate",
65 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
66 "gzip",
67 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
68 "brotli",
69 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
70 "zstd",
71 ];
72 tracing::info!(
73 "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
74 FORMATS.join(",")
75 );
76}
77
78pub(crate) fn post_process<T>(
80 opts: &RequestHandlerOpts,
81 req: &Request<T>,
82 mut resp: Response<Body>,
83) -> Result<Response<Body>, Error> {
84 if !opts.compression {
85 return Ok(resp);
86 }
87
88 let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
89 if is_precompressed {
90 return Ok(resp);
91 }
92
93 let enc = HeaderValue::from_name(hyper::header::ACCEPT_ENCODING);
95 let value = resp.headers().get(hyper::header::VARY).map_or(enc, |h| {
96 let mut a = h.to_str().unwrap_or_default().to_owned();
97 let b = hyper::header::ACCEPT_ENCODING.as_str();
98 if !a.contains(b) {
99 if !a.is_empty() {
100 a.push(',');
101 }
102 a.push_str(b);
103 }
104 HeaderValue::from_str(a.as_str()).unwrap()
105 });
106
107 resp.headers_mut().insert(hyper::header::VARY, value);
108
109 match auto(req.method(), req.headers(), opts.compression_level, resp) {
111 Ok(resp) => Ok(resp),
112 Err(err) => {
113 tracing::error!("error during body compression: {:?}", err);
114 error_page::error_response(
115 req.uri(),
116 req.method(),
117 &StatusCode::INTERNAL_SERVER_ERROR,
118 &opts.page404,
119 &opts.page50x,
120 )
121 }
122 }
123}
124
125pub fn auto(
130 method: &Method,
131 headers: &HeaderMap<HeaderValue>,
132 level: CompressionLevel,
133 resp: Response<Body>,
134) -> Result<Response<Body>> {
135 if method.is_head() || method.is_options() {
137 return Ok(resp);
138 }
139
140 if let Some(encoding) = get_preferred_encoding(headers) {
142 tracing::trace!(
143 "preferred encoding selected from the accept-encoding header: {:?}",
144 encoding
145 );
146
147 if let Some(content_type) = resp.headers().typed_get::<ContentType>()
149 && !Mime::from(content_type).is_compressible()
150 {
151 return Ok(resp);
152 }
153
154 if let Some(content_length) = resp
158 .headers()
159 .get(CONTENT_LENGTH)
160 .and_then(|v| v.to_str().ok())
161 .and_then(|v| v.parse::<usize>().ok())
162 && content_length < MIN_COMPRESS_SIZE
163 {
164 tracing::trace!(
165 "skipping compression: content-length ({content_length}) below minimum ({MIN_COMPRESS_SIZE})",
166 );
167 return Ok(resp);
168 }
169
170 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
171 if encoding == ContentCoding::GZIP {
172 let (head, body) = resp.into_parts();
173 return Ok(gzip(head, body.into(), level));
174 }
175
176 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
177 if encoding == ContentCoding::DEFLATE {
178 let (head, body) = resp.into_parts();
179 return Ok(deflate(head, body.into(), level));
180 }
181
182 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
183 if encoding == ContentCoding::BROTLI {
184 let (head, body) = resp.into_parts();
185 return Ok(brotli(head, body.into(), level));
186 }
187
188 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
189 if encoding == ContentCoding::ZSTD {
190 let (head, body) = resp.into_parts();
191 return Ok(zstd(head, body.into(), level));
192 }
193
194 tracing::trace!(
195 "no compression feature matched the preferred encoding, probably not enabled or unsupported"
196 );
197 }
198
199 Ok(resp)
200}
201
202#[cfg(any(feature = "compression", feature = "compression-gzip"))]
205#[cfg_attr(
206 docsrs,
207 doc(cfg(any(feature = "compression", feature = "compression-gzip")))
208)]
209pub fn gzip(
210 mut head: http::response::Parts,
211 body: CompressableBody<Body, hyper::Error>,
212 level: CompressionLevel,
213) -> Response<Body> {
214 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
215
216 tracing::trace!("compressing response body on the fly using GZIP");
217
218 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
219 let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
220 StreamReader::new(body),
221 level,
222 )));
223 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
224 head.headers.remove(CONTENT_LENGTH);
225 head.headers.insert(CONTENT_ENCODING, header);
226 Response::from_parts(head, body)
227}
228
229#[cfg(any(feature = "compression", feature = "compression-deflate"))]
232#[cfg_attr(
233 docsrs,
234 doc(cfg(any(feature = "compression", feature = "compression-deflate")))
235)]
236pub fn deflate(
237 mut head: http::response::Parts,
238 body: CompressableBody<Body, hyper::Error>,
239 level: CompressionLevel,
240) -> Response<Body> {
241 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
242
243 tracing::trace!("compressing response body on the fly using DEFLATE");
244
245 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
246 let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
247 StreamReader::new(body),
248 level,
249 )));
250 let header = create_encoding_header(
251 head.headers.remove(CONTENT_ENCODING),
252 ContentCoding::DEFLATE,
253 );
254 head.headers.remove(CONTENT_LENGTH);
255 head.headers.insert(CONTENT_ENCODING, header);
256 Response::from_parts(head, body)
257}
258
259#[cfg(any(feature = "compression", feature = "compression-brotli"))]
262#[cfg_attr(
263 docsrs,
264 doc(cfg(any(feature = "compression", feature = "compression-brotli")))
265)]
266pub fn brotli(
267 mut head: http::response::Parts,
268 body: CompressableBody<Body, hyper::Error>,
269 level: CompressionLevel,
270) -> Response<Body> {
271 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
272
273 tracing::trace!("compressing response body on the fly using BROTLI");
274
275 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
276 let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
277 StreamReader::new(body),
278 level,
279 )));
280 let header =
281 create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
282 head.headers.remove(CONTENT_LENGTH);
283 head.headers.insert(CONTENT_ENCODING, header);
284 Response::from_parts(head, body)
285}
286
287#[cfg(any(feature = "compression", feature = "compression-zstd"))]
290#[cfg_attr(
291 docsrs,
292 doc(cfg(any(feature = "compression", feature = "compression-zstd")))
293)]
294pub fn zstd(
295 mut head: http::response::Parts,
296 body: CompressableBody<Body, hyper::Error>,
297 level: CompressionLevel,
298) -> Response<Body> {
299 const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
300
301 tracing::trace!("compressing response body on the fly using ZSTD");
302
303 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
304 let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
305 StreamReader::new(body),
306 level,
307 )));
308 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
309 head.headers.remove(CONTENT_LENGTH);
310 head.headers.insert(CONTENT_ENCODING, header);
311 Response::from_parts(head, body)
312}
313
314pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
316 if let Some(val) = existing
317 && let Ok(str_val) = val.to_str()
318 {
319 return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
320 .unwrap_or_else(|_| coding.into());
321 }
322 coding.into()
323}
324
325#[inline(always)]
327pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
328 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
329 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
330
331 for encoding in accept_encoding.sorted_encodings() {
332 if AVAILABLE_ENCODINGS.contains(&encoding) {
333 return Some(encoding);
334 }
335 }
336 }
337 None
338}
339
340#[inline(always)]
342pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
343 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
344 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
345
346 return accept_encoding
347 .sorted_encodings()
348 .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
349 .collect::<Vec<_>>();
350 }
351 vec![]
352}
353
354#[pin_project]
357#[derive(Debug)]
358pub struct CompressableBody<S, E>
359where
360 S: Stream<Item = Result<Bytes, E>>,
361 E: std::error::Error,
362{
363 #[pin]
364 body: S,
365}
366
367impl<S, E> Stream for CompressableBody<S, E>
368where
369 S: Stream<Item = Result<Bytes, E>>,
370 E: std::error::Error,
371{
372 type Item = std::io::Result<Bytes>;
373
374 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
375 use std::io::{Error, ErrorKind};
376
377 let pin = self.project();
378 S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
379 }
380}
381
382impl From<Body> for CompressableBody<Body, hyper::Error> {
383 #[inline(always)]
384 fn from(body: Body) -> Self {
385 CompressableBody { body }
386 }
387}
388
389#[cfg(test)]
390#[cfg(any(feature = "compression", feature = "compression-gzip"))]
391mod tests {
392 use super::*;
393 use http::header::{ACCEPT_ENCODING, CONTENT_TYPE};
394
395 fn text_response_with_size(size: usize) -> Response<Body> {
396 let mut resp = Response::new(Body::from(vec![b'x'; size]));
397 resp.headers_mut()
398 .insert(CONTENT_TYPE, "text/html".parse().unwrap());
399 resp.headers_mut()
400 .insert(CONTENT_LENGTH, size.to_string().parse().unwrap());
401 resp
402 }
403
404 fn text_response_without_length() -> Response<Body> {
405 let mut resp = Response::new(Body::from("hello world"));
406 resp.headers_mut()
407 .insert(CONTENT_TYPE, "text/html".parse().unwrap());
408 resp
409 }
410
411 fn accept_gzip_headers() -> HeaderMap<HeaderValue> {
412 let mut headers = HeaderMap::new();
413 headers.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
414 headers
415 }
416
417 #[test]
418 fn small_response_below_threshold_is_not_compressed() {
419 let resp = text_response_with_size(MIN_COMPRESS_SIZE - 1);
420 let headers = accept_gzip_headers();
421 let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
422
423 assert!(
424 result.headers().get(CONTENT_ENCODING).is_none(),
425 "responses below {MIN_COMPRESS_SIZE} bytes must not be compressed"
426 );
427 }
428
429 #[test]
430 fn response_at_threshold_is_compressed() {
431 let resp = text_response_with_size(MIN_COMPRESS_SIZE);
432 let headers = accept_gzip_headers();
433 let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
434
435 assert!(
436 result.headers().get(CONTENT_ENCODING).is_some(),
437 "responses at exactly {MIN_COMPRESS_SIZE} bytes must be compressed"
438 );
439 }
440
441 #[test]
442 fn response_above_threshold_is_compressed() {
443 let resp = text_response_with_size(MIN_COMPRESS_SIZE + 1);
444 let headers = accept_gzip_headers();
445 let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
446
447 assert!(
448 result.headers().get(CONTENT_ENCODING).is_some(),
449 "responses above {MIN_COMPRESS_SIZE} bytes must be compressed"
450 );
451 }
452
453 #[test]
454 fn response_without_content_length_is_compressed() {
455 let resp = text_response_without_length();
456 let headers = accept_gzip_headers();
457 let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
458
459 assert!(
460 result.headers().get(CONTENT_ENCODING).is_some(),
461 "responses without Content-Length must still be compressed"
462 );
463 }
464
465 #[test]
466 fn small_response_head_method_is_not_compressed() {
467 let resp = text_response_with_size(MIN_COMPRESS_SIZE - 1);
468 let headers = accept_gzip_headers();
469 let result = auto(&Method::HEAD, &headers, CompressionLevel::Default, resp).unwrap();
470
471 assert!(
472 result.headers().get(CONTENT_ENCODING).is_none(),
473 "HEAD requests are never compressed regardless of size"
474 );
475 }
476
477 #[test]
478 fn non_compressible_content_type_is_not_compressed() {
479 let mut resp = Response::new(Body::from(vec![b'x'; MIN_COMPRESS_SIZE + 100]));
480 resp.headers_mut()
481 .insert(CONTENT_TYPE, "image/png".parse().unwrap());
482 resp.headers_mut().insert(
483 CONTENT_LENGTH,
484 (MIN_COMPRESS_SIZE + 100).to_string().parse().unwrap(),
485 );
486 let headers = accept_gzip_headers();
487 let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
488
489 assert!(
490 result.headers().get(CONTENT_ENCODING).is_none(),
491 "non-compressible content-types are never compressed"
492 );
493 }
494}