1#[cfg(any(feature = "compression", feature = "compression-brotli"))]
12use async_compression::tokio::bufread::BrotliEncoder;
13#[cfg(any(feature = "compression", feature = "compression-deflate"))]
14use async_compression::tokio::bufread::DeflateEncoder;
15#[cfg(any(feature = "compression", feature = "compression-gzip"))]
16use async_compression::tokio::bufread::GzipEncoder;
17#[cfg(any(feature = "compression", feature = "compression-zstd"))]
18use async_compression::tokio::bufread::ZstdEncoder;
19
20use bytes::Bytes;
21use futures_util::Stream;
22use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
23use hyper::{
24 header::{CONTENT_ENCODING, CONTENT_LENGTH},
25 Body, Method, Request, Response, StatusCode,
26};
27use lazy_static::lazy_static;
28use mime_guess::{mime, Mime};
29use pin_project::pin_project;
30use std::collections::HashSet;
31use std::pin::Pin;
32use std::task::{Context, Poll};
33use tokio_util::io::{ReaderStream, StreamReader};
34
35use crate::{
36 error_page,
37 handler::RequestHandlerOpts,
38 headers_ext::{AcceptEncoding, ContentCoding},
39 http_ext::MethodExt,
40 settings::CompressionLevel,
41 Error, Result,
42};
43
44lazy_static! {
45 static ref TEXT_MIME_TYPES: HashSet<&'static str> = [
47 "application/rtf",
48 "application/javascript",
49 "application/json",
50 "application/xml",
51 "font/ttf",
52 "application/font-sfnt",
53 "application/vnd.ms-fontobject",
54 "application/wasm",
55 ].into_iter().collect();
56}
57
58const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
60 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
61 ContentCoding::DEFLATE,
62 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
63 ContentCoding::GZIP,
64 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
65 ContentCoding::BROTLI,
66 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
67 ContentCoding::ZSTD,
68];
69
70pub fn init(enabled: bool, level: CompressionLevel, handler_opts: &mut RequestHandlerOpts) {
72 handler_opts.compression = enabled;
73 handler_opts.compression_level = level;
74
75 const FORMATS: &[&str] = &[
76 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
77 "deflate",
78 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
79 "gzip",
80 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
81 "brotli",
82 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
83 "zstd",
84 ];
85 tracing::info!(
86 "auto compression: enabled={enabled}, formats={}, compression level={level:?}",
87 FORMATS.join(",")
88 );
89}
90
91pub(crate) fn post_process<T>(
93 opts: &RequestHandlerOpts,
94 req: &Request<T>,
95 mut resp: Response<Body>,
96) -> Result<Response<Body>, Error> {
97 if !opts.compression {
98 return Ok(resp);
99 }
100
101 let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
102 if is_precompressed {
103 return Ok(resp);
104 }
105
106 let value = resp.headers().get(hyper::header::VARY).map_or(
108 HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
109 |h| {
110 let mut s = h.to_str().unwrap_or_default().to_owned();
111 s.push(',');
112 s.push_str(hyper::header::ACCEPT_ENCODING.as_str());
113 HeaderValue::from_str(s.as_str()).unwrap()
114 },
115 );
116 resp.headers_mut().insert(hyper::header::VARY, value);
117
118 match auto(req.method(), req.headers(), opts.compression_level, resp) {
120 Ok(resp) => Ok(resp),
121 Err(err) => {
122 tracing::error!("error during body compression: {:?}", err);
123 error_page::error_response(
124 req.uri(),
125 req.method(),
126 &StatusCode::INTERNAL_SERVER_ERROR,
127 &opts.page404,
128 &opts.page50x,
129 )
130 }
131 }
132}
133
134pub fn auto(
139 method: &Method,
140 headers: &HeaderMap<HeaderValue>,
141 level: CompressionLevel,
142 resp: Response<Body>,
143) -> Result<Response<Body>> {
144 if method.is_head() || method.is_options() {
146 return Ok(resp);
147 }
148
149 if let Some(encoding) = get_preferred_encoding(headers) {
151 tracing::trace!(
152 "preferred encoding selected from the accept-encoding header: {:?}",
153 encoding
154 );
155
156 if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
158 if !is_text(Mime::from(content_type)) {
159 return Ok(resp);
160 }
161 }
162
163 #[cfg(any(feature = "compression", feature = "compression-gzip"))]
164 if encoding == ContentCoding::GZIP {
165 let (head, body) = resp.into_parts();
166 return Ok(gzip(head, body.into(), level));
167 }
168
169 #[cfg(any(feature = "compression", feature = "compression-deflate"))]
170 if encoding == ContentCoding::DEFLATE {
171 let (head, body) = resp.into_parts();
172 return Ok(deflate(head, body.into(), level));
173 }
174
175 #[cfg(any(feature = "compression", feature = "compression-brotli"))]
176 if encoding == ContentCoding::BROTLI {
177 let (head, body) = resp.into_parts();
178 return Ok(brotli(head, body.into(), level));
179 }
180
181 #[cfg(any(feature = "compression", feature = "compression-zstd"))]
182 if encoding == ContentCoding::ZSTD {
183 let (head, body) = resp.into_parts();
184 return Ok(zstd(head, body.into(), level));
185 }
186
187 tracing::trace!("no compression feature matched the preferred encoding, probably not enabled or unsupported");
188 }
189
190 Ok(resp)
191}
192
193fn is_text(mime: Mime) -> bool {
195 mime.type_() == mime::TEXT
196 || mime
197 .suffix()
198 .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
199 || TEXT_MIME_TYPES.contains(mime.essence_str())
200}
201
202#[cfg(any(feature = "compression", feature = "compression-gzip"))]
205#[cfg_attr(
206 docsrs,
207 doc(cfg(any(feature = "compression", feature = "compression-gzip")))
208)]
209pub fn gzip(
210 mut head: http::response::Parts,
211 body: CompressableBody<Body, hyper::Error>,
212 level: CompressionLevel,
213) -> Response<Body> {
214 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
215
216 tracing::trace!("compressing response body on the fly using GZIP");
217
218 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
219 let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::with_quality(
220 StreamReader::new(body),
221 level,
222 )));
223 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
224 head.headers.remove(CONTENT_LENGTH);
225 head.headers.insert(CONTENT_ENCODING, header);
226 Response::from_parts(head, body)
227}
228
229#[cfg(any(feature = "compression", feature = "compression-deflate"))]
232#[cfg_attr(
233 docsrs,
234 doc(cfg(any(feature = "compression", feature = "compression-deflate")))
235)]
236pub fn deflate(
237 mut head: http::response::Parts,
238 body: CompressableBody<Body, hyper::Error>,
239 level: CompressionLevel,
240) -> Response<Body> {
241 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
242
243 tracing::trace!("compressing response body on the fly using DEFLATE");
244
245 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
246 let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::with_quality(
247 StreamReader::new(body),
248 level,
249 )));
250 let header = create_encoding_header(
251 head.headers.remove(CONTENT_ENCODING),
252 ContentCoding::DEFLATE,
253 );
254 head.headers.remove(CONTENT_LENGTH);
255 head.headers.insert(CONTENT_ENCODING, header);
256 Response::from_parts(head, body)
257}
258
259#[cfg(any(feature = "compression", feature = "compression-brotli"))]
262#[cfg_attr(
263 docsrs,
264 doc(cfg(any(feature = "compression", feature = "compression-brotli")))
265)]
266pub fn brotli(
267 mut head: http::response::Parts,
268 body: CompressableBody<Body, hyper::Error>,
269 level: CompressionLevel,
270) -> Response<Body> {
271 const DEFAULT_COMPRESSION_LEVEL: i32 = 4;
272
273 tracing::trace!("compressing response body on the fly using BROTLI");
274
275 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
276 let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::with_quality(
277 StreamReader::new(body),
278 level,
279 )));
280 let header =
281 create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
282 head.headers.remove(CONTENT_LENGTH);
283 head.headers.insert(CONTENT_ENCODING, header);
284 Response::from_parts(head, body)
285}
286
287#[cfg(any(feature = "compression", feature = "compression-zstd"))]
290#[cfg_attr(
291 docsrs,
292 doc(cfg(any(feature = "compression", feature = "compression-zstd")))
293)]
294pub fn zstd(
295 mut head: http::response::Parts,
296 body: CompressableBody<Body, hyper::Error>,
297 level: CompressionLevel,
298) -> Response<Body> {
299 const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
300
301 tracing::trace!("compressing response body on the fly using ZSTD");
302
303 let level = level.into_algorithm_level(DEFAULT_COMPRESSION_LEVEL);
304 let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::with_quality(
305 StreamReader::new(body),
306 level,
307 )));
308 let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
309 head.headers.remove(CONTENT_LENGTH);
310 head.headers.insert(CONTENT_ENCODING, header);
311 Response::from_parts(head, body)
312}
313
314pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
316 if let Some(val) = existing {
317 if let Ok(str_val) = val.to_str() {
318 return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
319 .unwrap_or_else(|_| coding.into());
320 }
321 }
322 coding.into()
323}
324
325#[inline(always)]
327pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
328 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
329 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
330
331 for encoding in accept_encoding.sorted_encodings() {
332 if AVAILABLE_ENCODINGS.contains(&encoding) {
333 return Some(encoding);
334 }
335 }
336 }
337 None
338}
339
340#[inline(always)]
342pub fn get_encodings(headers: &HeaderMap<HeaderValue>) -> Vec<ContentCoding> {
343 if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
344 tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
345
346 return accept_encoding
347 .sorted_encodings()
348 .filter(|encoding| AVAILABLE_ENCODINGS.contains(encoding))
349 .collect::<Vec<_>>();
350 }
351 vec![]
352}
353
354#[pin_project]
357#[derive(Debug)]
358pub struct CompressableBody<S, E>
359where
360 S: Stream<Item = Result<Bytes, E>>,
361 E: std::error::Error,
362{
363 #[pin]
364 body: S,
365}
366
367impl<S, E> Stream for CompressableBody<S, E>
368where
369 S: Stream<Item = Result<Bytes, E>>,
370 E: std::error::Error,
371{
372 type Item = std::io::Result<Bytes>;
373
374 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
375 use std::io::{Error, ErrorKind};
376
377 let pin = self.project();
378 S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
379 }
380}
381
382impl From<Body> for CompressableBody<Body, hyper::Error> {
383 #[inline(always)]
384 fn from(body: Body) -> Self {
385 CompressableBody { body }
386 }
387}