1#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24 header::{CONTENT_ENCODING, CONTENT_LENGTH},
25 Body, Method, Request, Response, StatusCode,
26};
27use lazy_static::lazy_static;
28use mime_guess::{mime, Mime};
29use pin_project::pin_project;
30use std::collections::HashSet;
31use std::pin::Pin;
32use std::task::{Context, Poll};
33use tokio_util::io::{ReaderStream, StreamReader};
34
35use crate::{
36 error_page,
37 handler::RequestHandlerOpts,
38 headers_ext::{AcceptEncoding, ContentCoding},
39 http_ext::MethodExt,
40 settings::CompressionLevel,
41 Error, Result,
42};
43
44lazy_static! {
45 static ref TEXT_MIME_TYPES: HashSet<&'static str> = [
47 "application/rtf",
48 "application/javascript",
49 "application/json",
50 "application/xml",
51 "font/ttf",
52 "application/font-sfnt",
53 "application/vnd.ms-fontobject",
54 "application/wasm",
55 ].into_iter().collect();
56}
57
58const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
60 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
61 ContentCoding::DEFLATE,
62 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
63 ContentCoding::GZIP,
64 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
65 ContentCoding::BROTLI,
66 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
67 ContentCoding::ZSTD,
68];
69
70pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
72 handler_opts.compression = enabled;
73 handler_opts.compression_level = level;
74
75 const FORMATS: &[&str] = &[
76 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
77 "deflate",
78 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
79 "gzip",
80 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
81 "brotli",
82 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
83 "zstd",
84 ];
85 server_info!(
86 "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
87 FORMATS.join(",")
88 );
89}
90
91pub(crate) fn post_process<T>(
93 opts: &RequestHandlerOpts,
94 req: &Request<T>,
95 mut resp: Response<Body>,
96) -> Result<Response<Body>, Error> {
97 if !opts.compression {
98 return Ok(resp);
99 }
100
101 let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
102 if is_precompressed {
103 return Ok(resp);
104 }
105
106 resp.headers_mut().insert(
108 hyper::header::VARY,
109 HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
110 );
111
112 match auto(req.method(), req.headers(), opts.compression_level, resp) {
114 Ok(resp) => Ok(resp),
115 Err(err) => {
116 tracing::error!("error during body compression: {:?}", err);
117 error_page::error_response(
118 req.uri(),
119 req.method(),
120 &StatusCode::INTERNAL_SERVER_ERROR,
121 &opts.page404,
122 &opts.page50x,
123 )
124 }
125 }
126}
127
128pub fn auto(
133 method: &Method,
134 headers: &HeaderMap<HeaderValue>,
135 level: CompressionLevel,
136 resp: Response<Body>,
137) -> Result<Response<Body>> {
138 if method.is_head() || method.is_options() {
140 return Ok(resp);
141 }
142
143 if let Some(encoding) = get_preferred_encoding(headers) {
145 tracing::trace!(
146 "preferred encoding selected from the accept-encoding header: {:?}",
147 encoding
148 );
149
150 if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
152 if !is_text(Mime::from(content_type)) {
153 return Ok(resp);
154 }
155 }
156
157 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
158 if encoding == ContentCoding::GZIP {
159 let (head, body) = resp.into_parts();
160 return Ok(gzip(head, body.into(), level));
161 }
162
163 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
164 if encoding == ContentCoding::DEFLATE {
165 let (head, body) = resp.into_parts();
166 return Ok(deflate(head, body.into(), level));
167 }
168
169 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
170 if encoding == ContentCoding::BROTLI {
171 let (head, body) = resp.into_parts();
172 return Ok(brotli(head, body.into(), level));
173 }
174
175 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
176 if encoding == ContentCoding::ZSTD {
177 let (head, body) = resp.into_parts();
178 return Ok(zstd(head, body.into(), level));
179 }
180
181 tracing::trace!("no compression feature matched the preferred encoding, probably not enabled or unsupported");
182 }
183
184 Ok(resp)
185}
186
187fn is_text(mime: Mime) -> bool {
189 mime.type_() == mime::TEXT
190 || mime
191 .suffix()
192 .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
193 || TEXT_MIME_TYPES.contains(mime.essence_str())
194}
195
196#[cfg(any(feature = "compression", feature = "compression-gzip"))]
199#[cfg_attr(
200 docsrs,
201 doc(cfg(any(feature = "compression", feature = "compression-gzip")))
202)]
203pub fn gzip(
204 mut head: http::response::Parts,
205 body: CompressableBody<Body, hyper::Error>,
206 level: CompressionLevel,
207) -> Response<Body> {
208 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
209
210 tracing::trace!("compressing response body on the fly using GZIP");
211
212 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
213 let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
214 StreamReader::new(body),
215 level,
216 )));
217 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
218 head.headers.remove(CONTENT_LENGTH);
219 head.headers.insert(CONTENT_ENCODING, header);
220 Response::from_parts(head, body)
221}
222
223#[cfg(any(feature = "compression", feature = "compression-deflate"))]
226#[cfg_attr(
227 docsrs,
228 doc(cfg(any(feature = "compression", feature = "compression-deflate")))
229)]
230pub fn deflate(
231 mut head: http::response::Parts,
232 body: CompressableBody<Body, hyper::Error>,
233 level: CompressionLevel,
234) -> Response<Body> {
235 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
236
237 tracing::trace!("compressing response body on the fly using DEFLATE");
238
239 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
240 let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
241 StreamReader::new(body),
242 level,
243 )));
244 let header = create_encoding_header(
245 head.headers.remove(CONTENT_ENCODING),
246 ContentCoding::DEFLATE,
247 );
248 head.headers.remove(CONTENT_LENGTH);
249 head.headers.insert(CONTENT_ENCODING, header);
250 Response::from_parts(head, body)
251}
252
253#[cfg(any(feature = "compression", feature = "compression-brotli"))]
256#[cfg_attr(
257 docsrs,
258 doc(cfg(any(feature = "compression", feature = "compression-brotli")))
259)]
260pub fn brotli(
261 mut head: http::response::Parts,
262 body: CompressableBody<Body, hyper::Error>,
263 level: CompressionLevel,
264) -> Response<Body> {
265 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
266
267 tracing::trace!("compressing response body on the fly using BROTLI");
268
269 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
270 let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
271 StreamReader::new(body),
272 level,
273 )));
274 let header =
275 create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
276 head.headers.remove(CONTENT_LENGTH);
277 head.headers.insert(CONTENT_ENCODING, header);
278 Response::from_parts(head, body)
279}
280
281#[cfg(any(feature = "compression", feature = "compression-zstd"))]
284#[cfg_attr(
285 docsrs,
286 doc(cfg(any(feature = "compression", feature = "compression-zstd")))
287)]
288pub fn zstd(
289 mut head: http::response::Parts,
290 body: CompressableBody<Body, hyper::Error>,
291 level: CompressionLevel,
292) -> Response<Body> {
293 const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
294
295 tracing::trace!("compressing response body on the fly using ZSTD");
296
297 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
298 let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
299 StreamReader::new(body),
300 level,
301 )));
302 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
303 head.headers.remove(CONTENT_LENGTH);
304 head.headers.insert(CONTENT_ENCODING, header);
305 Response::from_parts(head, body)
306}
307
308pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
310 if let Some(val) = existing {
311 if let Ok(str_val) = val.to_str() {
312 return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
313 .unwrap_or_else(|_| coding.into());
314 }
315 }
316 coding.into()
317}
318
319#[inline(always)]
321pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
322 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
323 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
324
325 for encoding in accept_encoding.sorted_encodings() {
326 if AVAILABLE_ENCODINGS.contains(&encoding) {
327 return Some(encoding);
328 }
329 }
330 }
331 None
332}
333
334#[inline(always)]
336pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
337 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
338 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
339
340 return accept_encoding
341 .sorted_encodings()
342 .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
343 .collect::<Vec<_>>();
344 }
345 vec![]
346}
347
348#[pin_project]
351#[derive(Debug)]
352pub struct CompressableBody<S, E>
353where
354 S: Stream<Item = Result<Bytes, E>>,
355 E: std::error::Error,
356{
357 #[pin]
358 body: S,
359}
360
361impl<S, E> Stream for CompressableBody<S, E>
362where
363 S: Stream<Item = Result<Bytes, E>>,
364 E: std::error::Error,
365{
366 type Item = std::io::Result<Bytes>;
367
368 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
369 use std::io::{Error, ErrorKind};
370
371 let pin = self.project();
372 S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
373 }
374}
375
376impl From<Body> for CompressableBody<Body, hyper::Error> {
377 #[inline(always)]
378 fn from(body: Body) -> Self {
379 CompressableBody { body }
380 }
381}