salvo_compression/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3//! Compression middleware for the Salvo web framework.
4//!
5//! Read more: <https://salvo.rs>
6
7use std::fmt::{self, Display, Formatter};
8use std::str::FromStr;
9
10use indexmap::IndexMap;
11
12use salvo_core::http::body::ResBody;
13use salvo_core::http::header::{
14    ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue,
15};
16use salvo_core::http::{self, Mime, StatusCode, mime};
17use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
18
19mod encoder;
20mod stream;
21use encoder::Encoder;
22use stream::EncodeStream;
23
24/// Level of compression data should be compressed with.
25#[non_exhaustive]
26#[derive(Clone, Copy, Default, Debug, Eq, PartialEq)]
27pub enum CompressionLevel {
28    /// Fastest quality of compression, usually produces a bigger size.
29    Fastest,
30    /// Best quality of compression, usually produces the smallest size.
31    Minsize,
32    /// Default quality of compression defined by the selected compression algorithm.
33    #[default]
34    Default,
35    /// Precise quality based on the underlying compression algorithms'
36    /// qualities. The interpretation of this depends on the algorithm chosen
37    /// and the specific implementation backing it.
38    /// Qualities are implicitly clamped to the algorithm's maximum.
39    Precise(u32),
40}
41
42/// CompressionAlgo
43#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
44#[non_exhaustive]
45pub enum CompressionAlgo {
46    /// Compress use Brotli algo.
47    #[cfg(feature = "brotli")]
48    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
49    Brotli,
50
51    /// Compress use Deflate algo.
52    #[cfg(feature = "deflate")]
53    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
54    Deflate,
55
56    /// Compress use Gzip algo.
57    #[cfg(feature = "gzip")]
58    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
59    Gzip,
60
61    /// Compress use Zstd algo.
62    #[cfg(feature = "zstd")]
63    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
64    Zstd,
65}
66
67impl FromStr for CompressionAlgo {
68    type Err = String;
69
70    fn from_str(s: &str) -> Result<Self, Self::Err> {
71        match s {
72            #[cfg(feature = "brotli")]
73            "br" => Ok(Self::Brotli),
74            #[cfg(feature = "brotli")]
75            "brotli" => Ok(Self::Brotli),
76
77            #[cfg(feature = "deflate")]
78            "deflate" => Ok(Self::Deflate),
79
80            #[cfg(feature = "gzip")]
81            "gzip" => Ok(Self::Gzip),
82
83            #[cfg(feature = "zstd")]
84            "zstd" => Ok(Self::Zstd),
85            _ => Err(format!("unknown compression algorithm: {s}")),
86        }
87    }
88}
89
90impl Display for CompressionAlgo {
91    #[allow(unreachable_patterns)]
92    #[allow(unused_variables)]
93    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
94        match self {
95            #[cfg(feature = "brotli")]
96            Self::Brotli => write!(f, "br"),
97            #[cfg(feature = "deflate")]
98            Self::Deflate => write!(f, "deflate"),
99            #[cfg(feature = "gzip")]
100            Self::Gzip => write!(f, "gzip"),
101            #[cfg(feature = "zstd")]
102            Self::Zstd => write!(f, "zstd"),
103            _ => unreachable!(),
104        }
105    }
106}
107
108impl From<CompressionAlgo> for HeaderValue {
109    #[inline]
110    fn from(algo: CompressionAlgo) -> Self {
111        match algo {
112            #[cfg(feature = "brotli")]
113            CompressionAlgo::Brotli => Self::from_static("br"),
114            #[cfg(feature = "deflate")]
115            CompressionAlgo::Deflate => Self::from_static("deflate"),
116            #[cfg(feature = "gzip")]
117            CompressionAlgo::Gzip => Self::from_static("gzip"),
118            #[cfg(feature = "zstd")]
119            CompressionAlgo::Zstd => Self::from_static("zstd"),
120        }
121    }
122}
123
124/// Compression
125#[derive(Clone, Debug)]
126#[non_exhaustive]
127pub struct Compression {
128    /// Compression algorithms to use.
129    pub algos: IndexMap<CompressionAlgo, CompressionLevel>,
130    /// Content types to compress.
131    pub content_types: Vec<Mime>,
132    /// Sets minimum compression size, if body is less than this value, no compression.
133    pub min_length: usize,
134    /// Ignore request algorithms order in `Accept-Encoding` header and always server's config.
135    pub force_priority: bool,
136}
137
138impl Default for Compression {
139    fn default() -> Self {
140        #[allow(unused_mut)]
141        let mut algos = IndexMap::new();
142        #[cfg(feature = "zstd")]
143        algos.insert(CompressionAlgo::Zstd, CompressionLevel::Default);
144        #[cfg(feature = "gzip")]
145        algos.insert(CompressionAlgo::Gzip, CompressionLevel::Default);
146        #[cfg(feature = "deflate")]
147        algos.insert(CompressionAlgo::Deflate, CompressionLevel::Default);
148        #[cfg(feature = "brotli")]
149        algos.insert(CompressionAlgo::Brotli, CompressionLevel::Default);
150        Self {
151            algos,
152            content_types: vec![
153                mime::TEXT_STAR,
154                mime::APPLICATION_JAVASCRIPT,
155                mime::APPLICATION_JSON,
156                mime::IMAGE_SVG,
157                "application/wasm".parse().expect("invalid mime type"),
158                "application/xml".parse().expect("invalid mime type"),
159                "application/rss+xml".parse().expect("invalid mime type"),
160            ],
161            min_length: 0,
162            force_priority: false,
163        }
164    }
165}
166
167impl Compression {
168    /// Create a new `Compression`.
169    #[inline]
170    #[must_use]
171    pub fn new() -> Self {
172        Default::default()
173    }
174
175    /// Remove all compression algorithms.
176    #[inline]
177    #[must_use]
178    pub fn disable_all(mut self) -> Self {
179        self.algos.clear();
180        self
181    }
182
183    /// Sets `Compression` with algos.
184    #[cfg(feature = "gzip")]
185    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
186    #[inline]
187    #[must_use]
188    pub fn enable_gzip(mut self, level: CompressionLevel) -> Self {
189        self.algos.insert(CompressionAlgo::Gzip, level);
190        self
191    }
192    /// Disable gzip compression.
193    #[cfg(feature = "gzip")]
194    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
195    #[inline]
196    #[must_use]
197    pub fn disable_gzip(mut self) -> Self {
198        self.algos.shift_remove(&CompressionAlgo::Gzip);
199        self
200    }
201    /// Enable zstd compression.
202    #[cfg(feature = "zstd")]
203    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
204    #[inline]
205    #[must_use]
206    pub fn enable_zstd(mut self, level: CompressionLevel) -> Self {
207        self.algos.insert(CompressionAlgo::Zstd, level);
208        self
209    }
210    /// Disable zstd compression.
211    #[cfg(feature = "zstd")]
212    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
213    #[inline]
214    #[must_use]
215    pub fn disable_zstd(mut self) -> Self {
216        self.algos.shift_remove(&CompressionAlgo::Zstd);
217        self
218    }
219    /// Enable brotli compression.
220    #[cfg(feature = "brotli")]
221    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
222    #[inline]
223    #[must_use]
224    pub fn enable_brotli(mut self, level: CompressionLevel) -> Self {
225        self.algos.insert(CompressionAlgo::Brotli, level);
226        self
227    }
228    /// Disable brotli compression.
229    #[cfg(feature = "brotli")]
230    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
231    #[inline]
232    #[must_use]
233    pub fn disable_brotli(mut self) -> Self {
234        self.algos.shift_remove(&CompressionAlgo::Brotli);
235        self
236    }
237
238    /// Enable deflate compression.
239    #[cfg(feature = "deflate")]
240    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
241    #[inline]
242    #[must_use]
243    pub fn enable_deflate(mut self, level: CompressionLevel) -> Self {
244        self.algos.insert(CompressionAlgo::Deflate, level);
245        self
246    }
247
248    /// Disable deflate compression.
249    #[cfg(feature = "deflate")]
250    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
251    #[inline]
252    #[must_use]
253    pub fn disable_deflate(mut self) -> Self {
254        self.algos.shift_remove(&CompressionAlgo::Deflate);
255        self
256    }
257
258    /// Sets minimum compression size, if body is less than this value, no compression
259    /// default is 1kb
260    #[inline]
261    #[must_use]
262    pub fn min_length(mut self, size: usize) -> Self {
263        self.min_length = size;
264        self
265    }
266    /// Sets `Compression` with force_priority.
267    #[inline]
268    #[must_use]
269    pub fn force_priority(mut self, force_priority: bool) -> Self {
270        self.force_priority = force_priority;
271        self
272    }
273
274    /// Sets `Compression` with content types list.
275    #[inline]
276    #[must_use]
277    pub fn content_types(mut self, content_types: &[Mime]) -> Self {
278        self.content_types = content_types.to_vec();
279        self
280    }
281
282    fn negotiate(
283        &self,
284        req: &Request,
285        res: &Response,
286    ) -> Option<(CompressionAlgo, CompressionLevel)> {
287        if req.headers().contains_key(&CONTENT_ENCODING) {
288            return None;
289        }
290
291        if !self.content_types.is_empty() {
292            let content_type = res
293                .headers()
294                .get(CONTENT_TYPE)
295                .and_then(|v| v.to_str().ok())
296                .unwrap_or_default();
297            if content_type.is_empty() {
298                return None;
299            }
300            if let Ok(content_type) = content_type.parse::<Mime>() {
301                if !self.content_types.iter().any(|citem| {
302                    citem.type_() == content_type.type_()
303                        && (citem.subtype() == "*" || citem.subtype() == content_type.subtype())
304                }) {
305                    return None;
306                }
307            } else {
308                return None;
309            }
310        }
311        let header = req
312            .headers()
313            .get(ACCEPT_ENCODING)
314            .and_then(|v| v.to_str().ok())?;
315
316        let accept_algos = http::parse_accept_encoding(header)
317            .into_iter()
318            .filter_map(|(algo, level)| {
319                if let Ok(algo) = algo.parse::<CompressionAlgo>() {
320                    Some((algo, level))
321                } else {
322                    None
323                }
324            })
325            .collect::<Vec<_>>();
326        if self.force_priority {
327            let accept_algos = accept_algos
328                .into_iter()
329                .map(|(algo, _)| algo)
330                .collect::<Vec<_>>();
331            self.algos
332                .iter()
333                .find(|(algo, _level)| accept_algos.contains(algo))
334                .map(|(algo, level)| (*algo, *level))
335        } else {
336            accept_algos
337                .into_iter()
338                .find_map(|(algo, _)| self.algos.get(&algo).map(|level| (algo, *level)))
339        }
340    }
341}
342
343#[async_trait]
344impl Handler for Compression {
345    async fn handle(
346        &self,
347        req: &mut Request,
348        depot: &mut Depot,
349        res: &mut Response,
350        ctrl: &mut FlowCtrl,
351    ) {
352        ctrl.call_next(req, depot, res).await;
353        if ctrl.is_ceased() || res.headers().contains_key(CONTENT_ENCODING) {
354            return;
355        }
356
357        if let Some(StatusCode::SWITCHING_PROTOCOLS | StatusCode::NO_CONTENT) = res.status_code {
358            return;
359        }
360
361        match res.take_body() {
362            ResBody::None => {
363                return;
364            }
365            ResBody::Once(bytes) => {
366                if self.min_length > 0 && bytes.len() < self.min_length {
367                    res.body(ResBody::Once(bytes));
368                    return;
369                }
370                if let Some((algo, level)) = self.negotiate(req, res) {
371                    res.stream(EncodeStream::new(algo, level, Some(bytes)));
372                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
373                } else {
374                    res.body(ResBody::Once(bytes));
375                    return;
376                }
377            }
378            ResBody::Chunks(chunks) => {
379                if self.min_length > 0 {
380                    let len: usize = chunks.iter().map(|c| c.len()).sum();
381                    if len < self.min_length {
382                        res.body(ResBody::Chunks(chunks));
383                        return;
384                    }
385                }
386                if let Some((algo, level)) = self.negotiate(req, res) {
387                    res.stream(EncodeStream::new(algo, level, chunks));
388                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
389                } else {
390                    res.body(ResBody::Chunks(chunks));
391                    return;
392                }
393            }
394            ResBody::Hyper(body) => {
395                if let Some((algo, level)) = self.negotiate(req, res) {
396                    res.stream(EncodeStream::new(algo, level, body));
397                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
398                } else {
399                    res.body(ResBody::Hyper(body));
400                    return;
401                }
402            }
403            ResBody::Stream(body) => {
404                let body = body.into_inner();
405                if let Some((algo, level)) = self.negotiate(req, res) {
406                    res.stream(EncodeStream::new(algo, level, body));
407                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
408                } else {
409                    res.body(ResBody::stream(body));
410                    return;
411                }
412            }
413            body => {
414                res.body(body);
415                return;
416            }
417        }
418        res.headers_mut().remove(CONTENT_LENGTH);
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use salvo_core::prelude::*;
425    use salvo_core::test::{ResponseExt, TestClient};
426
427    use super::*;
428
429    #[handler]
430    async fn hello() -> &'static str {
431        "hello"
432    }
433
434    #[tokio::test]
435    async fn test_gzip() {
436        let comp_handler = Compression::new().min_length(1);
437        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
438
439        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
440            .add_header(ACCEPT_ENCODING, "gzip", true)
441            .send(router)
442            .await;
443        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
444        let content = res.take_string().await.unwrap();
445        assert_eq!(content, "hello");
446    }
447
448    #[tokio::test]
449    async fn test_brotli() {
450        let comp_handler = Compression::new().min_length(1);
451        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
452
453        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
454            .add_header(ACCEPT_ENCODING, "br", true)
455            .send(router)
456            .await;
457        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
458        let content = res.take_string().await.unwrap();
459        assert_eq!(content, "hello");
460    }
461
462    #[tokio::test]
463    async fn test_deflate() {
464        let comp_handler = Compression::new().min_length(1);
465        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
466
467        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
468            .add_header(ACCEPT_ENCODING, "deflate", true)
469            .send(router)
470            .await;
471        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "deflate");
472        let content = res.take_string().await.unwrap();
473        assert_eq!(content, "hello");
474    }
475
476    #[tokio::test]
477    async fn test_zstd() {
478        let comp_handler = Compression::new().min_length(1);
479        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
480
481        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
482            .add_header(ACCEPT_ENCODING, "zstd", true)
483            .send(router)
484            .await;
485        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "zstd");
486        let content = res.take_string().await.unwrap();
487        assert_eq!(content, "hello");
488    }
489
490    #[tokio::test]
491    async fn test_min_length_not_compress() {
492        let comp_handler = Compression::new().min_length(10);
493        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
494
495        let res = TestClient::get("http://127.0.0.1:5801/hello")
496            .add_header(ACCEPT_ENCODING, "gzip", true)
497            .send(router)
498            .await;
499        assert!(res.headers().get(CONTENT_ENCODING).is_none());
500    }
501
502    #[tokio::test]
503    async fn test_min_length_should_compress() {
504        let comp_handler = Compression::new().min_length(1);
505        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
506
507        let res = TestClient::get("http://127.0.0.1:5801/hello")
508            .add_header(ACCEPT_ENCODING, "gzip", true)
509            .send(router)
510            .await;
511        assert!(res.headers().get(CONTENT_ENCODING).is_some());
512    }
513
514    #[handler]
515    async fn hello_html(res: &mut Response) {
516        res.render(Text::Html("<html><body>hello</body></html>"));
517    }
518    #[tokio::test]
519    async fn test_content_types_should_compress() {
520        let comp_handler = Compression::new()
521            .min_length(1)
522            .content_types(&[mime::TEXT_HTML]);
523        let router =
524            Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
525
526        let res = TestClient::get("http://127.0.0.1:5801/hello")
527            .add_header(ACCEPT_ENCODING, "gzip", true)
528            .send(router)
529            .await;
530        assert!(res.headers().get(CONTENT_ENCODING).is_some());
531    }
532
533    #[tokio::test]
534    async fn test_content_types_not_compress() {
535        let comp_handler = Compression::new()
536            .min_length(1)
537            .content_types(&[mime::APPLICATION_JSON]);
538        let router =
539            Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
540
541        let res = TestClient::get("http://127.0.0.1:5801/hello")
542            .add_header(ACCEPT_ENCODING, "gzip", true)
543            .send(router)
544            .await;
545        assert!(res.headers().get(CONTENT_ENCODING).is_none());
546    }
547
548    #[tokio::test]
549    async fn test_force_priority() {
550        let comp_handler = Compression::new()
551            .disable_all()
552            .enable_brotli(CompressionLevel::Default)
553            .enable_gzip(CompressionLevel::Default)
554            .min_length(1)
555            .force_priority(true);
556        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
557
558        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
559            .add_header(ACCEPT_ENCODING, "gzip, br", true)
560            .send(router)
561            .await;
562        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
563        let content = res.take_string().await.unwrap();
564        assert_eq!(content, "hello");
565    }
566}