1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3use std::fmt::{self, Display, Formatter};
8use std::str::FromStr;
9
10use indexmap::IndexMap;
11
12use salvo_core::http::body::ResBody;
13use salvo_core::http::header::{
14 ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue,
15};
16use salvo_core::http::{self, Mime, StatusCode, mime};
17use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
18
19mod encoder;
20mod stream;
21use encoder::Encoder;
22use stream::EncodeStream;
23
24#[non_exhaustive]
26#[derive(Clone, Copy, Default, Debug, Eq, PartialEq)]
27pub enum CompressionLevel {
28 Fastest,
30 Minsize,
32 #[default]
34 Default,
35 Precise(u32),
40}
41
42#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
44#[non_exhaustive]
45pub enum CompressionAlgo {
46 #[cfg(feature = "brotli")]
48 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
49 Brotli,
50
51 #[cfg(feature = "deflate")]
53 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
54 Deflate,
55
56 #[cfg(feature = "gzip")]
58 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
59 Gzip,
60
61 #[cfg(feature = "zstd")]
63 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
64 Zstd,
65}
66
67impl FromStr for CompressionAlgo {
68 type Err = String;
69
70 fn from_str(s: &str) -> Result<Self, Self::Err> {
71 match s {
72 #[cfg(feature = "brotli")]
73 "br" => Ok(Self::Brotli),
74 #[cfg(feature = "brotli")]
75 "brotli" => Ok(Self::Brotli),
76
77 #[cfg(feature = "deflate")]
78 "deflate" => Ok(Self::Deflate),
79
80 #[cfg(feature = "gzip")]
81 "gzip" => Ok(Self::Gzip),
82
83 #[cfg(feature = "zstd")]
84 "zstd" => Ok(Self::Zstd),
85 _ => Err(format!("unknown compression algorithm: {s}")),
86 }
87 }
88}
89
90impl Display for CompressionAlgo {
91 #[allow(unreachable_patterns)]
92 #[allow(unused_variables)]
93 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
94 match self {
95 #[cfg(feature = "brotli")]
96 Self::Brotli => write!(f, "br"),
97 #[cfg(feature = "deflate")]
98 Self::Deflate => write!(f, "deflate"),
99 #[cfg(feature = "gzip")]
100 Self::Gzip => write!(f, "gzip"),
101 #[cfg(feature = "zstd")]
102 Self::Zstd => write!(f, "zstd"),
103 _ => unreachable!(),
104 }
105 }
106}
107
108impl From<CompressionAlgo> for HeaderValue {
109 #[inline]
110 fn from(algo: CompressionAlgo) -> Self {
111 match algo {
112 #[cfg(feature = "brotli")]
113 CompressionAlgo::Brotli => Self::from_static("br"),
114 #[cfg(feature = "deflate")]
115 CompressionAlgo::Deflate => Self::from_static("deflate"),
116 #[cfg(feature = "gzip")]
117 CompressionAlgo::Gzip => Self::from_static("gzip"),
118 #[cfg(feature = "zstd")]
119 CompressionAlgo::Zstd => Self::from_static("zstd"),
120 }
121 }
122}
123
124#[derive(Clone, Debug)]
126#[non_exhaustive]
127pub struct Compression {
128 pub algos: IndexMap<CompressionAlgo, CompressionLevel>,
130 pub content_types: Vec<Mime>,
132 pub min_length: usize,
134 pub force_priority: bool,
136}
137
138impl Default for Compression {
139 fn default() -> Self {
140 #[allow(unused_mut)]
141 let mut algos = IndexMap::new();
142 #[cfg(feature = "zstd")]
143 algos.insert(CompressionAlgo::Zstd, CompressionLevel::Default);
144 #[cfg(feature = "gzip")]
145 algos.insert(CompressionAlgo::Gzip, CompressionLevel::Default);
146 #[cfg(feature = "deflate")]
147 algos.insert(CompressionAlgo::Deflate, CompressionLevel::Default);
148 #[cfg(feature = "brotli")]
149 algos.insert(CompressionAlgo::Brotli, CompressionLevel::Default);
150 Self {
151 algos,
152 content_types: vec![
153 mime::TEXT_STAR,
154 mime::APPLICATION_JAVASCRIPT,
155 mime::APPLICATION_JSON,
156 mime::IMAGE_SVG,
157 "application/wasm".parse().expect("invalid mime type"),
158 "application/xml".parse().expect("invalid mime type"),
159 "application/rss+xml".parse().expect("invalid mime type"),
160 ],
161 min_length: 0,
162 force_priority: false,
163 }
164 }
165}
166
167impl Compression {
168 #[inline]
170 #[must_use]
171 pub fn new() -> Self {
172 Default::default()
173 }
174
175 #[inline]
177 #[must_use]
178 pub fn disable_all(mut self) -> Self {
179 self.algos.clear();
180 self
181 }
182
183 #[cfg(feature = "gzip")]
185 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
186 #[inline]
187 #[must_use]
188 pub fn enable_gzip(mut self, level: CompressionLevel) -> Self {
189 self.algos.insert(CompressionAlgo::Gzip, level);
190 self
191 }
192 #[cfg(feature = "gzip")]
194 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
195 #[inline]
196 #[must_use]
197 pub fn disable_gzip(mut self) -> Self {
198 self.algos.shift_remove(&CompressionAlgo::Gzip);
199 self
200 }
201 #[cfg(feature = "zstd")]
203 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
204 #[inline]
205 #[must_use]
206 pub fn enable_zstd(mut self, level: CompressionLevel) -> Self {
207 self.algos.insert(CompressionAlgo::Zstd, level);
208 self
209 }
210 #[cfg(feature = "zstd")]
212 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
213 #[inline]
214 #[must_use]
215 pub fn disable_zstd(mut self) -> Self {
216 self.algos.shift_remove(&CompressionAlgo::Zstd);
217 self
218 }
219 #[cfg(feature = "brotli")]
221 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
222 #[inline]
223 #[must_use]
224 pub fn enable_brotli(mut self, level: CompressionLevel) -> Self {
225 self.algos.insert(CompressionAlgo::Brotli, level);
226 self
227 }
228 #[cfg(feature = "brotli")]
230 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
231 #[inline]
232 #[must_use]
233 pub fn disable_brotli(mut self) -> Self {
234 self.algos.shift_remove(&CompressionAlgo::Brotli);
235 self
236 }
237
238 #[cfg(feature = "deflate")]
240 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
241 #[inline]
242 #[must_use]
243 pub fn enable_deflate(mut self, level: CompressionLevel) -> Self {
244 self.algos.insert(CompressionAlgo::Deflate, level);
245 self
246 }
247
248 #[cfg(feature = "deflate")]
250 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
251 #[inline]
252 #[must_use]
253 pub fn disable_deflate(mut self) -> Self {
254 self.algos.shift_remove(&CompressionAlgo::Deflate);
255 self
256 }
257
258 #[inline]
261 #[must_use]
262 pub fn min_length(mut self, size: usize) -> Self {
263 self.min_length = size;
264 self
265 }
266 #[inline]
268 #[must_use]
269 pub fn force_priority(mut self, force_priority: bool) -> Self {
270 self.force_priority = force_priority;
271 self
272 }
273
274 #[inline]
276 #[must_use]
277 pub fn content_types(mut self, content_types: &[Mime]) -> Self {
278 self.content_types = content_types.to_vec();
279 self
280 }
281
282 fn negotiate(
283 &self,
284 req: &Request,
285 res: &Response,
286 ) -> Option<(CompressionAlgo, CompressionLevel)> {
287 if req.headers().contains_key(&CONTENT_ENCODING) {
288 return None;
289 }
290
291 if !self.content_types.is_empty() {
292 let content_type = res
293 .headers()
294 .get(CONTENT_TYPE)
295 .and_then(|v| v.to_str().ok())
296 .unwrap_or_default();
297 if content_type.is_empty() {
298 return None;
299 }
300 if let Ok(content_type) = content_type.parse::<Mime>() {
301 if !self.content_types.iter().any(|citem| {
302 citem.type_() == content_type.type_()
303 && (citem.subtype() == "*" || citem.subtype() == content_type.subtype())
304 }) {
305 return None;
306 }
307 } else {
308 return None;
309 }
310 }
311 let header = req
312 .headers()
313 .get(ACCEPT_ENCODING)
314 .and_then(|v| v.to_str().ok())?;
315
316 let accept_algos = http::parse_accept_encoding(header)
317 .into_iter()
318 .filter_map(|(algo, level)| {
319 if let Ok(algo) = algo.parse::<CompressionAlgo>() {
320 Some((algo, level))
321 } else {
322 None
323 }
324 })
325 .collect::<Vec<_>>();
326 if self.force_priority {
327 let accept_algos = accept_algos
328 .into_iter()
329 .map(|(algo, _)| algo)
330 .collect::<Vec<_>>();
331 self.algos
332 .iter()
333 .find(|(algo, _level)| accept_algos.contains(algo))
334 .map(|(algo, level)| (*algo, *level))
335 } else {
336 accept_algos
337 .into_iter()
338 .find_map(|(algo, _)| self.algos.get(&algo).map(|level| (algo, *level)))
339 }
340 }
341}
342
343#[async_trait]
344impl Handler for Compression {
345 async fn handle(
346 &self,
347 req: &mut Request,
348 depot: &mut Depot,
349 res: &mut Response,
350 ctrl: &mut FlowCtrl,
351 ) {
352 ctrl.call_next(req, depot, res).await;
353 if ctrl.is_ceased() || res.headers().contains_key(CONTENT_ENCODING) {
354 return;
355 }
356
357 if let Some(StatusCode::SWITCHING_PROTOCOLS | StatusCode::NO_CONTENT) = res.status_code {
358 return;
359 }
360
361 match res.take_body() {
362 ResBody::None => {
363 return;
364 }
365 ResBody::Once(bytes) => {
366 if self.min_length > 0 && bytes.len() < self.min_length {
367 res.body(ResBody::Once(bytes));
368 return;
369 }
370 if let Some((algo, level)) = self.negotiate(req, res) {
371 res.stream(EncodeStream::new(algo, level, Some(bytes)));
372 res.headers_mut().append(CONTENT_ENCODING, algo.into());
373 } else {
374 res.body(ResBody::Once(bytes));
375 return;
376 }
377 }
378 ResBody::Chunks(chunks) => {
379 if self.min_length > 0 {
380 let len: usize = chunks.iter().map(|c| c.len()).sum();
381 if len < self.min_length {
382 res.body(ResBody::Chunks(chunks));
383 return;
384 }
385 }
386 if let Some((algo, level)) = self.negotiate(req, res) {
387 res.stream(EncodeStream::new(algo, level, chunks));
388 res.headers_mut().append(CONTENT_ENCODING, algo.into());
389 } else {
390 res.body(ResBody::Chunks(chunks));
391 return;
392 }
393 }
394 ResBody::Hyper(body) => {
395 if let Some((algo, level)) = self.negotiate(req, res) {
396 res.stream(EncodeStream::new(algo, level, body));
397 res.headers_mut().append(CONTENT_ENCODING, algo.into());
398 } else {
399 res.body(ResBody::Hyper(body));
400 return;
401 }
402 }
403 ResBody::Stream(body) => {
404 let body = body.into_inner();
405 if let Some((algo, level)) = self.negotiate(req, res) {
406 res.stream(EncodeStream::new(algo, level, body));
407 res.headers_mut().append(CONTENT_ENCODING, algo.into());
408 } else {
409 res.body(ResBody::stream(body));
410 return;
411 }
412 }
413 body => {
414 res.body(body);
415 return;
416 }
417 }
418 res.headers_mut().remove(CONTENT_LENGTH);
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use salvo_core::prelude::*;
425 use salvo_core::test::{ResponseExt, TestClient};
426
427 use super::*;
428
429 #[handler]
430 async fn hello() -> &'static str {
431 "hello"
432 }
433
434 #[tokio::test]
435 async fn test_gzip() {
436 let comp_handler = Compression::new().min_length(1);
437 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
438
439 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
440 .add_header(ACCEPT_ENCODING, "gzip", true)
441 .send(router)
442 .await;
443 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
444 let content = res.take_string().await.unwrap();
445 assert_eq!(content, "hello");
446 }
447
448 #[tokio::test]
449 async fn test_brotli() {
450 let comp_handler = Compression::new().min_length(1);
451 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
452
453 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
454 .add_header(ACCEPT_ENCODING, "br", true)
455 .send(router)
456 .await;
457 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
458 let content = res.take_string().await.unwrap();
459 assert_eq!(content, "hello");
460 }
461
462 #[tokio::test]
463 async fn test_deflate() {
464 let comp_handler = Compression::new().min_length(1);
465 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
466
467 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
468 .add_header(ACCEPT_ENCODING, "deflate", true)
469 .send(router)
470 .await;
471 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "deflate");
472 let content = res.take_string().await.unwrap();
473 assert_eq!(content, "hello");
474 }
475
476 #[tokio::test]
477 async fn test_zstd() {
478 let comp_handler = Compression::new().min_length(1);
479 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
480
481 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
482 .add_header(ACCEPT_ENCODING, "zstd", true)
483 .send(router)
484 .await;
485 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "zstd");
486 let content = res.take_string().await.unwrap();
487 assert_eq!(content, "hello");
488 }
489
490 #[tokio::test]
491 async fn test_min_length_not_compress() {
492 let comp_handler = Compression::new().min_length(10);
493 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
494
495 let res = TestClient::get("http://127.0.0.1:5801/hello")
496 .add_header(ACCEPT_ENCODING, "gzip", true)
497 .send(router)
498 .await;
499 assert!(res.headers().get(CONTENT_ENCODING).is_none());
500 }
501
502 #[tokio::test]
503 async fn test_min_length_should_compress() {
504 let comp_handler = Compression::new().min_length(1);
505 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
506
507 let res = TestClient::get("http://127.0.0.1:5801/hello")
508 .add_header(ACCEPT_ENCODING, "gzip", true)
509 .send(router)
510 .await;
511 assert!(res.headers().get(CONTENT_ENCODING).is_some());
512 }
513
514 #[handler]
515 async fn hello_html(res: &mut Response) {
516 res.render(Text::Html("<html><body>hello</body></html>"));
517 }
518 #[tokio::test]
519 async fn test_content_types_should_compress() {
520 let comp_handler = Compression::new()
521 .min_length(1)
522 .content_types(&[mime::TEXT_HTML]);
523 let router =
524 Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
525
526 let res = TestClient::get("http://127.0.0.1:5801/hello")
527 .add_header(ACCEPT_ENCODING, "gzip", true)
528 .send(router)
529 .await;
530 assert!(res.headers().get(CONTENT_ENCODING).is_some());
531 }
532
533 #[tokio::test]
534 async fn test_content_types_not_compress() {
535 let comp_handler = Compression::new()
536 .min_length(1)
537 .content_types(&[mime::APPLICATION_JSON]);
538 let router =
539 Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
540
541 let res = TestClient::get("http://127.0.0.1:5801/hello")
542 .add_header(ACCEPT_ENCODING, "gzip", true)
543 .send(router)
544 .await;
545 assert!(res.headers().get(CONTENT_ENCODING).is_none());
546 }
547
548 #[tokio::test]
549 async fn test_force_priority() {
550 let comp_handler = Compression::new()
551 .disable_all()
552 .enable_brotli(CompressionLevel::Default)
553 .enable_gzip(CompressionLevel::Default)
554 .min_length(1)
555 .force_priority(true);
556 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
557
558 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
559 .add_header(ACCEPT_ENCODING, "gzip, br", true)
560 .send(router)
561 .await;
562 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
563 let content = res.take_string().await.unwrap();
564 assert_eq!(content, "hello");
565 }
566}