Skip to main content

static_web_server/
compression.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! Auto-compression module to compress responses body.
7//!
8
9// Part of the file is borrowed from <https://github.com/seanmonstar/warp/pull/513>*
10
11#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24    Body, Method, Request, Response, StatusCode,
25    header::{CONTENT_ENCODING, CONTENT_LENGTH},
26};
27use mime_guess::Mime;
28use pin_project::pin_project;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use tokio_util::io::{ReaderStream, StreamReader};
32
33use crate::{
34    Error, Result, error_page,
35    handler::RequestHandlerOpts,
36    headers_ext::{AcceptEncoding, ContentCoding},
37    http_ext::MethodExt,
38    mime_ext::MimeExt,
39    settings::CompressionLevel,
40};
41
42/// Minimum response body size in bytes below which dynamic compression is skipped.
43const MIN_COMPRESS_SIZE: usize = 200;
44
45/// List of encodings that can be handled given enabled features.
46const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
47    #[cfg(any(feature = "compression", feature = "compression-deflate"))]
48    ContentCoding::DEFLATE,
49    #[cfg(any(feature = "compression", feature = "compression-gzip"))]
50    ContentCoding::GZIP,
51    #[cfg(any(feature = "compression", feature = "compression-brotli"))]
52    ContentCoding::BROTLI,
53    #[cfg(any(feature = "compression", feature = "compression-zstd"))]
54    ContentCoding::ZSTD,
55];
56
57/// Initializes dynamic compression.
58pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
59    handler_opts.compression = enabled;
60    handler_opts.compression_level = level;
61
62    const FORMATS: &[&str] = &[
63        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
64        "deflate",
65        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
66        "gzip",
67        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
68        "brotli",
69        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
70        "zstd",
71    ];
72    tracing::info!(
73        "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
74        FORMATS.join(",")
75    );
76}
77
78/// Post-processing to dynamically compress the response if necessary.
79pub(crate) fn post_process<T>(
80    opts: &RequestHandlerOpts,
81    req: &Request<T>,
82    mut resp: Response<Body>,
83) -> Result<Response<Body>, Error> {
84    if !opts.compression {
85        return Ok(resp);
86    }
87
88    let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
89    if is_precompressed {
90        return Ok(resp);
91    }
92
93    // Compression content encoding varies so use a `Vary` header
94    let enc = HeaderValue::from_name(hyper::header::ACCEPT_ENCODING);
95    let value = resp.headers().get(hyper::header::VARY).map_or(enc, |h| {
96        let mut a = h.to_str().unwrap_or_default().to_owned();
97        let b = hyper::header::ACCEPT_ENCODING.as_str();
98        if !a.contains(b) {
99            if !a.is_empty() {
100                a.push(',');
101            }
102            a.push_str(b);
103        }
104        HeaderValue::from_str(a.as_str()).unwrap()
105    });
106
107    resp.headers_mut().insert(hyper::header::VARY, value);
108
109    // Auto compression based on the `Accept-Encoding` header
110    match auto(req.method(), req.headers(), opts.compression_level, resp) {
111        Ok(resp) => Ok(resp),
112        Err(err) => {
113            tracing::error!("error during body compression: {:?}", err);
114            error_page::error_response(
115                req.uri(),
116                req.method(),
117                &StatusCode::INTERNAL_SERVER_ERROR,
118                &opts.page404,
119                &opts.page50x,
120            )
121        }
122    }
123}
124
125/// Create a wrapping handler that compresses the Body of a [`hyper::Response`]
126/// using gzip, `deflate`, `brotli` or `zstd` if is specified in the `Accept-Encoding` header, adding
127/// `content-encoding: <coding>` to the Response's [`HeaderMap`].
128/// It also provides the ability to apply compression for text-based MIME types only.
129pub fn auto(
130    method: &Method,
131    headers: &HeaderMap<HeaderValue>,
132    level: CompressionLevel,
133    resp: Response<Body>,
134) -> Result<Response<Body>> {
135    // Skip compression for HEAD and OPTIONS request methods
136    if method.is_head() || method.is_options() {
137        return Ok(resp);
138    }
139
140    // Compress response based on Accept-Encoding header
141    if let Some(encoding) = get_preferred_encoding(headers) {
142        tracing::trace!(
143            "preferred encoding selected from the accept-encoding header: {:?}",
144            encoding
145        );
146
147        // Skip compression for non-text-based MIME types
148        if let Some(content_type) = resp.headers().typed_get::<ContentType>()
149            && !Mime::from(content_type).is_compressible()
150        {
151            return Ok(resp);
152        }
153
154        // Skip compression for responses below the minimum size threshold.
155        // Tiny payloads gain no benefit and the compression overhead can
156        // make them larger than the original.
157        if let Some(content_length) = resp
158            .headers()
159            .get(CONTENT_LENGTH)
160            .and_then(|v| v.to_str().ok())
161            .and_then(|v| v.parse::<usize>().ok())
162            && content_length < MIN_COMPRESS_SIZE
163        {
164            tracing::trace!(
165                "skipping compression: content-length ({content_length}) below minimum ({MIN_COMPRESS_SIZE})",
166            );
167            return Ok(resp);
168        }
169
170        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
171        if encoding == ContentCoding::GZIP {
172            let (head, body) = resp.into_parts();
173            return Ok(gzip(head, body.into(), level));
174        }
175
176        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
177        if encoding == ContentCoding::DEFLATE {
178            let (head, body) = resp.into_parts();
179            return Ok(deflate(head, body.into(), level));
180        }
181
182        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
183        if encoding == ContentCoding::BROTLI {
184            let (head, body) = resp.into_parts();
185            return Ok(brotli(head, body.into(), level));
186        }
187
188        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
189        if encoding == ContentCoding::ZSTD {
190            let (head, body) = resp.into_parts();
191            return Ok(zstd(head, body.into(), level));
192        }
193
194        tracing::trace!(
195            "no compression feature matched the preferred encoding, probably not enabled or unsupported"
196        );
197    }
198
199    Ok(resp)
200}
201
202/// Create a wrapping handler that compresses the Body of a [`Response`].
203/// using gzip, adding `content-encoding: gzip` to the Response's [`HeaderMap`].
204#[cfg(any(feature = "compression", feature = "compression-gzip"))]
205#[cfg_attr(
206    docsrs,
207    doc(cfg(any(feature = "compression", feature = "compression-gzip")))
208)]
209pub fn gzip(
210    mut head: http::response::Parts,
211    body: CompressableBody<Body, hyper::Error>,
212    level: CompressionLevel,
213) -> Response<Body> {
214    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
215
216    tracing::trace!("compressing response body on the fly using GZIP");
217
218    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
219    let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
220        StreamReader::new(body),
221        level,
222    )));
223    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
224    head.headers.remove(CONTENT_LENGTH);
225    head.headers.insert(CONTENT_ENCODING, header);
226    Response::from_parts(head, body)
227}
228
229/// Create a wrapping handler that compresses the Body of a [`Response`].
230/// using deflate, adding `content-encoding: deflate` to the Response's [`HeaderMap`].
231#[cfg(any(feature = "compression", feature = "compression-deflate"))]
232#[cfg_attr(
233    docsrs,
234    doc(cfg(any(feature = "compression", feature = "compression-deflate")))
235)]
236pub fn deflate(
237    mut head: http::response::Parts,
238    body: CompressableBody<Body, hyper::Error>,
239    level: CompressionLevel,
240) -> Response<Body> {
241    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
242
243    tracing::trace!("compressing response body on the fly using DEFLATE");
244
245    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
246    let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
247        StreamReader::new(body),
248        level,
249    )));
250    let header = create_encoding_header(
251        head.headers.remove(CONTENT_ENCODING),
252        ContentCoding::DEFLATE,
253    );
254    head.headers.remove(CONTENT_LENGTH);
255    head.headers.insert(CONTENT_ENCODING, header);
256    Response::from_parts(head, body)
257}
258
259/// Create a wrapping handler that compresses the Body of a [`Response`].
260/// using brotli, adding `content-encoding: br` to the Response's [`HeaderMap`].
261#[cfg(any(feature = "compression", feature = "compression-brotli"))]
262#[cfg_attr(
263    docsrs,
264    doc(cfg(any(feature = "compression", feature = "compression-brotli")))
265)]
266pub fn brotli(
267    mut head: http::response::Parts,
268    body: CompressableBody<Body, hyper::Error>,
269    level: CompressionLevel,
270) -> Response<Body> {
271    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
272
273    tracing::trace!("compressing response body on the fly using BROTLI");
274
275    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
276    let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
277        StreamReader::new(body),
278        level,
279    )));
280    let header =
281        create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
282    head.headers.remove(CONTENT_LENGTH);
283    head.headers.insert(CONTENT_ENCODING, header);
284    Response::from_parts(head, body)
285}
286
287/// Create a wrapping handler that compresses the Body of a [`Response`].
288/// using zstd, adding `content-encoding: zstd` to the Response's [`HeaderMap`].
289#[cfg(any(feature = "compression", feature = "compression-zstd"))]
290#[cfg_attr(
291    docsrs,
292    doc(cfg(any(feature = "compression", feature = "compression-zstd")))
293)]
294pub fn zstd(
295    mut head: http::response::Parts,
296    body: CompressableBody<Body, hyper::Error>,
297    level: CompressionLevel,
298) -> Response<Body> {
299    const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
300
301    tracing::trace!("compressing response body on the fly using ZSTD");
302
303    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
304    let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
305        StreamReader::new(body),
306        level,
307    )));
308    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
309    head.headers.remove(CONTENT_LENGTH);
310    head.headers.insert(CONTENT_ENCODING, header);
311    Response::from_parts(head, body)
312}
313
314/// Given an optional existing encoding header, appends to the existing or creates a new one.
315pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
316    if let Some(val) = existing
317        && let Ok(str_val) = val.to_str()
318    {
319        return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
320            .unwrap_or_else(|_| coding.into());
321    }
322    coding.into()
323}
324
325/// Try to get the preferred `content-encoding` via the `accept-encoding` header.
326#[inline(always)]
327pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
328    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
329        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
330
331        for encoding in accept_encoding.sorted_encodings() {
332            if AVAILABLE_ENCODINGS.contains(&encoding) {
333                return Some(encoding);
334            }
335        }
336    }
337    None
338}
339
340/// Get the `content-encodings` via the `accept-encoding` header.
341#[inline(always)]
342pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
343    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
344        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
345
346        return accept_encoding
347            .sorted_encodings()
348            .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
349            .collect::<Vec<_>>();
350    }
351    vec![]
352}
353
354/// A wrapper around any type that implements [`Stream`](futures_util::Stream) to be
355/// compatible with async_compression's `Stream` based encoders.
356#[pin_project]
357#[derive(Debug)]
358pub struct CompressableBody<S, E>
359where
360    S: Stream<Item = Result<Bytes, E>>,
361    E: std::error::Error,
362{
363    #[pin]
364    body: S,
365}
366
367impl<S, E> Stream for CompressableBody<S, E>
368where
369    S: Stream<Item = Result<Bytes, E>>,
370    E: std::error::Error,
371{
372    type Item = std::io::Result<Bytes>;
373
374    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
375        use std::io::{Error, ErrorKind};
376
377        let pin = self.project();
378        S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
379    }
380}
381
382impl From<Body> for CompressableBody<Body, hyper::Error> {
383    #[inline(always)]
384    fn from(body: Body) -> Self {
385        CompressableBody { body }
386    }
387}
388
389#[cfg(test)]
390#[cfg(any(feature = "compression", feature = "compression-gzip"))]
391mod tests {
392    use super::*;
393    use http::header::{ACCEPT_ENCODING, CONTENT_TYPE};
394
395    fn text_response_with_size(size: usize) -> Response<Body> {
396        let mut resp = Response::new(Body::from(vec![b'x'; size]));
397        resp.headers_mut()
398            .insert(CONTENT_TYPE, "text/html".parse().unwrap());
399        resp.headers_mut()
400            .insert(CONTENT_LENGTH, size.to_string().parse().unwrap());
401        resp
402    }
403
404    fn text_response_without_length() -> Response<Body> {
405        let mut resp = Response::new(Body::from("hello world"));
406        resp.headers_mut()
407            .insert(CONTENT_TYPE, "text/html".parse().unwrap());
408        resp
409    }
410
411    fn accept_gzip_headers() -> HeaderMap<HeaderValue> {
412        let mut headers = HeaderMap::new();
413        headers.insert(ACCEPT_ENCODING, "gzip".parse().unwrap());
414        headers
415    }
416
417    #[test]
418    fn small_response_below_threshold_is_not_compressed() {
419        let resp = text_response_with_size(MIN_COMPRESS_SIZE - 1);
420        let headers = accept_gzip_headers();
421        let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
422
423        assert!(
424            result.headers().get(CONTENT_ENCODING).is_none(),
425            "responses below {MIN_COMPRESS_SIZE} bytes must not be compressed"
426        );
427    }
428
429    #[test]
430    fn response_at_threshold_is_compressed() {
431        let resp = text_response_with_size(MIN_COMPRESS_SIZE);
432        let headers = accept_gzip_headers();
433        let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
434
435        assert!(
436            result.headers().get(CONTENT_ENCODING).is_some(),
437            "responses at exactly {MIN_COMPRESS_SIZE} bytes must be compressed"
438        );
439    }
440
441    #[test]
442    fn response_above_threshold_is_compressed() {
443        let resp = text_response_with_size(MIN_COMPRESS_SIZE + 1);
444        let headers = accept_gzip_headers();
445        let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
446
447        assert!(
448            result.headers().get(CONTENT_ENCODING).is_some(),
449            "responses above {MIN_COMPRESS_SIZE} bytes must be compressed"
450        );
451    }
452
453    #[test]
454    fn response_without_content_length_is_compressed() {
455        let resp = text_response_without_length();
456        let headers = accept_gzip_headers();
457        let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
458
459        assert!(
460            result.headers().get(CONTENT_ENCODING).is_some(),
461            "responses without Content-Length must still be compressed"
462        );
463    }
464
465    #[test]
466    fn small_response_head_method_is_not_compressed() {
467        let resp = text_response_with_size(MIN_COMPRESS_SIZE - 1);
468        let headers = accept_gzip_headers();
469        let result = auto(&Method::HEAD, &headers, CompressionLevel::Default, resp).unwrap();
470
471        assert!(
472            result.headers().get(CONTENT_ENCODING).is_none(),
473            "HEAD requests are never compressed regardless of size"
474        );
475    }
476
477    #[test]
478    fn non_compressible_content_type_is_not_compressed() {
479        let mut resp = Response::new(Body::from(vec![b'x'; MIN_COMPRESS_SIZE + 100]));
480        resp.headers_mut()
481            .insert(CONTENT_TYPE, "image/png".parse().unwrap());
482        resp.headers_mut().insert(
483            CONTENT_LENGTH,
484            (MIN_COMPRESS_SIZE + 100).to_string().parse().unwrap(),
485        );
486        let headers = accept_gzip_headers();
487        let result = auto(&Method::GET, &headers, CompressionLevel::Default, resp).unwrap();
488
489        assert!(
490            result.headers().get(CONTENT_ENCODING).is_none(),
491            "non-compressible content-types are never compressed"
492        );
493    }
494}