static_web_server/
compression.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! Auto-compression module to compress responses body.
7//!
8
9// Part of the file is borrowed from <https://github.com/seanmonstar/warp/pull/513>*
10
11#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24    header::{CONTENT_ENCODING, CONTENT_LENGTH},
25    Body, Method, Request, Response, StatusCode,
26};
27use lazy_static::lazy_static;
28use mime_guess::{mime, Mime};
29use pin_project::pin_project;
30use std::collections::HashSet;
31use std::pin::Pin;
32use std::task::{Context, Poll};
33use tokio_util::io::{ReaderStream, StreamReader};
34
35use crate::{
36    error_page,
37    handler::RequestHandlerOpts,
38    headers_ext::{AcceptEncoding, ContentCoding},
39    http_ext::MethodExt,
40    settings::CompressionLevel,
41    Error, Result,
42};
43
44lazy_static! {
45    /// Contains a fixed list of common text-based MIME types that aren't recognizable in a generic way.
46    static ref TEXT_MIME_TYPES: HashSet<&'static str> = [
47        "application/rtf",
48        "application/javascript",
49        "application/json",
50        "application/xml",
51        "font/ttf",
52        "application/font-sfnt",
53        "application/vnd.ms-fontobject",
54        "application/wasm",
55    ].into_iter().collect();
56}
57
58/// List of encodings that can be handled given enabled features.
59const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
60    #[cfg(any(feature = "compression", feature = "compression-deflate"))]
61    ContentCoding::DEFLATE,
62    #[cfg(any(feature = "compression", feature = "compression-gzip"))]
63    ContentCoding::GZIP,
64    #[cfg(any(feature = "compression", feature = "compression-brotli"))]
65    ContentCoding::BROTLI,
66    #[cfg(any(feature = "compression", feature = "compression-zstd"))]
67    ContentCoding::ZSTD,
68];
69
70/// Initializes dynamic compression.
71pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
72    handler_opts.compression = enabled;
73    handler_opts.compression_level = level;
74
75    const FORMATS: &[&str] = &[
76        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
77        "deflate",
78        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
79        "gzip",
80        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
81        "brotli",
82        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
83        "zstd",
84    ];
85    tracing::info!(
86        "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
87        FORMATS.join(",")
88    );
89}
90
91/// Post-processing to dynamically compress the response if necessary.
92pub(crate) fn post_process<T>(
93    opts: &RequestHandlerOpts,
94    req: &Request<T>,
95    mut resp: Response<Body>,
96) -> Result<Response<Body>, Error> {
97    if !opts.compression {
98        return Ok(resp);
99    }
100
101    let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
102    if is_precompressed {
103        return Ok(resp);
104    }
105
106    // Compression content encoding varies so use a `Vary` header
107    let value = resp.headers().get(hyper::header::VARY).map_or(
108        HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
109        |h| {
110            let mut s = h.to_str().unwrap_or_default().to_owned();
111            s.push(',');
112            s.push_str(hyper::header::ACCEPT_ENCODING.as_str());
113            HeaderValue::from_str(s.as_str()).unwrap()
114        },
115    );
116    resp.headers_mut().insert(hyper::header::VARY, value);
117
118    // Auto compression based on the `Accept-Encoding` header
119    match auto(req.method(), req.headers(), opts.compression_level, resp) {
120        Ok(resp) => Ok(resp),
121        Err(err) => {
122            tracing::error!("error during body compression: {:?}", err);
123            error_page::error_response(
124                req.uri(),
125                req.method(),
126                &StatusCode::INTERNAL_SERVER_ERROR,
127                &opts.page404,
128                &opts.page50x,
129            )
130        }
131    }
132}
133
134/// Create a wrapping handler that compresses the Body of a [`hyper::Response`]
135/// using gzip, `deflate`, `brotli` or `zstd` if is specified in the `Accept-Encoding` header, adding
136/// `content-encoding: <coding>` to the Response's [`HeaderMap`].
137/// It also provides the ability to apply compression for text-based MIME types only.
138pub fn auto(
139    method: &Method,
140    headers: &HeaderMap<HeaderValue>,
141    level: CompressionLevel,
142    resp: Response<Body>,
143) -> Result<Response<Body>> {
144    // Skip compression for HEAD and OPTIONS request methods
145    if method.is_head() || method.is_options() {
146        return Ok(resp);
147    }
148
149    // Compress response based on Accept-Encoding header
150    if let Some(encoding) = get_preferred_encoding(headers) {
151        tracing::trace!(
152            "preferred encoding selected from the accept-encoding header: {:?}",
153            encoding
154        );
155
156        // Skip compression for non-text-based MIME types
157        if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
158            if !is_text(Mime::from(content_type)) {
159                return Ok(resp);
160            }
161        }
162
163        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
164        if encoding == ContentCoding::GZIP {
165            let (head, body) = resp.into_parts();
166            return Ok(gzip(head, body.into(), level));
167        }
168
169        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
170        if encoding == ContentCoding::DEFLATE {
171            let (head, body) = resp.into_parts();
172            return Ok(deflate(head, body.into(), level));
173        }
174
175        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
176        if encoding == ContentCoding::BROTLI {
177            let (head, body) = resp.into_parts();
178            return Ok(brotli(head, body.into(), level));
179        }
180
181        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
182        if encoding == ContentCoding::ZSTD {
183            let (head, body) = resp.into_parts();
184            return Ok(zstd(head, body.into(), level));
185        }
186
187        tracing::trace!("no compression feature matched the preferred encoding, probably not enabled or unsupported");
188    }
189
190    Ok(resp)
191}
192
193/// Checks whether the MIME type corresponds to any of the known text types.
194fn is_text(mime: Mime) -> bool {
195    mime.type_() == mime::TEXT
196        || mime
197            .suffix()
198            .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
199        || TEXT_MIME_TYPES.contains(mime.essence_str())
200}
201
202/// Create a wrapping handler that compresses the Body of a [`Response`].
203/// using gzip, adding `content-encoding: gzip` to the Response's [`HeaderMap`].
204#[cfg(any(feature = "compression", feature = "compression-gzip"))]
205#[cfg_attr(
206    docsrs,
207    doc(cfg(any(feature = "compression", feature = "compression-gzip")))
208)]
209pub fn gzip(
210    mut head: http::response::Parts,
211    body: CompressableBody<Body, hyper::Error>,
212    level: CompressionLevel,
213) -> Response<Body> {
214    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
215
216    tracing::trace!("compressing response body on the fly using GZIP");
217
218    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
219    let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
220        StreamReader::new(body),
221        level,
222    )));
223    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
224    head.headers.remove(CONTENT_LENGTH);
225    head.headers.insert(CONTENT_ENCODING, header);
226    Response::from_parts(head, body)
227}
228
229/// Create a wrapping handler that compresses the Body of a [`Response`].
230/// using deflate, adding `content-encoding: deflate` to the Response's [`HeaderMap`].
231#[cfg(any(feature = "compression", feature = "compression-deflate"))]
232#[cfg_attr(
233    docsrs,
234    doc(cfg(any(feature = "compression", feature = "compression-deflate")))
235)]
236pub fn deflate(
237    mut head: http::response::Parts,
238    body: CompressableBody<Body, hyper::Error>,
239    level: CompressionLevel,
240) -> Response<Body> {
241    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
242
243    tracing::trace!("compressing response body on the fly using DEFLATE");
244
245    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
246    let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
247        StreamReader::new(body),
248        level,
249    )));
250    let header = create_encoding_header(
251        head.headers.remove(CONTENT_ENCODING),
252        ContentCoding::DEFLATE,
253    );
254    head.headers.remove(CONTENT_LENGTH);
255    head.headers.insert(CONTENT_ENCODING, header);
256    Response::from_parts(head, body)
257}
258
259/// Create a wrapping handler that compresses the Body of a [`Response`].
260/// using brotli, adding `content-encoding: br` to the Response's [`HeaderMap`].
261#[cfg(any(feature = "compression", feature = "compression-brotli"))]
262#[cfg_attr(
263    docsrs,
264    doc(cfg(any(feature = "compression", feature = "compression-brotli")))
265)]
266pub fn brotli(
267    mut head: http::response::Parts,
268    body: CompressableBody<Body, hyper::Error>,
269    level: CompressionLevel,
270) -> Response<Body> {
271    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
272
273    tracing::trace!("compressing response body on the fly using BROTLI");
274
275    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
276    let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
277        StreamReader::new(body),
278        level,
279    )));
280    let header =
281        create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
282    head.headers.remove(CONTENT_LENGTH);
283    head.headers.insert(CONTENT_ENCODING, header);
284    Response::from_parts(head, body)
285}
286
287/// Create a wrapping handler that compresses the Body of a [`Response`].
288/// using zstd, adding `content-encoding: zstd` to the Response's [`HeaderMap`].
289#[cfg(any(feature = "compression", feature = "compression-zstd"))]
290#[cfg_attr(
291    docsrs,
292    doc(cfg(any(feature = "compression", feature = "compression-zstd")))
293)]
294pub fn zstd(
295    mut head: http::response::Parts,
296    body: CompressableBody<Body, hyper::Error>,
297    level: CompressionLevel,
298) -> Response<Body> {
299    const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
300
301    tracing::trace!("compressing response body on the fly using ZSTD");
302
303    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
304    let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
305        StreamReader::new(body),
306        level,
307    )));
308    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
309    head.headers.remove(CONTENT_LENGTH);
310    head.headers.insert(CONTENT_ENCODING, header);
311    Response::from_parts(head, body)
312}
313
314/// Given an optional existing encoding header, appends to the existing or creates a new one.
315pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
316    if let Some(val) = existing {
317        if let Ok(str_val) = val.to_str() {
318            return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
319                .unwrap_or_else(|_| coding.into());
320        }
321    }
322    coding.into()
323}
324
325/// Try to get the preferred `content-encoding` via the `accept-encoding` header.
326#[inline(always)]
327pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
328    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
329        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
330
331        for encoding in accept_encoding.sorted_encodings() {
332            if AVAILABLE_ENCODINGS.contains(&encoding) {
333                return Some(encoding);
334            }
335        }
336    }
337    None
338}
339
340/// Get the `content-encodings` via the `accept-encoding` header.
341#[inline(always)]
342pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
343    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
344        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
345
346        return accept_encoding
347            .sorted_encodings()
348            .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
349            .collect::<Vec<_>>();
350    }
351    vec![]
352}
353
354/// A wrapper around any type that implements [`Stream`](futures_util::Stream) to be
355/// compatible with async_compression's `Stream` based encoders.
356#[pin_project]
357#[derive(Debug)]
358pub struct CompressableBody<S, E>
359where
360    S: Stream<Item = Result<Bytes, E>>,
361    E: std::error::Error,
362{
363    #[pin]
364    body: S,
365}
366
367impl<S, E> Stream for CompressableBody<S, E>
368where
369    S: Stream<Item = Result<Bytes, E>>,
370    E: std::error::Error,
371{
372    type Item = std::io::Result<Bytes>;
373
374    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
375        use std::io::{Error, ErrorKind};
376
377        let pin = self.project();
378        S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
379    }
380}
381
382impl From<Body> for CompressableBody<Body, hyper::Error> {
383    #[inline(always)]
384    fn from(body: Body) -> Self {
385        CompressableBody { body }
386    }
387}