tork_core/middleware/
compression.rs1use std::io::Write;
4
5use bytes::Bytes;
6use flate2::write::GzEncoder;
7use flate2::Compression as CompressionLevel;
8use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, VARY};
9use http::HeaderValue;
10
11use crate::body::RespBody;
12use crate::constants::TEXT_EVENT_STREAM;
13use crate::error::Result;
14use crate::middleware::{DuplicatePolicy, Middleware, Next, Request};
15use crate::response::{into_body_bytes, Response};
16use crate::router::BoxFuture;
17
18const GZIP: &str = "gzip";
20
21const DEFAULT_MAXIMUM_SIZE: usize = 8 * 1024 * 1024;
28
29pub struct Compression {
36 gzip: bool,
37 minimum_size: usize,
38 maximum_size: usize,
39}
40
41impl Compression {
42 pub fn new() -> Self {
44 Self {
45 gzip: false,
46 minimum_size: 0,
47 maximum_size: DEFAULT_MAXIMUM_SIZE,
48 }
49 }
50
51 pub fn gzip(mut self) -> Self {
53 self.gzip = true;
54 self
55 }
56
57 pub fn minimum_size(mut self, bytes: usize) -> Self {
59 self.minimum_size = bytes;
60 self
61 }
62
63 pub fn maximum_size(mut self, bytes: usize) -> Self {
69 self.maximum_size = bytes;
70 self
71 }
72}
73
74impl Default for Compression {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl Middleware for Compression {
81 fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
82 let gzip_enabled = self.gzip;
83 let minimum_size = self.minimum_size;
84 let maximum_size = self.maximum_size;
85 let accepts_gzip = request
86 .headers()
87 .get(ACCEPT_ENCODING)
88 .and_then(|value| value.to_str().ok())
89 .map(|value| value.to_ascii_lowercase().contains(GZIP))
90 .unwrap_or(false);
91
92 Box::pin(async move {
93 let mut response = next.run(request).await?;
94
95 if gzip_enabled && !is_event_stream(&response) {
101 append_vary_accept_encoding(response.headers_mut());
102 }
103
104 if !gzip_enabled
108 || !accepts_gzip
109 || response.headers().contains_key(CONTENT_ENCODING)
110 || is_event_stream(&response)
111 {
112 return Ok(response);
113 }
114
115 if let Some(length) = content_length(response.headers()) {
118 if length > maximum_size || length < minimum_size {
119 return Ok(response);
120 }
121 }
122
123 let (mut parts, bytes) = into_body_bytes(response).await;
124 if bytes.len() < minimum_size || bytes.len() > maximum_size {
127 return Ok(Response::from_parts(parts, RespBody::new(bytes)));
128 }
129
130 match gzip(&bytes) {
131 Ok(compressed) => {
132 parts
133 .headers
134 .insert(CONTENT_ENCODING, HeaderValue::from_static(GZIP));
135 if let Ok(length) = HeaderValue::from_str(&compressed.len().to_string()) {
136 parts.headers.insert(CONTENT_LENGTH, length);
137 }
138 Ok(Response::from_parts(
139 parts,
140 RespBody::new(Bytes::from(compressed)),
141 ))
142 }
143 Err(_) => Ok(Response::from_parts(parts, RespBody::new(bytes))),
145 }
146 })
147 }
148
149 fn name(&self) -> &'static str {
150 "Compression"
151 }
152
153 fn duplicate_policy(&self) -> DuplicatePolicy {
154 DuplicatePolicy::Reject
155 }
156}
157
158fn append_vary_accept_encoding(headers: &mut http::HeaderMap) {
160 let already_present = headers
161 .get_all(VARY)
162 .iter()
163 .filter_map(|value| value.to_str().ok())
164 .any(|value| value.to_ascii_lowercase().contains("accept-encoding"));
165 if !already_present {
166 headers.append(VARY, HeaderValue::from_static("Accept-Encoding"));
167 }
168}
169
170fn content_length(headers: &http::HeaderMap) -> Option<usize> {
172 headers
173 .get(CONTENT_LENGTH)
174 .and_then(|value| value.to_str().ok())
175 .and_then(|value| value.trim().parse::<usize>().ok())
176}
177
178fn is_event_stream(response: &Response) -> bool {
182 response
183 .headers()
184 .get(CONTENT_TYPE)
185 .and_then(|value| value.to_str().ok())
186 .map(|value| value.starts_with(TEXT_EVENT_STREAM))
187 .unwrap_or(false)
188}
189
190fn gzip(data: &[u8]) -> std::io::Result<Vec<u8>> {
192 let mut encoder = GzEncoder::new(
193 Vec::with_capacity(data.len() / 2 + 16),
194 CompressionLevel::default(),
195 );
196 encoder.write_all(data)?;
197 encoder.finish()
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 fn response_with_content_type(value: &'static str) -> Response {
205 let mut response = http::Response::new(RespBody::new(Bytes::new()));
206 response
207 .headers_mut()
208 .insert(CONTENT_TYPE, HeaderValue::from_static(value));
209 response
210 }
211
212 #[test]
213 fn event_stream_is_detected_and_bypasses_compression() {
214 assert!(is_event_stream(&response_with_content_type(
215 TEXT_EVENT_STREAM
216 )));
217 assert!(!is_event_stream(&response_with_content_type(
218 "application/json"
219 )));
220 assert!(!is_event_stream(&http::Response::new(RespBody::new(
222 Bytes::new()
223 ))));
224 }
225
226 #[test]
227 fn content_length_parses_only_valid_values() {
228 let mut headers = http::HeaderMap::new();
229 assert_eq!(content_length(&headers), None);
230 headers.insert(CONTENT_LENGTH, HeaderValue::from_static("1024"));
231 assert_eq!(content_length(&headers), Some(1024));
232 headers.insert(CONTENT_LENGTH, HeaderValue::from_static("not-a-number"));
233 assert_eq!(content_length(&headers), None);
234 }
235
236 #[test]
237 fn gzip_round_trips_through_flate2() {
238 let original = b"hello world, this is a test that compresses well. ".repeat(20);
239 let compressed = gzip(&original).expect("gzip must succeed");
240 assert!(compressed.len() < original.len());
242 }
243}