viz_core/middleware/
compression.rs

1//! Compression Middleware.
2
3use std::str::FromStr;
4
5use async_compression::tokio::bufread;
6use tokio_util::io::{ReaderStream, StreamReader};
7
8use crate::{
9    header::{HeaderValue, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH},
10    Body, Handler, IntoResponse, Request, Response, Result, Transform,
11};
12
13/// Compress response body.
14#[derive(Debug)]
15pub struct Config;
16
17impl<H> Transform<H> for Config
18where
19    H: Clone,
20{
21    type Output = CompressionMiddleware<H>;
22
23    fn transform(&self, h: H) -> Self::Output {
24        CompressionMiddleware { h }
25    }
26}
27
28/// Compression middleware.
29#[derive(Clone, Debug)]
30pub struct CompressionMiddleware<H> {
31    h: H,
32}
33
34#[crate::async_trait]
35impl<H, O> Handler<Request> for CompressionMiddleware<H>
36where
37    H: Handler<Request, Output = Result<O>>,
38    O: IntoResponse,
39{
40    type Output = Result<Response>;
41
42    async fn call(&self, req: Request) -> Self::Output {
43        let accept_encoding = req
44            .headers()
45            .get(ACCEPT_ENCODING)
46            .map(HeaderValue::to_str)
47            .and_then(Result::ok)
48            .and_then(parse_accept_encoding);
49
50        let raw = self.h.call(req).await?;
51
52        Ok(match accept_encoding {
53            Some(algo) => Compress::new(raw, algo).into_response(),
54            None => raw.into_response(),
55        })
56    }
57}
58
59/// Compresses the response body with the specified algorithm
60/// and sets the `Content-Encoding` header.
61#[derive(Debug)]
62pub struct Compress<T> {
63    inner: T,
64    algo: ContentCoding,
65}
66
67impl<T> Compress<T> {
68    /// Creates a compressed response with the specified algorithm.
69    pub const fn new(inner: T, algo: ContentCoding) -> Self {
70        Self { inner, algo }
71    }
72}
73
74impl<T: IntoResponse> IntoResponse for Compress<T> {
75    fn into_response(self) -> Response {
76        let mut res = self.inner.into_response();
77
78        match self.algo {
79            ContentCoding::Gzip | ContentCoding::Deflate | ContentCoding::Brotli => {
80                res = res.map(|body| {
81                    let body = StreamReader::new(body);
82                    if self.algo == ContentCoding::Gzip {
83                        Body::from_stream(ReaderStream::new(bufread::GzipEncoder::new(body)))
84                    } else if self.algo == ContentCoding::Deflate {
85                        Body::from_stream(ReaderStream::new(bufread::DeflateEncoder::new(body)))
86                    } else {
87                        Body::from_stream(ReaderStream::new(bufread::BrotliEncoder::new(body)))
88                    }
89                });
90                res.headers_mut()
91                    .append(CONTENT_ENCODING, HeaderValue::from_static(self.algo.into()));
92                res.headers_mut().remove(CONTENT_LENGTH);
93                res
94            }
95            ContentCoding::Any => res,
96        }
97    }
98}
99
100/// [`ContentCoding`]
101///
102/// [`ContentCoding`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
103#[derive(Debug, PartialEq, Eq)]
104pub enum ContentCoding {
105    /// gzip
106    Gzip,
107    /// deflate
108    Deflate,
109    /// brotli
110    Brotli,
111    /// *
112    Any,
113}
114
115impl FromStr for ContentCoding {
116    type Err = ();
117
118    fn from_str(s: &str) -> Result<Self, Self::Err> {
119        if s.eq_ignore_ascii_case("deflate") {
120            Ok(Self::Deflate)
121        } else if s.eq_ignore_ascii_case("gzip") {
122            Ok(Self::Gzip)
123        } else if s.eq_ignore_ascii_case("br") {
124            Ok(Self::Brotli)
125        } else if s == "*" {
126            Ok(Self::Any)
127        } else {
128            Err(())
129        }
130    }
131}
132
133impl From<ContentCoding> for &'static str {
134    fn from(cc: ContentCoding) -> Self {
135        match cc {
136            ContentCoding::Gzip => "gzip",
137            ContentCoding::Deflate => "deflate",
138            ContentCoding::Brotli => "br",
139            ContentCoding::Any => "*",
140        }
141    }
142}
143
144#[allow(clippy::cast_sign_loss)]
145#[allow(clippy::cast_possible_truncation)]
146fn parse_accept_encoding(s: &str) -> Option<ContentCoding> {
147    s.split(',')
148        .map(str::trim)
149        .filter_map(|v| match v.split_once(";q=") {
150            None => v.parse::<ContentCoding>().ok().map(|c| (c, 100)),
151            Some((c, q)) => Some((
152                c.parse::<ContentCoding>().ok()?,
153                q.parse::<f32>()
154                    .ok()
155                    .filter(|v| *v >= 0. && *v <= 1.)
156                    .map(|v| (v * 100.) as u8)?,
157            )),
158        })
159        .max_by_key(|(_, q)| *q)
160        .map(|(c, _)| c)
161}