Skip to main content

tork_core/middleware/
compression.rs

1//! Response compression middleware.
2
3use std::io::Write;
4
5use bytes::Bytes;
6use flate2::write::GzEncoder;
7use flate2::Compression as CompressionLevel;
8use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, VARY};
9use http::HeaderValue;
10
11use crate::body::RespBody;
12use crate::constants::TEXT_EVENT_STREAM;
13use crate::error::Result;
14use crate::middleware::{DuplicatePolicy, Middleware, Next, Request};
15use crate::response::{into_body_bytes, Response};
16use crate::router::BoxFuture;
17
18/// Content-coding token for gzip.
19const GZIP: &str = "gzip";
20
21/// Default upper bound on a body eligible for compression (8 MiB).
22///
23/// Compressing buffers the whole body in memory and produces a second buffer for
24/// the gzip output, so a very large response can multiply peak memory per request.
25/// Bodies above this size are passed through uncompressed; when a `Content-Length`
26/// advertises the size, they are not even buffered.
27const DEFAULT_MAXIMUM_SIZE: usize = 8 * 1024 * 1024;
28
29/// Compresses response bodies when the client supports it.
30///
31/// When gzip is enabled, the client's `Accept-Encoding` includes gzip, the
32/// response has no existing `Content-Encoding`, and the body is between
33/// `minimum_size` and `maximum_size` bytes, the body is gzip-compressed and the
34/// relevant headers are set.
35pub struct Compression {
36    gzip: bool,
37    minimum_size: usize,
38    maximum_size: usize,
39}
40
41impl Compression {
42    /// Creates a compression middleware with no algorithm enabled yet.
43    pub fn new() -> Self {
44        Self {
45            gzip: false,
46            minimum_size: 0,
47            maximum_size: DEFAULT_MAXIMUM_SIZE,
48        }
49    }
50
51    /// Enables gzip compression.
52    pub fn gzip(mut self) -> Self {
53        self.gzip = true;
54        self
55    }
56
57    /// Sets the minimum body size (in bytes) eligible for compression.
58    pub fn minimum_size(mut self, bytes: usize) -> Self {
59        self.minimum_size = bytes;
60        self
61    }
62
63    /// Sets the maximum body size (in bytes) eligible for compression.
64    ///
65    /// Bodies larger than this are sent uncompressed to bound per-request memory;
66    /// when the response advertises a `Content-Length` over this limit, the body is
67    /// streamed through without being buffered. Use `usize::MAX` to lift the cap.
68    pub fn maximum_size(mut self, bytes: usize) -> Self {
69        self.maximum_size = bytes;
70        self
71    }
72}
73
74impl Default for Compression {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl Middleware for Compression {
81    fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
82        let gzip_enabled = self.gzip;
83        let minimum_size = self.minimum_size;
84        let maximum_size = self.maximum_size;
85        let accepts_gzip = request
86            .headers()
87            .get(ACCEPT_ENCODING)
88            .and_then(|value| value.to_str().ok())
89            .map(|value| value.to_ascii_lowercase().contains(GZIP))
90            .unwrap_or(false);
91
92        Box::pin(async move {
93            let mut response = next.run(request).await?;
94
95            // When gzip is enabled the same URL may be served compressed or not
96            // depending on the client's `Accept-Encoding`, so every eligible
97            // response must carry `Vary: Accept-Encoding` (not just the compressed
98            // one) or a cache could hand a compressed body to a client that did not
99            // ask for it.
100            if gzip_enabled && !is_event_stream(&response) {
101                append_vary_accept_encoding(response.headers_mut());
102            }
103
104            // Skip when gzip is off, unsupported, the body is already encoded, or
105            // the body is a stream (an event stream must not be buffered here, and
106            // streaming responses are not worth compressing frame by frame).
107            if !gzip_enabled
108                || !accepts_gzip
109                || response.headers().contains_key(CONTENT_ENCODING)
110                || is_event_stream(&response)
111            {
112                return Ok(response);
113            }
114
115            // If the response advertises a length over the cap, or is already below
116            // the minimum compression threshold, pass it through without buffering.
117            if let Some(length) = content_length(response.headers()) {
118                if length > maximum_size || length < minimum_size {
119                    return Ok(response);
120                }
121            }
122
123            let (mut parts, bytes) = into_body_bytes(response).await;
124            // Out of the eligible window: too small to be worth it, or large enough
125            // that compressing would add a second big buffer for little gain.
126            if bytes.len() < minimum_size || bytes.len() > maximum_size {
127                return Ok(Response::from_parts(parts, RespBody::new(bytes)));
128            }
129
130            match gzip(&bytes) {
131                Ok(compressed) => {
132                    parts
133                        .headers
134                        .insert(CONTENT_ENCODING, HeaderValue::from_static(GZIP));
135                    if let Ok(length) = HeaderValue::from_str(&compressed.len().to_string()) {
136                        parts.headers.insert(CONTENT_LENGTH, length);
137                    }
138                    Ok(Response::from_parts(
139                        parts,
140                        RespBody::new(Bytes::from(compressed)),
141                    ))
142                }
143                // On the unlikely encode failure, send the body uncompressed.
144                Err(_) => Ok(Response::from_parts(parts, RespBody::new(bytes))),
145            }
146        })
147    }
148
149    fn name(&self) -> &'static str {
150        "Compression"
151    }
152
153    fn duplicate_policy(&self) -> DuplicatePolicy {
154        DuplicatePolicy::Reject
155    }
156}
157
158/// Adds `Accept-Encoding` to the response's `Vary` header unless already listed.
159fn append_vary_accept_encoding(headers: &mut http::HeaderMap) {
160    let already_present = headers
161        .get_all(VARY)
162        .iter()
163        .filter_map(|value| value.to_str().ok())
164        .any(|value| value.to_ascii_lowercase().contains("accept-encoding"));
165    if !already_present {
166        headers.append(VARY, HeaderValue::from_static("Accept-Encoding"));
167    }
168}
169
170/// Parses the response's `Content-Length` header, if present and valid.
171fn content_length(headers: &http::HeaderMap) -> Option<usize> {
172    headers
173        .get(CONTENT_LENGTH)
174        .and_then(|value| value.to_str().ok())
175        .and_then(|value| value.trim().parse::<usize>().ok())
176}
177
178/// Reports whether the response is a Server-Sent Events stream.
179///
180/// Such a body is unbounded and must not be buffered for compression.
181fn is_event_stream(response: &Response) -> bool {
182    response
183        .headers()
184        .get(CONTENT_TYPE)
185        .and_then(|value| value.to_str().ok())
186        .map(|value| value.starts_with(TEXT_EVENT_STREAM))
187        .unwrap_or(false)
188}
189
190/// Gzip-compresses a byte slice.
191fn gzip(data: &[u8]) -> std::io::Result<Vec<u8>> {
192    let mut encoder = GzEncoder::new(
193        Vec::with_capacity(data.len() / 2 + 16),
194        CompressionLevel::default(),
195    );
196    encoder.write_all(data)?;
197    encoder.finish()
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    fn response_with_content_type(value: &'static str) -> Response {
205        let mut response = http::Response::new(RespBody::new(Bytes::new()));
206        response
207            .headers_mut()
208            .insert(CONTENT_TYPE, HeaderValue::from_static(value));
209        response
210    }
211
212    #[test]
213    fn event_stream_is_detected_and_bypasses_compression() {
214        assert!(is_event_stream(&response_with_content_type(
215            TEXT_EVENT_STREAM
216        )));
217        assert!(!is_event_stream(&response_with_content_type(
218            "application/json"
219        )));
220        // A response without a content type is not treated as an event stream.
221        assert!(!is_event_stream(&http::Response::new(RespBody::new(
222            Bytes::new()
223        ))));
224    }
225
226    #[test]
227    fn content_length_parses_only_valid_values() {
228        let mut headers = http::HeaderMap::new();
229        assert_eq!(content_length(&headers), None);
230        headers.insert(CONTENT_LENGTH, HeaderValue::from_static("1024"));
231        assert_eq!(content_length(&headers), Some(1024));
232        headers.insert(CONTENT_LENGTH, HeaderValue::from_static("not-a-number"));
233        assert_eq!(content_length(&headers), None);
234    }
235
236    #[test]
237    fn gzip_round_trips_through_flate2() {
238        let original = b"hello world, this is a test that compresses well. ".repeat(20);
239        let compressed = gzip(&original).expect("gzip must succeed");
240        // Compressed data should be smaller than original (highly repetitive).
241        assert!(compressed.len() < original.len());
242    }
243}