1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3use std::fmt::{self, Display, Formatter};
8use std::str::FromStr;
9
10use indexmap::IndexMap;
11use salvo_core::http::body::ResBody;
12use salvo_core::http::header::{
13 ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue,
14};
15use salvo_core::http::{self, Mime, StatusCode, mime};
16use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
17
18mod encoder;
19mod stream;
20use encoder::Encoder;
21use stream::EncodeStream;
22
23#[non_exhaustive]
25#[derive(Clone, Copy, Default, Debug, Eq, PartialEq)]
26pub enum CompressionLevel {
27 Fastest,
29 Minsize,
31 #[default]
33 Default,
34 Precise(u32),
39}
40
41#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
43#[non_exhaustive]
44pub enum CompressionAlgo {
45 #[cfg(feature = "brotli")]
47 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
48 Brotli,
49
50 #[cfg(feature = "deflate")]
52 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
53 Deflate,
54
55 #[cfg(feature = "gzip")]
57 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
58 Gzip,
59
60 #[cfg(feature = "zstd")]
62 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
63 Zstd,
64}
65
66impl FromStr for CompressionAlgo {
67 type Err = String;
68
69 fn from_str(s: &str) -> Result<Self, Self::Err> {
70 match s {
71 #[cfg(feature = "brotli")]
72 "br" => Ok(Self::Brotli),
73 #[cfg(feature = "brotli")]
74 "brotli" => Ok(Self::Brotli),
75
76 #[cfg(feature = "deflate")]
77 "deflate" => Ok(Self::Deflate),
78
79 #[cfg(feature = "gzip")]
80 "gzip" => Ok(Self::Gzip),
81
82 #[cfg(feature = "zstd")]
83 "zstd" => Ok(Self::Zstd),
84 _ => Err(format!("unknown compression algorithm: {s}")),
85 }
86 }
87}
88
89impl Display for CompressionAlgo {
90 #[allow(unreachable_patterns)]
91 #[allow(unused_variables)]
92 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
93 match self {
94 #[cfg(feature = "brotli")]
95 Self::Brotli => write!(f, "br"),
96 #[cfg(feature = "deflate")]
97 Self::Deflate => write!(f, "deflate"),
98 #[cfg(feature = "gzip")]
99 Self::Gzip => write!(f, "gzip"),
100 #[cfg(feature = "zstd")]
101 Self::Zstd => write!(f, "zstd"),
102 _ => unreachable!(),
103 }
104 }
105}
106
107impl From<CompressionAlgo> for HeaderValue {
108 #[inline]
109 fn from(algo: CompressionAlgo) -> Self {
110 match algo {
111 #[cfg(feature = "brotli")]
112 CompressionAlgo::Brotli => Self::from_static("br"),
113 #[cfg(feature = "deflate")]
114 CompressionAlgo::Deflate => Self::from_static("deflate"),
115 #[cfg(feature = "gzip")]
116 CompressionAlgo::Gzip => Self::from_static("gzip"),
117 #[cfg(feature = "zstd")]
118 CompressionAlgo::Zstd => Self::from_static("zstd"),
119 }
120 }
121}
122
123#[derive(Clone, Debug)]
125#[non_exhaustive]
126pub struct Compression {
127 pub algos: IndexMap<CompressionAlgo, CompressionLevel>,
129 pub content_types: Vec<Mime>,
131 pub min_length: usize,
133 pub force_priority: bool,
135}
136
137impl Default for Compression {
138 fn default() -> Self {
139 #[allow(unused_mut)]
140 let mut algos = IndexMap::new();
141 #[cfg(feature = "zstd")]
142 algos.insert(CompressionAlgo::Zstd, CompressionLevel::Default);
143 #[cfg(feature = "gzip")]
144 algos.insert(CompressionAlgo::Gzip, CompressionLevel::Default);
145 #[cfg(feature = "deflate")]
146 algos.insert(CompressionAlgo::Deflate, CompressionLevel::Default);
147 #[cfg(feature = "brotli")]
148 algos.insert(CompressionAlgo::Brotli, CompressionLevel::Default);
149 Self {
150 algos,
151 content_types: vec![
152 mime::TEXT_STAR,
153 mime::APPLICATION_JAVASCRIPT,
154 mime::APPLICATION_JSON,
155 mime::IMAGE_SVG,
156 "application/wasm".parse().expect("invalid mime type"),
157 "application/xml".parse().expect("invalid mime type"),
158 "application/rss+xml".parse().expect("invalid mime type"),
159 ],
160 min_length: 0,
161 force_priority: false,
162 }
163 }
164}
165
166impl Compression {
167 #[inline]
169 #[must_use]
170 pub fn new() -> Self {
171 Default::default()
172 }
173
174 #[inline]
176 #[must_use]
177 pub fn disable_all(mut self) -> Self {
178 self.algos.clear();
179 self
180 }
181
182 #[cfg(feature = "gzip")]
184 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
185 #[inline]
186 #[must_use]
187 pub fn enable_gzip(mut self, level: CompressionLevel) -> Self {
188 self.algos.insert(CompressionAlgo::Gzip, level);
189 self
190 }
191 #[cfg(feature = "gzip")]
193 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
194 #[inline]
195 #[must_use]
196 pub fn disable_gzip(mut self) -> Self {
197 self.algos.shift_remove(&CompressionAlgo::Gzip);
198 self
199 }
200 #[cfg(feature = "zstd")]
202 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
203 #[inline]
204 #[must_use]
205 pub fn enable_zstd(mut self, level: CompressionLevel) -> Self {
206 self.algos.insert(CompressionAlgo::Zstd, level);
207 self
208 }
209 #[cfg(feature = "zstd")]
211 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
212 #[inline]
213 #[must_use]
214 pub fn disable_zstd(mut self) -> Self {
215 self.algos.shift_remove(&CompressionAlgo::Zstd);
216 self
217 }
218 #[cfg(feature = "brotli")]
220 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
221 #[inline]
222 #[must_use]
223 pub fn enable_brotli(mut self, level: CompressionLevel) -> Self {
224 self.algos.insert(CompressionAlgo::Brotli, level);
225 self
226 }
227 #[cfg(feature = "brotli")]
229 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
230 #[inline]
231 #[must_use]
232 pub fn disable_brotli(mut self) -> Self {
233 self.algos.shift_remove(&CompressionAlgo::Brotli);
234 self
235 }
236
237 #[cfg(feature = "deflate")]
239 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
240 #[inline]
241 #[must_use]
242 pub fn enable_deflate(mut self, level: CompressionLevel) -> Self {
243 self.algos.insert(CompressionAlgo::Deflate, level);
244 self
245 }
246
247 #[cfg(feature = "deflate")]
249 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
250 #[inline]
251 #[must_use]
252 pub fn disable_deflate(mut self) -> Self {
253 self.algos.shift_remove(&CompressionAlgo::Deflate);
254 self
255 }
256
257 #[inline]
260 #[must_use]
261 pub fn min_length(mut self, size: usize) -> Self {
262 self.min_length = size;
263 self
264 }
265 #[inline]
267 #[must_use]
268 pub fn force_priority(mut self, force_priority: bool) -> Self {
269 self.force_priority = force_priority;
270 self
271 }
272
273 #[inline]
275 #[must_use]
276 pub fn content_types(mut self, content_types: &[Mime]) -> Self {
277 self.content_types = content_types.to_vec();
278 self
279 }
280
281 fn negotiate(
282 &self,
283 req: &Request,
284 res: &Response,
285 ) -> Option<(CompressionAlgo, CompressionLevel)> {
286 if req.headers().contains_key(&CONTENT_ENCODING) {
287 return None;
288 }
289
290 if !self.content_types.is_empty() {
291 let content_type = res
292 .headers()
293 .get(CONTENT_TYPE)
294 .and_then(|v| v.to_str().ok())
295 .unwrap_or_default();
296 if content_type.is_empty() {
297 return None;
298 }
299 if let Ok(content_type) = content_type.parse::<Mime>() {
300 if !self.content_types.iter().any(|citem| {
301 citem.type_() == content_type.type_()
302 && (citem.subtype() == "*" || citem.subtype() == content_type.subtype())
303 }) {
304 return None;
305 }
306 } else {
307 return None;
308 }
309 }
310 let header = req
311 .headers()
312 .get(ACCEPT_ENCODING)
313 .and_then(|v| v.to_str().ok())?;
314
315 let accept_algos = http::parse_accept_encoding(header)
316 .into_iter()
317 .filter_map(|(algo, level)| {
318 if let Ok(algo) = algo.parse::<CompressionAlgo>() {
319 Some((algo, level))
320 } else {
321 None
322 }
323 })
324 .collect::<Vec<_>>();
325 if self.force_priority {
326 let accept_algos = accept_algos
327 .into_iter()
328 .map(|(algo, _)| algo)
329 .collect::<Vec<_>>();
330 self.algos
331 .iter()
332 .find(|(algo, _level)| accept_algos.contains(algo))
333 .map(|(algo, level)| (*algo, *level))
334 } else {
335 accept_algos
336 .into_iter()
337 .find_map(|(algo, _)| self.algos.get(&algo).map(|level| (algo, *level)))
338 }
339 }
340}
341
342#[async_trait]
343impl Handler for Compression {
344 async fn handle(
345 &self,
346 req: &mut Request,
347 depot: &mut Depot,
348 res: &mut Response,
349 ctrl: &mut FlowCtrl,
350 ) {
351 ctrl.call_next(req, depot, res).await;
352 if ctrl.is_ceased() || res.headers().contains_key(CONTENT_ENCODING) {
353 return;
354 }
355
356 if let Some(StatusCode::SWITCHING_PROTOCOLS | StatusCode::NO_CONTENT) = res.status_code {
357 return;
358 }
359
360 match res.take_body() {
361 ResBody::None => {
362 return;
363 }
364 ResBody::Once(bytes) => {
365 if self.min_length > 0 && bytes.len() < self.min_length {
366 res.body(ResBody::Once(bytes));
367 return;
368 }
369 if let Some((algo, level)) = self.negotiate(req, res) {
370 res.stream(EncodeStream::new(algo, level, Some(bytes)));
371 res.headers_mut().append(CONTENT_ENCODING, algo.into());
372 } else {
373 res.body(ResBody::Once(bytes));
374 return;
375 }
376 }
377 ResBody::Chunks(chunks) => {
378 if self.min_length > 0 {
379 let len: usize = chunks.iter().map(|c| c.len()).sum();
380 if len < self.min_length {
381 res.body(ResBody::Chunks(chunks));
382 return;
383 }
384 }
385 if let Some((algo, level)) = self.negotiate(req, res) {
386 res.stream(EncodeStream::new(algo, level, chunks));
387 res.headers_mut().append(CONTENT_ENCODING, algo.into());
388 } else {
389 res.body(ResBody::Chunks(chunks));
390 return;
391 }
392 }
393 ResBody::Hyper(body) => {
394 if let Some((algo, level)) = self.negotiate(req, res) {
395 res.stream(EncodeStream::new(algo, level, body));
396 res.headers_mut().append(CONTENT_ENCODING, algo.into());
397 } else {
398 res.body(ResBody::Hyper(body));
399 return;
400 }
401 }
402 ResBody::Stream(body) => {
403 let body = body.into_inner();
404 if let Some((algo, level)) = self.negotiate(req, res) {
405 res.stream(EncodeStream::new(algo, level, body));
406 res.headers_mut().append(CONTENT_ENCODING, algo.into());
407 } else {
408 res.body(ResBody::stream(body));
409 return;
410 }
411 }
412 body => {
413 res.body(body);
414 return;
415 }
416 }
417 res.headers_mut().remove(CONTENT_LENGTH);
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use salvo_core::prelude::*;
424 use salvo_core::test::{ResponseExt, TestClient};
425
426 use super::*;
427
428 #[handler]
429 async fn hello() -> &'static str {
430 "hello"
431 }
432
433 #[tokio::test]
434 async fn test_gzip() {
435 let comp_handler = Compression::new().min_length(1);
436 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
437
438 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
439 .add_header(ACCEPT_ENCODING, "gzip", true)
440 .send(router)
441 .await;
442 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
443 let content = res.take_string().await.unwrap();
444 assert_eq!(content, "hello");
445 }
446
447 #[tokio::test]
448 async fn test_brotli() {
449 let comp_handler = Compression::new().min_length(1);
450 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
451
452 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
453 .add_header(ACCEPT_ENCODING, "br", true)
454 .send(router)
455 .await;
456 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
457 let content = res.take_string().await.unwrap();
458 assert_eq!(content, "hello");
459 }
460
461 #[tokio::test]
462 async fn test_deflate() {
463 let comp_handler = Compression::new().min_length(1);
464 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
465
466 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
467 .add_header(ACCEPT_ENCODING, "deflate", true)
468 .send(router)
469 .await;
470 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "deflate");
471 let content = res.take_string().await.unwrap();
472 assert_eq!(content, "hello");
473 }
474
475 #[tokio::test]
476 async fn test_zstd() {
477 let comp_handler = Compression::new().min_length(1);
478 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
479
480 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
481 .add_header(ACCEPT_ENCODING, "zstd", true)
482 .send(router)
483 .await;
484 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "zstd");
485 let content = res.take_string().await.unwrap();
486 assert_eq!(content, "hello");
487 }
488
489 #[tokio::test]
490 async fn test_min_length_not_compress() {
491 let comp_handler = Compression::new().min_length(10);
492 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
493
494 let res = TestClient::get("http://127.0.0.1:5801/hello")
495 .add_header(ACCEPT_ENCODING, "gzip", true)
496 .send(router)
497 .await;
498 assert!(res.headers().get(CONTENT_ENCODING).is_none());
499 }
500
501 #[tokio::test]
502 async fn test_min_length_should_compress() {
503 let comp_handler = Compression::new().min_length(1);
504 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
505
506 let res = TestClient::get("http://127.0.0.1:5801/hello")
507 .add_header(ACCEPT_ENCODING, "gzip", true)
508 .send(router)
509 .await;
510 assert!(res.headers().get(CONTENT_ENCODING).is_some());
511 }
512
513 #[handler]
514 async fn hello_html(res: &mut Response) {
515 res.render(Text::Html("<html><body>hello</body></html>"));
516 }
517 #[tokio::test]
518 async fn test_content_types_should_compress() {
519 let comp_handler = Compression::new()
520 .min_length(1)
521 .content_types(&[mime::TEXT_HTML]);
522 let router =
523 Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
524
525 let res = TestClient::get("http://127.0.0.1:5801/hello")
526 .add_header(ACCEPT_ENCODING, "gzip", true)
527 .send(router)
528 .await;
529 assert!(res.headers().get(CONTENT_ENCODING).is_some());
530 }
531
532 #[tokio::test]
533 async fn test_content_types_not_compress() {
534 let comp_handler = Compression::new()
535 .min_length(1)
536 .content_types(&[mime::APPLICATION_JSON]);
537 let router =
538 Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
539
540 let res = TestClient::get("http://127.0.0.1:5801/hello")
541 .add_header(ACCEPT_ENCODING, "gzip", true)
542 .send(router)
543 .await;
544 assert!(res.headers().get(CONTENT_ENCODING).is_none());
545 }
546
547 #[tokio::test]
548 async fn test_force_priority() {
549 let comp_handler = Compression::new()
550 .disable_all()
551 .enable_brotli(CompressionLevel::Default)
552 .enable_gzip(CompressionLevel::Default)
553 .min_length(1)
554 .force_priority(true);
555 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
556
557 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
558 .add_header(ACCEPT_ENCODING, "gzip, br", true)
559 .send(router)
560 .await;
561 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
562 let content = res.take_string().await.unwrap();
563 assert_eq!(content, "hello");
564 }
565}