static_web_server/
compression.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! Auto-compression module to compress responses body.
7//!
8
9// Part of the file is borrowed from <https://github.com/seanmonstar/warp/pull/513>*
10
11#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24    Body, Method, Request, Response, StatusCode,
25    header::{CONTENT_ENCODING, CONTENT_LENGTH},
26};
27use mime_guess::{Mime, mime};
28use pin_project::pin_project;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use tokio_util::io::{ReaderStream, StreamReader};
32
33use crate::{
34    Error, Result, error_page,
35    handler::RequestHandlerOpts,
36    headers_ext::{AcceptEncoding, ContentCoding},
37    http_ext::MethodExt,
38    settings::CompressionLevel,
39};
40
41/// Contains a fixed list of common text-based MIME types that aren't recognizable in a generic way.
42const TEXT_MIME_TYPES: [&str; 8] = [
43    "application/rtf",
44    "application/javascript",
45    "application/json",
46    "application/xml",
47    "font/ttf",
48    "application/font-sfnt",
49    "application/vnd.ms-fontobject",
50    "application/wasm",
51];
52
53/// List of encodings that can be handled given enabled features.
54const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
55    #[cfg(any(feature = "compression", feature = "compression-deflate"))]
56    ContentCoding::DEFLATE,
57    #[cfg(any(feature = "compression", feature = "compression-gzip"))]
58    ContentCoding::GZIP,
59    #[cfg(any(feature = "compression", feature = "compression-brotli"))]
60    ContentCoding::BROTLI,
61    #[cfg(any(feature = "compression", feature = "compression-zstd"))]
62    ContentCoding::ZSTD,
63];
64
65/// Initializes dynamic compression.
66pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
67    handler_opts.compression = enabled;
68    handler_opts.compression_level = level;
69
70    const FORMATS: &[&str] = &[
71        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
72        "deflate",
73        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
74        "gzip",
75        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
76        "brotli",
77        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
78        "zstd",
79    ];
80    tracing::info!(
81        "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
82        FORMATS.join(",")
83    );
84}
85
86/// Post-processing to dynamically compress the response if necessary.
87pub(crate) fn post_process<T>(
88    opts: &RequestHandlerOpts,
89    req: &Request<T>,
90    mut resp: Response<Body>,
91) -> Result<Response<Body>, Error> {
92    if !opts.compression {
93        return Ok(resp);
94    }
95
96    let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
97    if is_precompressed {
98        return Ok(resp);
99    }
100
101    // Compression content encoding varies so use a `Vary` header
102    let value = resp.headers().get(hyper::header::VARY).map_or(
103        HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
104        |h| {
105            let mut s = h.to_str().unwrap_or_default().to_owned();
106            s.push(',');
107            s.push_str(hyper::header::ACCEPT_ENCODING.as_str());
108            HeaderValue::from_str(s.as_str()).unwrap()
109        },
110    );
111    resp.headers_mut().insert(hyper::header::VARY, value);
112
113    // Auto compression based on the `Accept-Encoding` header
114    match auto(req.method(), req.headers(), opts.compression_level, resp) {
115        Ok(resp) => Ok(resp),
116        Err(err) => {
117            tracing::error!("error during body compression: {:?}", err);
118            error_page::error_response(
119                req.uri(),
120                req.method(),
121                &StatusCode::INTERNAL_SERVER_ERROR,
122                &opts.page404,
123                &opts.page50x,
124            )
125        }
126    }
127}
128
129/// Create a wrapping handler that compresses the Body of a [`hyper::Response`]
130/// using gzip, `deflate`, `brotli` or `zstd` if is specified in the `Accept-Encoding` header, adding
131/// `content-encoding: <coding>` to the Response's [`HeaderMap`].
132/// It also provides the ability to apply compression for text-based MIME types only.
133pub fn auto(
134    method: &Method,
135    headers: &HeaderMap<HeaderValue>,
136    level: CompressionLevel,
137    resp: Response<Body>,
138) -> Result<Response<Body>> {
139    // Skip compression for HEAD and OPTIONS request methods
140    if method.is_head() || method.is_options() {
141        return Ok(resp);
142    }
143
144    // Compress response based on Accept-Encoding header
145    if let Some(encoding) = get_preferred_encoding(headers) {
146        tracing::trace!(
147            "preferred encoding selected from the accept-encoding header: {:?}",
148            encoding
149        );
150
151        // Skip compression for non-text-based MIME types
152        if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
153            if !is_text(Mime::from(content_type)) {
154                return Ok(resp);
155            }
156        }
157
158        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
159        if encoding == ContentCoding::GZIP {
160            let (head, body) = resp.into_parts();
161            return Ok(gzip(head, body.into(), level));
162        }
163
164        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
165        if encoding == ContentCoding::DEFLATE {
166            let (head, body) = resp.into_parts();
167            return Ok(deflate(head, body.into(), level));
168        }
169
170        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
171        if encoding == ContentCoding::BROTLI {
172            let (head, body) = resp.into_parts();
173            return Ok(brotli(head, body.into(), level));
174        }
175
176        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
177        if encoding == ContentCoding::ZSTD {
178            let (head, body) = resp.into_parts();
179            return Ok(zstd(head, body.into(), level));
180        }
181
182        tracing::trace!(
183            "no compression feature matched the preferred encoding, probably not enabled or unsupported"
184        );
185    }
186
187    Ok(resp)
188}
189
190/// Checks whether the MIME type corresponds to any of the known text types.
191fn is_text(mime: Mime) -> bool {
192    mime.type_() == mime::TEXT
193        || mime
194            .suffix()
195            .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
196        || TEXT_MIME_TYPES.contains(&mime.essence_str())
197}
198
199/// Create a wrapping handler that compresses the Body of a [`Response`].
200/// using gzip, adding `content-encoding: gzip` to the Response's [`HeaderMap`].
201#[cfg(any(feature = "compression", feature = "compression-gzip"))]
202#[cfg_attr(
203    docsrs,
204    doc(cfg(any(feature = "compression", feature = "compression-gzip")))
205)]
206pub fn gzip(
207    mut head: http::response::Parts,
208    body: CompressableBody<Body, hyper::Error>,
209    level: CompressionLevel,
210) -> Response<Body> {
211    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
212
213    tracing::trace!("compressing response body on the fly using GZIP");
214
215    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
216    let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
217        StreamReader::new(body),
218        level,
219    )));
220    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
221    head.headers.remove(CONTENT_LENGTH);
222    head.headers.insert(CONTENT_ENCODING, header);
223    Response::from_parts(head, body)
224}
225
226/// Create a wrapping handler that compresses the Body of a [`Response`].
227/// using deflate, adding `content-encoding: deflate` to the Response's [`HeaderMap`].
228#[cfg(any(feature = "compression", feature = "compression-deflate"))]
229#[cfg_attr(
230    docsrs,
231    doc(cfg(any(feature = "compression", feature = "compression-deflate")))
232)]
233pub fn deflate(
234    mut head: http::response::Parts,
235    body: CompressableBody<Body, hyper::Error>,
236    level: CompressionLevel,
237) -> Response<Body> {
238    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
239
240    tracing::trace!("compressing response body on the fly using DEFLATE");
241
242    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
243    let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
244        StreamReader::new(body),
245        level,
246    )));
247    let header = create_encoding_header(
248        head.headers.remove(CONTENT_ENCODING),
249        ContentCoding::DEFLATE,
250    );
251    head.headers.remove(CONTENT_LENGTH);
252    head.headers.insert(CONTENT_ENCODING, header);
253    Response::from_parts(head, body)
254}
255
256/// Create a wrapping handler that compresses the Body of a [`Response`].
257/// using brotli, adding `content-encoding: br` to the Response's [`HeaderMap`].
258#[cfg(any(feature = "compression", feature = "compression-brotli"))]
259#[cfg_attr(
260    docsrs,
261    doc(cfg(any(feature = "compression", feature = "compression-brotli")))
262)]
263pub fn brotli(
264    mut head: http::response::Parts,
265    body: CompressableBody<Body, hyper::Error>,
266    level: CompressionLevel,
267) -> Response<Body> {
268    const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
269
270    tracing::trace!("compressing response body on the fly using BROTLI");
271
272    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
273    let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
274        StreamReader::new(body),
275        level,
276    )));
277    let header =
278        create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
279    head.headers.remove(CONTENT_LENGTH);
280    head.headers.insert(CONTENT_ENCODING, header);
281    Response::from_parts(head, body)
282}
283
284/// Create a wrapping handler that compresses the Body of a [`Response`].
285/// using zstd, adding `content-encoding: zstd` to the Response's [`HeaderMap`].
286#[cfg(any(feature = "compression", feature = "compression-zstd"))]
287#[cfg_attr(
288    docsrs,
289    doc(cfg(any(feature = "compression", feature = "compression-zstd")))
290)]
291pub fn zstd(
292    mut head: http::response::Parts,
293    body: CompressableBody<Body, hyper::Error>,
294    level: CompressionLevel,
295) -> Response<Body> {
296    const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
297
298    tracing::trace!("compressing response body on the fly using ZSTD");
299
300    let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
301    let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
302        StreamReader::new(body),
303        level,
304    )));
305    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
306    head.headers.remove(CONTENT_LENGTH);
307    head.headers.insert(CONTENT_ENCODING, header);
308    Response::from_parts(head, body)
309}
310
311/// Given an optional existing encoding header, appends to the existing or creates a new one.
312pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
313    if let Some(val) = existing {
314        if let Ok(str_val) = val.to_str() {
315            return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
316                .unwrap_or_else(|_| coding.into());
317        }
318    }
319    coding.into()
320}
321
322/// Try to get the preferred `content-encoding` via the `accept-encoding` header.
323#[inline(always)]
324pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
325    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
326        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
327
328        for encoding in accept_encoding.sorted_encodings() {
329            if AVAILABLE_ENCODINGS.contains(&encoding) {
330                return Some(encoding);
331            }
332        }
333    }
334    None
335}
336
337/// Get the `content-encodings` via the `accept-encoding` header.
338#[inline(always)]
339pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
340    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
341        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
342
343        return accept_encoding
344            .sorted_encodings()
345            .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
346            .collect::<Vec<_>>();
347    }
348    vec![]
349}
350
351/// A wrapper around any type that implements [`Stream`](futures_util::Stream) to be
352/// compatible with async_compression's `Stream` based encoders.
353#[pin_project]
354#[derive(Debug)]
355pub struct CompressableBody<S, E>
356where
357    S: Stream<Item = Result<Bytes, E>>,
358    E: std::error::Error,
359{
360    #[pin]
361    body: S,
362}
363
364impl<S, E> Stream for CompressableBody<S, E>
365where
366    S: Stream<Item = Result<Bytes, E>>,
367    E: std::error::Error,
368{
369    type Item = std::io::Result<Bytes>;
370
371    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
372        use std::io::{Error, ErrorKind};
373
374        let pin = self.project();
375        S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
376    }
377}
378
379impl From<Body> for CompressableBody<Body, hyper::Error> {
380    #[inline(always)]
381    fn from(body: Body) -> Self {
382        CompressableBody { body }
383    }
384}