salvo_compression/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3//! Compression middleware for the Salvo web framework.
4//!
5//! Read more: <https://salvo.rs>
6
7use std::fmt::{self, Display, Formatter};
8use std::str::FromStr;
9
10use indexmap::IndexMap;
11use salvo_core::http::body::ResBody;
12use salvo_core::http::header::{
13    ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue,
14};
15use salvo_core::http::{self, Mime, StatusCode, mime};
16use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
17
18mod encoder;
19mod stream;
20use encoder::Encoder;
21use stream::EncodeStream;
22
23/// Level of compression data should be compressed with.
24#[non_exhaustive]
25#[derive(Clone, Copy, Default, Debug, Eq, PartialEq)]
26pub enum CompressionLevel {
27    /// Fastest quality of compression, usually produces a bigger size.
28    Fastest,
29    /// Best quality of compression, usually produces the smallest size.
30    Minsize,
31    /// Default quality of compression defined by the selected compression algorithm.
32    #[default]
33    Default,
34    /// Precise quality based on the underlying compression algorithms'
35    /// qualities. The interpretation of this depends on the algorithm chosen
36    /// and the specific implementation backing it.
37    /// Qualities are implicitly clamped to the algorithm's maximum.
38    Precise(u32),
39}
40
41/// CompressionAlgo
42#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
43#[non_exhaustive]
44pub enum CompressionAlgo {
45    /// Compress use Brotli algo.
46    #[cfg(feature = "brotli")]
47    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
48    Brotli,
49
50    /// Compress use Deflate algo.
51    #[cfg(feature = "deflate")]
52    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
53    Deflate,
54
55    /// Compress use Gzip algo.
56    #[cfg(feature = "gzip")]
57    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
58    Gzip,
59
60    /// Compress use Zstd algo.
61    #[cfg(feature = "zstd")]
62    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
63    Zstd,
64}
65
66impl FromStr for CompressionAlgo {
67    type Err = String;
68
69    fn from_str(s: &str) -> Result<Self, Self::Err> {
70        match s {
71            #[cfg(feature = "brotli")]
72            "br" => Ok(Self::Brotli),
73            #[cfg(feature = "brotli")]
74            "brotli" => Ok(Self::Brotli),
75
76            #[cfg(feature = "deflate")]
77            "deflate" => Ok(Self::Deflate),
78
79            #[cfg(feature = "gzip")]
80            "gzip" => Ok(Self::Gzip),
81
82            #[cfg(feature = "zstd")]
83            "zstd" => Ok(Self::Zstd),
84            _ => Err(format!("unknown compression algorithm: {s}")),
85        }
86    }
87}
88
89impl Display for CompressionAlgo {
90    #[allow(unreachable_patterns)]
91    #[allow(unused_variables)]
92    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
93        match self {
94            #[cfg(feature = "brotli")]
95            Self::Brotli => write!(f, "br"),
96            #[cfg(feature = "deflate")]
97            Self::Deflate => write!(f, "deflate"),
98            #[cfg(feature = "gzip")]
99            Self::Gzip => write!(f, "gzip"),
100            #[cfg(feature = "zstd")]
101            Self::Zstd => write!(f, "zstd"),
102            _ => unreachable!(),
103        }
104    }
105}
106
107impl From<CompressionAlgo> for HeaderValue {
108    #[inline]
109    fn from(algo: CompressionAlgo) -> Self {
110        match algo {
111            #[cfg(feature = "brotli")]
112            CompressionAlgo::Brotli => Self::from_static("br"),
113            #[cfg(feature = "deflate")]
114            CompressionAlgo::Deflate => Self::from_static("deflate"),
115            #[cfg(feature = "gzip")]
116            CompressionAlgo::Gzip => Self::from_static("gzip"),
117            #[cfg(feature = "zstd")]
118            CompressionAlgo::Zstd => Self::from_static("zstd"),
119        }
120    }
121}
122
123/// Compression
124#[derive(Clone, Debug)]
125#[non_exhaustive]
126pub struct Compression {
127    /// Compression algorithms to use.
128    pub algos: IndexMap<CompressionAlgo, CompressionLevel>,
129    /// Content types to compress.
130    pub content_types: Vec<Mime>,
131    /// Sets minimum compression size, if body is less than this value, no compression.
132    pub min_length: usize,
133    /// Ignore request algorithms order in `Accept-Encoding` header and always server's config.
134    pub force_priority: bool,
135}
136
137impl Default for Compression {
138    fn default() -> Self {
139        #[allow(unused_mut)]
140        let mut algos = IndexMap::new();
141        #[cfg(feature = "zstd")]
142        algos.insert(CompressionAlgo::Zstd, CompressionLevel::Default);
143        #[cfg(feature = "gzip")]
144        algos.insert(CompressionAlgo::Gzip, CompressionLevel::Default);
145        #[cfg(feature = "deflate")]
146        algos.insert(CompressionAlgo::Deflate, CompressionLevel::Default);
147        #[cfg(feature = "brotli")]
148        algos.insert(CompressionAlgo::Brotli, CompressionLevel::Default);
149        Self {
150            algos,
151            content_types: vec![
152                mime::TEXT_STAR,
153                mime::APPLICATION_JAVASCRIPT,
154                mime::APPLICATION_JSON,
155                mime::IMAGE_SVG,
156                "application/wasm".parse().expect("invalid mime type"),
157                "application/xml".parse().expect("invalid mime type"),
158                "application/rss+xml".parse().expect("invalid mime type"),
159            ],
160            min_length: 0,
161            force_priority: false,
162        }
163    }
164}
165
166impl Compression {
167    /// Create a new `Compression`.
168    #[inline]
169    #[must_use]
170    pub fn new() -> Self {
171        Default::default()
172    }
173
174    /// Remove all compression algorithms.
175    #[inline]
176    #[must_use]
177    pub fn disable_all(mut self) -> Self {
178        self.algos.clear();
179        self
180    }
181
182    /// Sets `Compression` with algos.
183    #[cfg(feature = "gzip")]
184    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
185    #[inline]
186    #[must_use]
187    pub fn enable_gzip(mut self, level: CompressionLevel) -> Self {
188        self.algos.insert(CompressionAlgo::Gzip, level);
189        self
190    }
191    /// Disable gzip compression.
192    #[cfg(feature = "gzip")]
193    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
194    #[inline]
195    #[must_use]
196    pub fn disable_gzip(mut self) -> Self {
197        self.algos.shift_remove(&CompressionAlgo::Gzip);
198        self
199    }
200    /// Enable zstd compression.
201    #[cfg(feature = "zstd")]
202    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
203    #[inline]
204    #[must_use]
205    pub fn enable_zstd(mut self, level: CompressionLevel) -> Self {
206        self.algos.insert(CompressionAlgo::Zstd, level);
207        self
208    }
209    /// Disable zstd compression.
210    #[cfg(feature = "zstd")]
211    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
212    #[inline]
213    #[must_use]
214    pub fn disable_zstd(mut self) -> Self {
215        self.algos.shift_remove(&CompressionAlgo::Zstd);
216        self
217    }
218    /// Enable brotli compression.
219    #[cfg(feature = "brotli")]
220    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
221    #[inline]
222    #[must_use]
223    pub fn enable_brotli(mut self, level: CompressionLevel) -> Self {
224        self.algos.insert(CompressionAlgo::Brotli, level);
225        self
226    }
227    /// Disable brotli compression.
228    #[cfg(feature = "brotli")]
229    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
230    #[inline]
231    #[must_use]
232    pub fn disable_brotli(mut self) -> Self {
233        self.algos.shift_remove(&CompressionAlgo::Brotli);
234        self
235    }
236
237    /// Enable deflate compression.
238    #[cfg(feature = "deflate")]
239    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
240    #[inline]
241    #[must_use]
242    pub fn enable_deflate(mut self, level: CompressionLevel) -> Self {
243        self.algos.insert(CompressionAlgo::Deflate, level);
244        self
245    }
246
247    /// Disable deflate compression.
248    #[cfg(feature = "deflate")]
249    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
250    #[inline]
251    #[must_use]
252    pub fn disable_deflate(mut self) -> Self {
253        self.algos.shift_remove(&CompressionAlgo::Deflate);
254        self
255    }
256
257    /// Sets minimum compression size, if body is less than this value, no compression
258    /// default is 1kb
259    #[inline]
260    #[must_use]
261    pub fn min_length(mut self, size: usize) -> Self {
262        self.min_length = size;
263        self
264    }
265    /// Sets `Compression` with force_priority.
266    #[inline]
267    #[must_use]
268    pub fn force_priority(mut self, force_priority: bool) -> Self {
269        self.force_priority = force_priority;
270        self
271    }
272
273    /// Sets `Compression` with content types list.
274    #[inline]
275    #[must_use]
276    pub fn content_types(mut self, content_types: &[Mime]) -> Self {
277        self.content_types = content_types.to_vec();
278        self
279    }
280
281    fn negotiate(
282        &self,
283        req: &Request,
284        res: &Response,
285    ) -> Option<(CompressionAlgo, CompressionLevel)> {
286        if req.headers().contains_key(&CONTENT_ENCODING) {
287            return None;
288        }
289
290        if !self.content_types.is_empty() {
291            let content_type = res
292                .headers()
293                .get(CONTENT_TYPE)
294                .and_then(|v| v.to_str().ok())
295                .unwrap_or_default();
296            if content_type.is_empty() {
297                return None;
298            }
299            if let Ok(content_type) = content_type.parse::<Mime>() {
300                if !self.content_types.iter().any(|citem| {
301                    citem.type_() == content_type.type_()
302                        && (citem.subtype() == "*" || citem.subtype() == content_type.subtype())
303                }) {
304                    return None;
305                }
306            } else {
307                return None;
308            }
309        }
310        let header = req
311            .headers()
312            .get(ACCEPT_ENCODING)
313            .and_then(|v| v.to_str().ok())?;
314
315        let accept_algos = http::parse_accept_encoding(header)
316            .into_iter()
317            .filter_map(|(algo, level)| {
318                if let Ok(algo) = algo.parse::<CompressionAlgo>() {
319                    Some((algo, level))
320                } else {
321                    None
322                }
323            })
324            .collect::<Vec<_>>();
325        if self.force_priority {
326            let accept_algos = accept_algos
327                .into_iter()
328                .map(|(algo, _)| algo)
329                .collect::<Vec<_>>();
330            self.algos
331                .iter()
332                .find(|(algo, _level)| accept_algos.contains(algo))
333                .map(|(algo, level)| (*algo, *level))
334        } else {
335            accept_algos
336                .into_iter()
337                .find_map(|(algo, _)| self.algos.get(&algo).map(|level| (algo, *level)))
338        }
339    }
340}
341
342#[async_trait]
343impl Handler for Compression {
344    async fn handle(
345        &self,
346        req: &mut Request,
347        depot: &mut Depot,
348        res: &mut Response,
349        ctrl: &mut FlowCtrl,
350    ) {
351        ctrl.call_next(req, depot, res).await;
352        if ctrl.is_ceased() || res.headers().contains_key(CONTENT_ENCODING) {
353            return;
354        }
355
356        if let Some(StatusCode::SWITCHING_PROTOCOLS | StatusCode::NO_CONTENT) = res.status_code {
357            return;
358        }
359
360        match res.take_body() {
361            ResBody::None => {
362                return;
363            }
364            ResBody::Once(bytes) => {
365                if self.min_length > 0 && bytes.len() < self.min_length {
366                    res.body(ResBody::Once(bytes));
367                    return;
368                }
369                if let Some((algo, level)) = self.negotiate(req, res) {
370                    res.stream(EncodeStream::new(algo, level, Some(bytes)));
371                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
372                } else {
373                    res.body(ResBody::Once(bytes));
374                    return;
375                }
376            }
377            ResBody::Chunks(chunks) => {
378                if self.min_length > 0 {
379                    let len: usize = chunks.iter().map(|c| c.len()).sum();
380                    if len < self.min_length {
381                        res.body(ResBody::Chunks(chunks));
382                        return;
383                    }
384                }
385                if let Some((algo, level)) = self.negotiate(req, res) {
386                    res.stream(EncodeStream::new(algo, level, chunks));
387                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
388                } else {
389                    res.body(ResBody::Chunks(chunks));
390                    return;
391                }
392            }
393            ResBody::Hyper(body) => {
394                if let Some((algo, level)) = self.negotiate(req, res) {
395                    res.stream(EncodeStream::new(algo, level, body));
396                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
397                } else {
398                    res.body(ResBody::Hyper(body));
399                    return;
400                }
401            }
402            ResBody::Stream(body) => {
403                let body = body.into_inner();
404                if let Some((algo, level)) = self.negotiate(req, res) {
405                    res.stream(EncodeStream::new(algo, level, body));
406                    res.headers_mut().append(CONTENT_ENCODING, algo.into());
407                } else {
408                    res.body(ResBody::stream(body));
409                    return;
410                }
411            }
412            body => {
413                res.body(body);
414                return;
415            }
416        }
417        res.headers_mut().remove(CONTENT_LENGTH);
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use salvo_core::prelude::*;
424    use salvo_core::test::{ResponseExt, TestClient};
425
426    use super::*;
427
428    #[handler]
429    async fn hello() -> &'static str {
430        "hello"
431    }
432
433    #[tokio::test]
434    async fn test_gzip() {
435        let comp_handler = Compression::new().min_length(1);
436        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
437
438        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
439            .add_header(ACCEPT_ENCODING, "gzip", true)
440            .send(router)
441            .await;
442        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
443        let content = res.take_string().await.unwrap();
444        assert_eq!(content, "hello");
445    }
446
447    #[tokio::test]
448    async fn test_brotli() {
449        let comp_handler = Compression::new().min_length(1);
450        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
451
452        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
453            .add_header(ACCEPT_ENCODING, "br", true)
454            .send(router)
455            .await;
456        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
457        let content = res.take_string().await.unwrap();
458        assert_eq!(content, "hello");
459    }
460
461    #[tokio::test]
462    async fn test_deflate() {
463        let comp_handler = Compression::new().min_length(1);
464        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
465
466        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
467            .add_header(ACCEPT_ENCODING, "deflate", true)
468            .send(router)
469            .await;
470        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "deflate");
471        let content = res.take_string().await.unwrap();
472        assert_eq!(content, "hello");
473    }
474
475    #[tokio::test]
476    async fn test_zstd() {
477        let comp_handler = Compression::new().min_length(1);
478        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
479
480        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
481            .add_header(ACCEPT_ENCODING, "zstd", true)
482            .send(router)
483            .await;
484        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "zstd");
485        let content = res.take_string().await.unwrap();
486        assert_eq!(content, "hello");
487    }
488
489    #[tokio::test]
490    async fn test_min_length_not_compress() {
491        let comp_handler = Compression::new().min_length(10);
492        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
493
494        let res = TestClient::get("http://127.0.0.1:5801/hello")
495            .add_header(ACCEPT_ENCODING, "gzip", true)
496            .send(router)
497            .await;
498        assert!(res.headers().get(CONTENT_ENCODING).is_none());
499    }
500
501    #[tokio::test]
502    async fn test_min_length_should_compress() {
503        let comp_handler = Compression::new().min_length(1);
504        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
505
506        let res = TestClient::get("http://127.0.0.1:5801/hello")
507            .add_header(ACCEPT_ENCODING, "gzip", true)
508            .send(router)
509            .await;
510        assert!(res.headers().get(CONTENT_ENCODING).is_some());
511    }
512
513    #[handler]
514    async fn hello_html(res: &mut Response) {
515        res.render(Text::Html("<html><body>hello</body></html>"));
516    }
517    #[tokio::test]
518    async fn test_content_types_should_compress() {
519        let comp_handler = Compression::new()
520            .min_length(1)
521            .content_types(&[mime::TEXT_HTML]);
522        let router =
523            Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
524
525        let res = TestClient::get("http://127.0.0.1:5801/hello")
526            .add_header(ACCEPT_ENCODING, "gzip", true)
527            .send(router)
528            .await;
529        assert!(res.headers().get(CONTENT_ENCODING).is_some());
530    }
531
532    #[tokio::test]
533    async fn test_content_types_not_compress() {
534        let comp_handler = Compression::new()
535            .min_length(1)
536            .content_types(&[mime::APPLICATION_JSON]);
537        let router =
538            Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
539
540        let res = TestClient::get("http://127.0.0.1:5801/hello")
541            .add_header(ACCEPT_ENCODING, "gzip", true)
542            .send(router)
543            .await;
544        assert!(res.headers().get(CONTENT_ENCODING).is_none());
545    }
546
547    #[tokio::test]
548    async fn test_force_priority() {
549        let comp_handler = Compression::new()
550            .disable_all()
551            .enable_brotli(CompressionLevel::Default)
552            .enable_gzip(CompressionLevel::Default)
553            .min_length(1)
554            .force_priority(true);
555        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
556
557        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
558            .add_header(ACCEPT_ENCODING, "gzip, br", true)
559            .send(router)
560            .await;
561        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
562        let content = res.take_string().await.unwrap();
563        assert_eq!(content, "hello");
564    }
565}