1#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24 Body, Method, Request, Response, StatusCode,
25 header::{CONTENT_ENCODING, CONTENT_LENGTH},
26};
27use mime_guess::{Mime, mime};
28use pin_project::pin_project;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use tokio_util::io::{ReaderStream, StreamReader};
32
33use crate::{
34 Error, Result, error_page,
35 handler::RequestHandlerOpts,
36 headers_ext::{AcceptEncoding, ContentCoding},
37 http_ext::MethodExt,
38 settings::CompressionLevel,
39};
40
41const TEXT_MIME_TYPES: [&str; 8] = [
43 "application/rtf",
44 "application/javascript",
45 "application/json",
46 "application/xml",
47 "font/ttf",
48 "application/font-sfnt",
49 "application/vnd.ms-fontobject",
50 "application/wasm",
51];
52
53const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
55 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
56 ContentCoding::DEFLATE,
57 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
58 ContentCoding::GZIP,
59 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
60 ContentCoding::BROTLI,
61 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
62 ContentCoding::ZSTD,
63];
64
65pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
67 handler_opts.compression = enabled;
68 handler_opts.compression_level = level;
69
70 const FORMATS: &[&str] = &[
71 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
72 "deflate",
73 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
74 "gzip",
75 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
76 "brotli",
77 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
78 "zstd",
79 ];
80 tracing::info!(
81 "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
82 FORMATS.join(",")
83 );
84}
85
86pub(crate) fn post_process<T>(
88 opts: &RequestHandlerOpts,
89 req: &Request<T>,
90 mut resp: Response<Body>,
91) -> Result<Response<Body>, Error> {
92 if !opts.compression {
93 return Ok(resp);
94 }
95
96 let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
97 if is_precompressed {
98 return Ok(resp);
99 }
100
101 let value = resp.headers().get(hyper::header::VARY).map_or(
103 HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
104 |h| {
105 let mut s = h.to_str().unwrap_or_default().to_owned();
106 s.push(',');
107 s.push_str(hyper::header::ACCEPT_ENCODING.as_str());
108 HeaderValue::from_str(s.as_str()).unwrap()
109 },
110 );
111 resp.headers_mut().insert(hyper::header::VARY, value);
112
113 match auto(req.method(), req.headers(), opts.compression_level, resp) {
115 Ok(resp) => Ok(resp),
116 Err(err) => {
117 tracing::error!("error during body compression: {:?}", err);
118 error_page::error_response(
119 req.uri(),
120 req.method(),
121 &StatusCode::INTERNAL_SERVER_ERROR,
122 &opts.page404,
123 &opts.page50x,
124 )
125 }
126 }
127}
128
129pub fn auto(
134 method: &Method,
135 headers: &HeaderMap<HeaderValue>,
136 level: CompressionLevel,
137 resp: Response<Body>,
138) -> Result<Response<Body>> {
139 if method.is_head() || method.is_options() {
141 return Ok(resp);
142 }
143
144 if let Some(encoding) = get_preferred_encoding(headers) {
146 tracing::trace!(
147 "preferred encoding selected from the accept-encoding header: {:?}",
148 encoding
149 );
150
151 if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
153 if !is_text(Mime::from(content_type)) {
154 return Ok(resp);
155 }
156 }
157
158 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
159 if encoding == ContentCoding::GZIP {
160 let (head, body) = resp.into_parts();
161 return Ok(gzip(head, body.into(), level));
162 }
163
164 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
165 if encoding == ContentCoding::DEFLATE {
166 let (head, body) = resp.into_parts();
167 return Ok(deflate(head, body.into(), level));
168 }
169
170 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
171 if encoding == ContentCoding::BROTLI {
172 let (head, body) = resp.into_parts();
173 return Ok(brotli(head, body.into(), level));
174 }
175
176 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
177 if encoding == ContentCoding::ZSTD {
178 let (head, body) = resp.into_parts();
179 return Ok(zstd(head, body.into(), level));
180 }
181
182 tracing::trace!(
183 "no compression feature matched the preferred encoding, probably not enabled or unsupported"
184 );
185 }
186
187 Ok(resp)
188}
189
190fn is_text(mime: Mime) -> bool {
192 mime.type_() == mime::TEXT
193 || mime
194 .suffix()
195 .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
196 || TEXT_MIME_TYPES.contains(&mime.essence_str())
197}
198
199#[cfg(any(feature = "compression", feature = "compression-gzip"))]
202#[cfg_attr(
203 docsrs,
204 doc(cfg(any(feature = "compression", feature = "compression-gzip")))
205)]
206pub fn gzip(
207 mut head: http::response::Parts,
208 body: CompressableBody<Body, hyper::Error>,
209 level: CompressionLevel,
210) -> Response<Body> {
211 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
212
213 tracing::trace!("compressing response body on the fly using GZIP");
214
215 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
216 let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
217 StreamReader::new(body),
218 level,
219 )));
220 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
221 head.headers.remove(CONTENT_LENGTH);
222 head.headers.insert(CONTENT_ENCODING, header);
223 Response::from_parts(head, body)
224}
225
226#[cfg(any(feature = "compression", feature = "compression-deflate"))]
229#[cfg_attr(
230 docsrs,
231 doc(cfg(any(feature = "compression", feature = "compression-deflate")))
232)]
233pub fn deflate(
234 mut head: http::response::Parts,
235 body: CompressableBody<Body, hyper::Error>,
236 level: CompressionLevel,
237) -> Response<Body> {
238 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
239
240 tracing::trace!("compressing response body on the fly using DEFLATE");
241
242 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
243 let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
244 StreamReader::new(body),
245 level,
246 )));
247 let header = create_encoding_header(
248 head.headers.remove(CONTENT_ENCODING),
249 ContentCoding::DEFLATE,
250 );
251 head.headers.remove(CONTENT_LENGTH);
252 head.headers.insert(CONTENT_ENCODING, header);
253 Response::from_parts(head, body)
254}
255
256#[cfg(any(feature = "compression", feature = "compression-brotli"))]
259#[cfg_attr(
260 docsrs,
261 doc(cfg(any(feature = "compression", feature = "compression-brotli")))
262)]
263pub fn brotli(
264 mut head: http::response::Parts,
265 body: CompressableBody<Body, hyper::Error>,
266 level: CompressionLevel,
267) -> Response<Body> {
268 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
269
270 tracing::trace!("compressing response body on the fly using BROTLI");
271
272 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
273 let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
274 StreamReader::new(body),
275 level,
276 )));
277 let header =
278 create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
279 head.headers.remove(CONTENT_LENGTH);
280 head.headers.insert(CONTENT_ENCODING, header);
281 Response::from_parts(head, body)
282}
283
284#[cfg(any(feature = "compression", feature = "compression-zstd"))]
287#[cfg_attr(
288 docsrs,
289 doc(cfg(any(feature = "compression", feature = "compression-zstd")))
290)]
291pub fn zstd(
292 mut head: http::response::Parts,
293 body: CompressableBody<Body, hyper::Error>,
294 level: CompressionLevel,
295) -> Response<Body> {
296 const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
297
298 tracing::trace!("compressing response body on the fly using ZSTD");
299
300 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
301 let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
302 StreamReader::new(body),
303 level,
304 )));
305 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
306 head.headers.remove(CONTENT_LENGTH);
307 head.headers.insert(CONTENT_ENCODING, header);
308 Response::from_parts(head, body)
309}
310
311pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
313 if let Some(val) = existing {
314 if let Ok(str_val) = val.to_str() {
315 return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
316 .unwrap_or_else(|_| coding.into());
317 }
318 }
319 coding.into()
320}
321
322#[inline(always)]
324pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
325 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
326 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
327
328 for encoding in accept_encoding.sorted_encodings() {
329 if AVAILABLE_ENCODINGS.contains(&encoding) {
330 return Some(encoding);
331 }
332 }
333 }
334 None
335}
336
337#[inline(always)]
339pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
340 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
341 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
342
343 return accept_encoding
344 .sorted_encodings()
345 .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
346 .collect::<Vec<_>>();
347 }
348 vec![]
349}
350
351#[pin_project]
354#[derive(Debug)]
355pub struct CompressableBody<S, E>
356where
357 S: Stream<Item = Result<Bytes, E>>,
358 E: std::error::Error,
359{
360 #[pin]
361 body: S,
362}
363
364impl<S, E> Stream for CompressableBody<S, E>
365where
366 S: Stream<Item = Result<Bytes, E>>,
367 E: std::error::Error,
368{
369 type Item = std::io::Result<Bytes>;
370
371 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
372 use std::io::{Error, ErrorKind};
373
374 let pin = self.project();
375 S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
376 }
377}
378
379impl From<Body> for CompressableBody<Body, hyper::Error> {
380 #[inline(always)]
381 fn from(body: Body) -> Self {
382 CompressableBody { body }
383 }
384}