static_web_server/
compression.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! Auto-compression module to compress responses body.
7//!
8
9// Part of the file is borrowed from <https://github.com/seanmonstar/warp/pull/513>*
10
11#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24    header::{CONTENT_ENCODING, CONTENT_LENGTH},
25    Body, Method, Request, Response, StatusCode,
26};
27use lazy_static::lazy_static;
28use mime_guess::{mime, Mime};
29use pin_project::pin_project;
30use std::collections::HashSet;
31use std::pin::Pin;
32use std::task::{Context, Poll};
33use tokio_util::io::{ReaderStream, StreamReader};
34
35use crate::{
36    error_page,
37    handler::RequestHandlerOpts,
38    headers_ext::{AcceptEncoding, ContentCoding},
39    http_ext::MethodExt,
40    settings::CompressionLevel,
41    Error, Result,
42};
43
44lazy_static! {
45    /// Contains a fixed list of common text-based MIME types that aren't recognizable in a generic way.
46    static ref TEXT_MIME_TYPES: HashSet<&'static str> = [
47        "application/rtf",
48        "application/javascript",
49        "application/json",
50        "application/xml",
51        "font/ttf",
52        "application/font-sfnt",
53        "application/vnd.ms-fontobject",
54        "application/wasm",
55    ].into_iter().collect();
56}
57
58/// List of encodings that can be handled given enabled features.
59const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
60    #[cfg(any(feature = "compression", feature = "compression-deflate"))]
61    ContentCoding::DEFLATE,
62    #[cfg(any(feature = "compression", feature = "compression-gzip"))]
63    ContentCoding::GZIP,
64    #[cfg(any(feature = "compression", feature = "compression-brotli"))]
65    ContentCoding::BROTLI,
66    #[cfg(any(feature = "compression", feature = "compression-zstd"))]
67    ContentCoding::ZSTD,
68];
69
70/// Initializes dynamic compression.
71pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
72    handler_opts.compression = enabled;
73    handler_opts.compression_level = level;
74
75    const FORMATS: &[&str] = &[
76        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
77        "deflate",
78        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
79        "gzip",
80        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
81        "brotli",
82        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
83        "zstd",
84    ];
85    server_info!(
86        "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
87        FORMATS.join(",")
88    );
89}
90
91/// Post-processing to dynamically compress the response if necessary.
92pub(crate) fn post_process<T>(
93    opts: &RequestHandlerOpts,
94    req: &Request<T>,
95    mut resp: Response<Body>,
96) -> Result<Response<Body>, Error> {
97    if !opts.compression {
98        return Ok(resp);
99    }
100
101    let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
102    if is_precompressed {
103        return Ok(resp);
104    }
105
106    // Compression content encoding varies so use a `Vary` header
107    resp.headers_mut().insert(
108        hyper::header::VARY,
109        HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
110    );
111
112    // Auto compression based on the `Accept-Encoding` header
113    match auto(req.method(), req.headers(), opts.compression_level, resp) {
114        Ok(resp) => Ok(resp),
115        Err(err) => {
116            tracing::error!("error during body compression: {:?}", err);
117            error_page::error_response(
118                req.uri(),
119                req.method(),
120                &StatusCode::INTERNAL_SERVER_ERROR,
121                &opts.page404,
122                &opts.page50x,
123            )
124        }
125    }
126}
127
128/// Create a wrapping handler that compresses the Body of a [`hyper::Response`]
129/// using gzip, `deflate`, `brotli` or `zstd` if is specified in the `Accept-Encoding` header, adding
130/// `content-encoding: <coding>` to the Response's [`HeaderMap`].
131/// It also provides the ability to apply compression for text-based MIME types only.
132pub fn auto(
133    method: &Method,
134    headers: &HeaderMap<HeaderValue>,
135    level: CompressionLevel,
136    resp: Response<Body>,
137) -> Result<Response<Body>> {
138    // Skip compression for HEAD and OPTIONS request methods
139    if method.is_head() || method.is_options() {
140        return Ok(resp);
141    }
142
143    // Compress response based on Accept-Encoding header
144    if let Some(encoding) = get_preferred_encoding(headers) {
145        tracing::trace!(
146            "preferred encoding selected from the accept-encoding header: {:?}",
147            encoding
148        );
149
150        // Skip compression for non-text-based MIME types
151        if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
152            if !is_text(Mime::from(content_type)) {
153                return Ok(resp);
154            }
155        }
156
157        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
158        if encoding == ContentCoding::GZIP {
159            let (head, body) = resp.into_parts();
160            return Ok(gzip(head, body.into(), level));
161        }
162
163        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
164        if encoding == ContentCoding::DEFLATE {
165            let (head, body) = resp.into_parts();
166            return Ok(deflate(head, body.into(), level));
167        }
168
169        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
170        if encoding == ContentCoding::BROTLI {
171            let (head, body) = resp.into_parts();
172            return Ok(brotli(head, body.into(), level));
173        }
174
175        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
176        if encoding == ContentCoding::ZSTD {
177            let (head, body) = resp.into_parts();
178            return Ok(zstd(head, body.into(), level));
179        }
180
181        tracing::trace!("no compression feature matched the preferred encoding, probably not enabled or unsupported");
182    }
183
184    Ok(resp)
185}
186
187/// Checks whether the MIME type corresponds to any of the known text types.
188fn is_text(mime: Mime) -> bool {
189    mime.type_() == mime::TEXT
190        || mime
191            .suffix()
192            .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
193        || TEXT_MIME_TYPES.contains(mime.essence_str())
194}
195
196/// Create a wrapping handler that compresses the Body of a [`Response`].
197/// using gzip, adding `content-encoding: gzip` to the Response's [`HeaderMap`].
198#[cfg(any(feature = "compression", feature = "compression-gzip"))]
199#[cfg_attr(
200    docsrs,
201    doc(cfg(any(feature = "compression", feature = "compression-gzip")))
202)]
203pub fn gzip(
204    mut head: http::response::Parts,
205    body: CompressableBody<Body, hyper::Error>,
206    level: CompressionLevel,
207) -> Response<Body> {
208    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
209
210    tracing::trace!("compressing response body on the fly using GZIP");
211
212    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
213    let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
214        StreamReader::new(body),
215        level,
216    )));
217    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
218    head.headers.remove(CONTENT_LENGTH);
219    head.headers.insert(CONTENT_ENCODING, header);
220    Response::from_parts(head, body)
221}
222
223/// Create a wrapping handler that compresses the Body of a [`Response`].
224/// using deflate, adding `content-encoding: deflate` to the Response's [`HeaderMap`].
225#[cfg(any(feature = "compression", feature = "compression-deflate"))]
226#[cfg_attr(
227    docsrs,
228    doc(cfg(any(feature = "compression", feature = "compression-deflate")))
229)]
230pub fn deflate(
231    mut head: http::response::Parts,
232    body: CompressableBody<Body, hyper::Error>,
233    level: CompressionLevel,
234) -> Response<Body> {
235    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
236
237    tracing::trace!("compressing response body on the fly using DEFLATE");
238
239    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
240    let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
241        StreamReader::new(body),
242        level,
243    )));
244    let header = create_encoding_header(
245        head.headers.remove(CONTENT_ENCODING),
246        ContentCoding::DEFLATE,
247    );
248    head.headers.remove(CONTENT_LENGTH);
249    head.headers.insert(CONTENT_ENCODING, header);
250    Response::from_parts(head, body)
251}
252
253/// Create a wrapping handler that compresses the Body of a [`Response`].
254/// using brotli, adding `content-encoding: br` to the Response's [`HeaderMap`].
255#[cfg(any(feature = "compression", feature = "compression-brotli"))]
256#[cfg_attr(
257    docsrs,
258    doc(cfg(any(feature = "compression", feature = "compression-brotli")))
259)]
260pub fn brotli(
261    mut head: http::response::Parts,
262    body: CompressableBody<Body, hyper::Error>,
263    level: CompressionLevel,
264) -> Response<Body> {
265    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
266
267    tracing::trace!("compressing response body on the fly using BROTLI");
268
269    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
270    let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
271        StreamReader::new(body),
272        level,
273    )));
274    let header =
275        create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
276    head.headers.remove(CONTENT_LENGTH);
277    head.headers.insert(CONTENT_ENCODING, header);
278    Response::from_parts(head, body)
279}
280
281/// Create a wrapping handler that compresses the Body of a [`Response`].
282/// using zstd, adding `content-encoding: zstd` to the Response's [`HeaderMap`].
283#[cfg(any(feature = "compression", feature = "compression-zstd"))]
284#[cfg_attr(
285    docsrs,
286    doc(cfg(any(feature = "compression", feature = "compression-zstd")))
287)]
288pub fn zstd(
289    mut head: http::response::Parts,
290    body: CompressableBody<Body, hyper::Error>,
291    level: CompressionLevel,
292) -> Response<Body> {
293    const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
294
295    tracing::trace!("compressing response body on the fly using ZSTD");
296
297    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
298    let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
299        StreamReader::new(body),
300        level,
301    )));
302    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
303    head.headers.remove(CONTENT_LENGTH);
304    head.headers.insert(CONTENT_ENCODING, header);
305    Response::from_parts(head, body)
306}
307
308/// Given an optional existing encoding header, appends to the existing or creates a new one.
309pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
310    if let Some(val) = existing {
311        if let Ok(str_val) = val.to_str() {
312            return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
313                .unwrap_or_else(|_| coding.into());
314        }
315    }
316    coding.into()
317}
318
319/// Try to get the preferred `content-encoding` via the `accept-encoding` header.
320#[inline(always)]
321pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
322    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
323        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
324
325        for encoding in accept_encoding.sorted_encodings() {
326            if AVAILABLE_ENCODINGS.contains(&encoding) {
327                return Some(encoding);
328            }
329        }
330    }
331    None
332}
333
334/// Get the `content-encodings` via the `accept-encoding` header.
335#[inline(always)]
336pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
337    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
338        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
339
340        return accept_encoding
341            .sorted_encodings()
342            .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
343            .collect::<Vec<_>>();
344    }
345    vec![]
346}
347
348/// A wrapper around any type that implements [`Stream`](futures_util::Stream) to be
349/// compatible with async_compression's `Stream` based encoders.
350#[pin_project]
351#[derive(Debug)]
352pub struct CompressableBody<S, E>
353where
354    S: Stream<Item = Result<Bytes, E>>,
355    E: std::error::Error,
356{
357    #[pin]
358    body: S,
359}
360
361impl<S, E> Stream for CompressableBody<S, E>
362where
363    S: Stream<Item = Result<Bytes, E>>,
364    E: std::error::Error,
365{
366    type Item = std::io::Result<Bytes>;
367
368    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
369        use std::io::{Error, ErrorKind};
370
371        let pin = self.project();
372        S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
373    }
374}
375
376impl From<Body> for CompressableBody<Body, hyper::Error> {
377    #[inline(always)]
378    fn from(body: Body) -> Self {
379        CompressableBody { body }
380    }
381}