Skip to main content

salvo_compression/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3//! Compression middleware for the Salvo web framework.
4//!
5//! This middleware automatically compresses HTTP responses using various algorithms,
6//! reducing bandwidth usage and improving load times for clients.
7//!
8//! # Supported Algorithms
9//!
10//! | Algorithm | Feature | Content-Encoding |
11//! |-----------|---------|------------------|
12//! | Gzip | `gzip` | `gzip` |
13//! | Brotli | `brotli` | `br` |
14//! | Deflate | `deflate` | `deflate` |
15//! | Zstd | `zstd` | `zstd` |
16//!
17//! # Example
18//!
19//! ```ignore
20//! use salvo_compression::{Compression, CompressionLevel};
21//! use salvo_core::prelude::*;
22//!
23//! let compression = Compression::new()
24//!     .enable_gzip(CompressionLevel::Default)
25//!     .min_length(1024);  // Only compress responses > 1KB
26//!
27//! let router = Router::new()
28//!     .hoop(compression)
29//!     .get(my_handler);
30//! ```
31//!
32//! # Algorithm Negotiation
33//!
34//! The middleware negotiates the compression algorithm based on the client's
35//! `Accept-Encoding` header. By default, it respects the client's preference order.
36//! Use `force_priority(true)` to use the server's configured priority instead.
37//!
38//! # Compression Levels
39//!
40//! - [`CompressionLevel::Fastest`]: Fastest compression, larger output
41//! - [`CompressionLevel::Default`]: Balanced compression (recommended)
42//! - [`CompressionLevel::Minsize`]: Best compression, slower
43//! - `CompressionLevel::Precise(u32)`: Fine-grained control
44//!
45//! # Default Content Types
46//!
47//! By default, the middleware compresses:
48//! - `text/*` (HTML, CSS, plain text, etc.)
49//! - `application/javascript`
50//! - `application/json`
51//! - `application/xml`, `application/rss+xml`
52//! - `application/wasm`
53//! - `image/svg+xml`
54//!
55//! Use `.content_types()` to customize which MIME types are compressed.
56//!
57//! # Minimum Length
58//!
59//! Small responses may not benefit from compression. Use `.min_length(bytes)`
60//! to skip compression for responses smaller than the specified size.
61//!
62//! Read more: <https://salvo.rs>
63
64use std::fmt::{self, Display, Formatter};
65use std::str::FromStr;
66
67use indexmap::IndexMap;
68use salvo_core::http::body::ResBody;
69use salvo_core::http::header::{
70    ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue,
71};
72use salvo_core::http::{self, Mime, StatusCode, mime};
73use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
74
75mod encoder;
76mod stream;
77use encoder::Encoder;
78use stream::EncodeStream;
79
80/// Level of compression data should be compressed with.
81#[non_exhaustive]
82#[derive(Clone, Copy, Default, Debug, Eq, PartialEq)]
83pub enum CompressionLevel {
84    /// Fastest quality of compression, usually produces a bigger size.
85    Fastest,
86    /// Best quality of compression, usually produces the smallest size.
87    Minsize,
88    /// Default quality of compression defined by the selected compression algorithm.
89    #[default]
90    Default,
91    /// Precise quality based on the underlying compression algorithms'
92    /// qualities. The interpretation of this depends on the algorithm chosen
93    /// and the specific implementation backing it.
94    /// Qualities are implicitly clamped to the algorithm's maximum.
95    Precise(u32),
96}
97
98/// CompressionAlgo
99#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
100#[non_exhaustive]
101pub enum CompressionAlgo {
102    /// Compress use Brotli algo.
103    #[cfg(feature = "brotli")]
104    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
105    Brotli,
106
107    /// Compress use Deflate algo.
108    #[cfg(feature = "deflate")]
109    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
110    Deflate,
111
112    /// Compress use Gzip algo.
113    #[cfg(feature = "gzip")]
114    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
115    Gzip,
116
117    /// Compress use Zstd algo.
118    #[cfg(feature = "zstd")]
119    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
120    Zstd,
121}
122
123impl FromStr for CompressionAlgo {
124    type Err = String;
125
126    fn from_str(s: &str) -> Result<Self, Self::Err> {
127        match s {
128            #[cfg(feature = "brotli")]
129            "br" => Ok(Self::Brotli),
130            #[cfg(feature = "brotli")]
131            "brotli" => Ok(Self::Brotli),
132
133            #[cfg(feature = "deflate")]
134            "deflate" => Ok(Self::Deflate),
135
136            #[cfg(feature = "gzip")]
137            "gzip" => Ok(Self::Gzip),
138
139            #[cfg(feature = "zstd")]
140            "zstd" => Ok(Self::Zstd),
141            _ => Err(format!("unknown compression algorithm: {s}")),
142        }
143    }
144}
145
146impl Display for CompressionAlgo {
147    #[allow(unreachable_patterns)]
148    #[allow(unused_variables)]
149    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
150        match self {
151            #[cfg(feature = "brotli")]
152            Self::Brotli => write!(f, "br"),
153            #[cfg(feature = "deflate")]
154            Self::Deflate => write!(f, "deflate"),
155            #[cfg(feature = "gzip")]
156            Self::Gzip => write!(f, "gzip"),
157            #[cfg(feature = "zstd")]
158            Self::Zstd => write!(f, "zstd"),
159            _ => unreachable!(),
160        }
161    }
162}
163
164impl From<CompressionAlgo> for HeaderValue {
165    #[inline]
166    fn from(algo: CompressionAlgo) -> Self {
167        match algo {
168            #[cfg(feature = "brotli")]
169            CompressionAlgo::Brotli => Self::from_static("br"),
170            #[cfg(feature = "deflate")]
171            CompressionAlgo::Deflate => Self::from_static("deflate"),
172            #[cfg(feature = "gzip")]
173            CompressionAlgo::Gzip => Self::from_static("gzip"),
174            #[cfg(feature = "zstd")]
175            CompressionAlgo::Zstd => Self::from_static("zstd"),
176        }
177    }
178}
179
180/// Compression
181#[derive(Clone, Debug)]
182#[non_exhaustive]
183pub struct Compression {
184    /// Compression algorithms to use.
185    pub algos: IndexMap<CompressionAlgo, CompressionLevel>,
186    /// Content types to compress.
187    pub content_types: Vec<Mime>,
188    /// Sets minimum compression size, if body is less than this value, no compression.
189    pub min_length: usize,
190    /// Ignore request algorithms order in `Accept-Encoding` header and always server's config.
191    pub force_priority: bool,
192}
193
194impl Default for Compression {
195    fn default() -> Self {
196        #[allow(unused_mut)]
197        let mut algos = IndexMap::new();
198        #[cfg(feature = "zstd")]
199        algos.insert(CompressionAlgo::Zstd, CompressionLevel::Default);
200        #[cfg(feature = "gzip")]
201        algos.insert(CompressionAlgo::Gzip, CompressionLevel::Default);
202        #[cfg(feature = "deflate")]
203        algos.insert(CompressionAlgo::Deflate, CompressionLevel::Default);
204        #[cfg(feature = "brotli")]
205        algos.insert(CompressionAlgo::Brotli, CompressionLevel::Default);
206        Self {
207            algos,
208            content_types: vec![
209                mime::TEXT_STAR,
210                mime::APPLICATION_JAVASCRIPT,
211                mime::APPLICATION_JSON,
212                mime::IMAGE_SVG,
213                "application/wasm".parse().expect("invalid mime type"),
214                "application/xml".parse().expect("invalid mime type"),
215                "application/rss+xml".parse().expect("invalid mime type"),
216            ],
217            min_length: 0,
218            force_priority: false,
219        }
220    }
221}
222
223impl Compression {
224    /// Create a new `Compression`.
225    #[inline]
226    #[must_use]
227    pub fn new() -> Self {
228        Default::default()
229    }
230
231    /// Remove all compression algorithms.
232    #[inline]
233    #[must_use]
234    pub fn disable_all(mut self) -> Self {
235        self.algos.clear();
236        self
237    }
238
239    /// Sets `Compression` with algos.
240    #[cfg(feature = "gzip")]
241    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
242    #[inline]
243    #[must_use]
244    pub fn enable_gzip(mut self, level: CompressionLevel) -> Self {
245        self.algos.insert(CompressionAlgo::Gzip, level);
246        self
247    }
248    /// Disable gzip compression.
249    #[cfg(feature = "gzip")]
250    #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
251    #[inline]
252    #[must_use]
253    pub fn disable_gzip(mut self) -> Self {
254        self.algos.shift_remove(&CompressionAlgo::Gzip);
255        self
256    }
257    /// Enable zstd compression.
258    #[cfg(feature = "zstd")]
259    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
260    #[inline]
261    #[must_use]
262    pub fn enable_zstd(mut self, level: CompressionLevel) -> Self {
263        self.algos.insert(CompressionAlgo::Zstd, level);
264        self
265    }
266    /// Disable zstd compression.
267    #[cfg(feature = "zstd")]
268    #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
269    #[inline]
270    #[must_use]
271    pub fn disable_zstd(mut self) -> Self {
272        self.algos.shift_remove(&CompressionAlgo::Zstd);
273        self
274    }
275    /// Enable brotli compression.
276    #[cfg(feature = "brotli")]
277    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
278    #[inline]
279    #[must_use]
280    pub fn enable_brotli(mut self, level: CompressionLevel) -> Self {
281        self.algos.insert(CompressionAlgo::Brotli, level);
282        self
283    }
284    /// Disable brotli compression.
285    #[cfg(feature = "brotli")]
286    #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
287    #[inline]
288    #[must_use]
289    pub fn disable_brotli(mut self) -> Self {
290        self.algos.shift_remove(&CompressionAlgo::Brotli);
291        self
292    }
293
294    /// Enable deflate compression.
295    #[cfg(feature = "deflate")]
296    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
297    #[inline]
298    #[must_use]
299    pub fn enable_deflate(mut self, level: CompressionLevel) -> Self {
300        self.algos.insert(CompressionAlgo::Deflate, level);
301        self
302    }
303
304    /// Disable deflate compression.
305    #[cfg(feature = "deflate")]
306    #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
307    #[inline]
308    #[must_use]
309    pub fn disable_deflate(mut self) -> Self {
310        self.algos.shift_remove(&CompressionAlgo::Deflate);
311        self
312    }
313
314    /// Sets minimum compression size, if body is less than this value, no compression
315    /// default is 1kb
316    #[inline]
317    #[must_use]
318    pub fn min_length(mut self, size: usize) -> Self {
319        self.min_length = size;
320        self
321    }
322    /// Sets `Compression` with force_priority.
323    #[inline]
324    #[must_use]
325    pub fn force_priority(mut self, force_priority: bool) -> Self {
326        self.force_priority = force_priority;
327        self
328    }
329
330    /// Sets `Compression` with content types list.
331    #[inline]
332    #[must_use]
333    pub fn content_types(mut self, content_types: &[Mime]) -> Self {
334        self.content_types = content_types.to_vec();
335        self
336    }
337
338    fn negotiate(
339        &self,
340        req: &Request,
341        res: &Response,
342    ) -> Option<(CompressionAlgo, CompressionLevel)> {
343        if req.headers().contains_key(&CONTENT_ENCODING) {
344            return None;
345        }
346
347        if !self.content_types.is_empty() {
348            let content_type = res
349                .headers()
350                .get(CONTENT_TYPE)
351                .and_then(|v| v.to_str().ok())
352                .unwrap_or_default();
353            if content_type.is_empty() {
354                return None;
355            }
356            if let Ok(content_type) = content_type.parse::<Mime>() {
357                if !self.content_types.iter().any(|citem| {
358                    citem.type_() == content_type.type_()
359                        && (citem.subtype() == "*" || citem.subtype() == content_type.subtype())
360                }) {
361                    return None;
362                }
363            } else {
364                return None;
365            }
366        }
367        let header = req
368            .headers()
369            .get(ACCEPT_ENCODING)
370            .and_then(|v| v.to_str().ok())?;
371
372        let accept_list = http::parse_accept_encoding(header);
373
374        let wildcard_q = accept_list.iter().find(|(a, _)| a == "*").map(|(_, q)| *q);
375
376        // Algorithms accept q > 0 and sorted by q-value descending.
377        let accept_algos = accept_list
378            .iter()
379            .filter(|(_, q)| *q > 0)
380            .filter_map(|(algo, q)| algo.parse::<CompressionAlgo>().ok().map(|a| (a, *q)))
381            .collect::<Vec<_>>();
382
383        // Algorithms to explicitly rejected when q = 0
384        let rejected = accept_list
385            .iter()
386            .filter(|(_, q)| *q == 0)
387            .filter_map(|(algo, _)| algo.parse::<CompressionAlgo>().ok())
388            .collect::<Vec<_>>();
389
390        if self.force_priority {
391            // Server preference: pick the highest-priority server algo the client accepts.
392            self.algos
393                .iter()
394                .find(|(algo, _)| {
395                    if rejected.contains(algo) {
396                        return false;
397                    }
398                    accept_algos.iter().any(|(a, _)| a == *algo)
399                        || wildcard_q.is_some_and(|q| q > 0)
400                })
401                .map(|(algo, level)| (*algo, *level))
402        } else {
403            // Client preference: pick the highest q-value algo the server supports.
404            let result = accept_algos
405                .iter()
406                .find_map(|(algo, _)| self.algos.get(algo).map(|level| (*algo, *level)));
407
408            if result.is_some() {
409                return result;
410            }
411
412            // Wildcard `*`: use the server's top algo that is not explicitly rejected.
413            if wildcard_q.is_some_and(|q| q > 0) {
414                self.algos
415                    .iter()
416                    .find(|(algo, _)| !rejected.contains(algo))
417                    .map(|(algo, level)| (*algo, *level))
418            } else {
419                None
420            }
421        }
422    }
423}
424
425#[async_trait]
426impl Handler for Compression {
427    async fn handle(
428        &self,
429        req: &mut Request,
430        depot: &mut Depot,
431        res: &mut Response,
432        ctrl: &mut FlowCtrl,
433    ) {
434        ctrl.call_next(req, depot, res).await;
435        if ctrl.is_ceased() || res.headers().contains_key(CONTENT_ENCODING) {
436            return;
437        }
438
439        if let Some(StatusCode::SWITCHING_PROTOCOLS | StatusCode::NO_CONTENT) = res.status_code {
440            return;
441        }
442
443        match res.take_body() {
444            ResBody::None => {
445                return;
446            }
447            ResBody::Once(bytes) => {
448                if self.min_length > 0 && bytes.len() < self.min_length {
449                    res.body(ResBody::Once(bytes));
450                    return;
451                }
452                if let Some((algo, level)) = self.negotiate(req, res) {
453                    res.stream(EncodeStream::new(algo, level, Some(bytes)));
454                    res.headers_mut().insert(CONTENT_ENCODING, algo.into());
455                } else {
456                    res.body(ResBody::Once(bytes));
457                    return;
458                }
459            }
460            ResBody::Chunks(chunks) => {
461                if self.min_length > 0 {
462                    let len: usize = chunks.iter().map(|c| c.len()).sum();
463                    if len < self.min_length {
464                        res.body(ResBody::Chunks(chunks));
465                        return;
466                    }
467                }
468                if let Some((algo, level)) = self.negotiate(req, res) {
469                    res.stream(EncodeStream::new(algo, level, chunks));
470                    res.headers_mut().insert(CONTENT_ENCODING, algo.into());
471                } else {
472                    res.body(ResBody::Chunks(chunks));
473                    return;
474                }
475            }
476            ResBody::Hyper(body) => {
477                if let Some((algo, level)) = self.negotiate(req, res) {
478                    res.stream(EncodeStream::new(algo, level, body));
479                    res.headers_mut().insert(CONTENT_ENCODING, algo.into());
480                } else {
481                    res.body(ResBody::Hyper(body));
482                    return;
483                }
484            }
485            ResBody::Stream(body) => {
486                let body = body.into_inner();
487                if let Some((algo, level)) = self.negotiate(req, res) {
488                    res.stream(EncodeStream::new(algo, level, body));
489                    res.headers_mut().insert(CONTENT_ENCODING, algo.into());
490                } else {
491                    res.body(ResBody::stream(body));
492                    return;
493                }
494            }
495            body => {
496                res.body(body);
497                return;
498            }
499        }
500        res.headers_mut().remove(CONTENT_LENGTH);
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use salvo_core::prelude::*;
507    use salvo_core::test::{ResponseExt, TestClient};
508
509    use super::*;
510
511    #[handler]
512    async fn hello() -> &'static str {
513        "hello"
514    }
515
516    #[tokio::test]
517    async fn test_gzip() {
518        let comp_handler = Compression::new().min_length(1);
519        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
520
521        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
522            .add_header(ACCEPT_ENCODING, "gzip", true)
523            .send(router)
524            .await;
525        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
526        let content = res.take_string().await.unwrap();
527        assert_eq!(content, "hello");
528    }
529
530    #[tokio::test]
531    async fn test_brotli() {
532        let comp_handler = Compression::new().min_length(1);
533        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
534
535        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
536            .add_header(ACCEPT_ENCODING, "br", true)
537            .send(router)
538            .await;
539        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
540        let content = res.take_string().await.unwrap();
541        assert_eq!(content, "hello");
542    }
543
544    #[tokio::test]
545    async fn test_deflate() {
546        let comp_handler = Compression::new().min_length(1);
547        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
548
549        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
550            .add_header(ACCEPT_ENCODING, "deflate", true)
551            .send(router)
552            .await;
553        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "deflate");
554        let content = res.take_string().await.unwrap();
555        assert_eq!(content, "hello");
556    }
557
558    #[tokio::test]
559    async fn test_zstd() {
560        let comp_handler = Compression::new().min_length(1);
561        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
562
563        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
564            .add_header(ACCEPT_ENCODING, "zstd", true)
565            .send(router)
566            .await;
567        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "zstd");
568        let content = res.take_string().await.unwrap();
569        assert_eq!(content, "hello");
570    }
571
572    #[tokio::test]
573    async fn test_min_length_not_compress() {
574        let comp_handler = Compression::new().min_length(10);
575        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
576
577        let res = TestClient::get("http://127.0.0.1:5801/hello")
578            .add_header(ACCEPT_ENCODING, "gzip", true)
579            .send(router)
580            .await;
581        assert!(res.headers().get(CONTENT_ENCODING).is_none());
582    }
583
584    #[tokio::test]
585    async fn test_min_length_should_compress() {
586        let comp_handler = Compression::new().min_length(1);
587        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
588
589        let res = TestClient::get("http://127.0.0.1:5801/hello")
590            .add_header(ACCEPT_ENCODING, "gzip", true)
591            .send(router)
592            .await;
593        assert!(res.headers().get(CONTENT_ENCODING).is_some());
594    }
595
596    #[handler]
597    async fn hello_html(res: &mut Response) {
598        res.render(Text::Html("<html><body>hello</body></html>"));
599    }
600    #[tokio::test]
601    async fn test_content_types_should_compress() {
602        let comp_handler = Compression::new()
603            .min_length(1)
604            .content_types(&[mime::TEXT_HTML]);
605        let router =
606            Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
607
608        let res = TestClient::get("http://127.0.0.1:5801/hello")
609            .add_header(ACCEPT_ENCODING, "gzip", true)
610            .send(router)
611            .await;
612        assert!(res.headers().get(CONTENT_ENCODING).is_some());
613    }
614
615    #[tokio::test]
616    async fn test_content_types_not_compress() {
617        let comp_handler = Compression::new()
618            .min_length(1)
619            .content_types(&[mime::APPLICATION_JSON]);
620        let router =
621            Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello_html));
622
623        let res = TestClient::get("http://127.0.0.1:5801/hello")
624            .add_header(ACCEPT_ENCODING, "gzip", true)
625            .send(router)
626            .await;
627        assert!(res.headers().get(CONTENT_ENCODING).is_none());
628    }
629
630    #[tokio::test]
631    async fn test_q_value_preference() {
632        // Client prefers br (q=1.0) over gzip (q=0.5)
633        let comp_handler = Compression::new().min_length(1);
634        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
635
636        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
637            .add_header(ACCEPT_ENCODING, "gzip;q=0.5, br;q=1.0", true)
638            .send(router)
639            .await;
640        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
641        let content = res.take_string().await.unwrap();
642        assert_eq!(content, "hello");
643    }
644
645    #[tokio::test]
646    async fn test_q_value_zero_rejects_algo() {
647        // gzip is explicitly rejected (q=0), only br is acceptable
648        let comp_handler = Compression::new().min_length(1);
649        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
650
651        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
652            .add_header(ACCEPT_ENCODING, "gzip;q=0, br", true)
653            .send(router)
654            .await;
655        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
656        let content = res.take_string().await.unwrap();
657        assert_eq!(content, "hello");
658    }
659
660    #[tokio::test]
661    async fn test_identity_only_no_compression() {
662        // identity means no encoding; server must not compress
663        let comp_handler = Compression::new().min_length(1);
664        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
665
666        let res = TestClient::get("http://127.0.0.1:5801/hello")
667            .add_header(ACCEPT_ENCODING, "identity", true)
668            .send(router)
669            .await;
670        assert!(res.headers().get(CONTENT_ENCODING).is_none());
671    }
672
673    #[tokio::test]
674    async fn test_wildcard_uses_server_algo() {
675        // `*` means accept any encoding; server picks its preferred algo
676        let comp_handler = Compression::new()
677            .disable_all()
678            .enable_gzip(CompressionLevel::Default)
679            .min_length(1);
680        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
681
682        let res = TestClient::get("http://127.0.0.1:5801/hello")
683            .add_header(ACCEPT_ENCODING, "*", true)
684            .send(router)
685            .await;
686        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
687    }
688
689    #[tokio::test]
690    async fn test_wildcard_excludes_rejected_algo() {
691        // `*` but gzip;q=0 — server must not use gzip, falls back to next algo
692        let comp_handler = Compression::new()
693            .disable_all()
694            .enable_gzip(CompressionLevel::Default)
695            .enable_brotli(CompressionLevel::Default)
696            .min_length(1);
697        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
698
699        let res = TestClient::get("http://127.0.0.1:5801/hello")
700            .add_header(ACCEPT_ENCODING, "*, gzip;q=0", true)
701            .send(router)
702            .await;
703        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
704    }
705
706    #[tokio::test]
707    async fn test_single_content_encoding_header() {
708        // Ensure only one Content-Encoding header is set (no duplicates via append)
709        let comp_handler = Compression::new().min_length(1);
710        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
711
712        let res = TestClient::get("http://127.0.0.1:5801/hello")
713            .add_header(ACCEPT_ENCODING, "gzip", true)
714            .send(router)
715            .await;
716        let count = res.headers().get_all(CONTENT_ENCODING).iter().count();
717        assert_eq!(count, 1, "must have exactly one Content-Encoding header");
718    }
719
720    #[tokio::test]
721    async fn test_force_priority() {
722        let comp_handler = Compression::new()
723            .disable_all()
724            .enable_brotli(CompressionLevel::Default)
725            .enable_gzip(CompressionLevel::Default)
726            .min_length(1)
727            .force_priority(true);
728        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
729
730        let mut res = TestClient::get("http://127.0.0.1:5801/hello")
731            .add_header(ACCEPT_ENCODING, "gzip, br", true)
732            .send(router)
733            .await;
734        assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
735        let content = res.take_string().await.unwrap();
736        assert_eq!(content, "hello");
737    }
738
739    // Tests for CompressionLevel
740    #[test]
741    fn test_compression_level_default() {
742        let level: CompressionLevel = Default::default();
743        assert_eq!(level, CompressionLevel::Default);
744    }
745
746    #[test]
747    fn test_compression_level_fastest() {
748        let level = CompressionLevel::Fastest;
749        assert_eq!(level, CompressionLevel::Fastest);
750    }
751
752    #[test]
753    fn test_compression_level_minsize() {
754        let level = CompressionLevel::Minsize;
755        assert_eq!(level, CompressionLevel::Minsize);
756    }
757
758    #[test]
759    fn test_compression_level_precise() {
760        let level = CompressionLevel::Precise(5);
761        assert_eq!(level, CompressionLevel::Precise(5));
762    }
763
764    #[test]
765    fn test_compression_level_clone() {
766        let level = CompressionLevel::Fastest;
767        let cloned = level;
768        assert_eq!(level, cloned);
769    }
770
771    #[test]
772    fn test_compression_level_copy() {
773        let level = CompressionLevel::Default;
774        let copied = level;
775        assert_eq!(level, copied);
776    }
777
778    #[test]
779    fn test_compression_level_debug() {
780        let level = CompressionLevel::Fastest;
781        let debug_str = format!("{:?}", level);
782        assert!(debug_str.contains("Fastest"));
783    }
784
785    // Tests for CompressionAlgo
786    #[cfg(feature = "gzip")]
787    #[test]
788    fn test_compression_algo_gzip_from_str() {
789        let algo: CompressionAlgo = "gzip".parse().unwrap();
790        assert_eq!(algo, CompressionAlgo::Gzip);
791    }
792
793    #[cfg(feature = "brotli")]
794    #[test]
795    fn test_compression_algo_brotli_from_str() {
796        let algo: CompressionAlgo = "br".parse().unwrap();
797        assert_eq!(algo, CompressionAlgo::Brotli);
798
799        let algo: CompressionAlgo = "brotli".parse().unwrap();
800        assert_eq!(algo, CompressionAlgo::Brotli);
801    }
802
803    #[cfg(feature = "deflate")]
804    #[test]
805    fn test_compression_algo_deflate_from_str() {
806        let algo: CompressionAlgo = "deflate".parse().unwrap();
807        assert_eq!(algo, CompressionAlgo::Deflate);
808    }
809
810    #[cfg(feature = "zstd")]
811    #[test]
812    fn test_compression_algo_zstd_from_str() {
813        let algo: CompressionAlgo = "zstd".parse().unwrap();
814        assert_eq!(algo, CompressionAlgo::Zstd);
815    }
816
817    #[test]
818    fn test_compression_algo_unknown_from_str() {
819        let result: Result<CompressionAlgo, _> = "unknown".parse();
820        assert!(result.is_err());
821        assert!(
822            result
823                .unwrap_err()
824                .contains("unknown compression algorithm")
825        );
826    }
827
828    #[cfg(feature = "gzip")]
829    #[test]
830    fn test_compression_algo_gzip_display() {
831        let algo = CompressionAlgo::Gzip;
832        assert_eq!(format!("{}", algo), "gzip");
833    }
834
835    #[cfg(feature = "brotli")]
836    #[test]
837    fn test_compression_algo_brotli_display() {
838        let algo = CompressionAlgo::Brotli;
839        assert_eq!(format!("{}", algo), "br");
840    }
841
842    #[cfg(feature = "deflate")]
843    #[test]
844    fn test_compression_algo_deflate_display() {
845        let algo = CompressionAlgo::Deflate;
846        assert_eq!(format!("{}", algo), "deflate");
847    }
848
849    #[cfg(feature = "zstd")]
850    #[test]
851    fn test_compression_algo_zstd_display() {
852        let algo = CompressionAlgo::Zstd;
853        assert_eq!(format!("{}", algo), "zstd");
854    }
855
856    #[cfg(feature = "gzip")]
857    #[test]
858    fn test_compression_algo_into_header_value() {
859        let algo = CompressionAlgo::Gzip;
860        let header: HeaderValue = algo.into();
861        assert_eq!(header, "gzip");
862    }
863
864    #[test]
865    fn test_compression_algo_debug() {
866        #[cfg(feature = "gzip")]
867        {
868            let algo = CompressionAlgo::Gzip;
869            let debug_str = format!("{:?}", algo);
870            assert!(debug_str.contains("Gzip"));
871        }
872    }
873
874    #[test]
875    fn test_compression_algo_clone() {
876        #[cfg(feature = "gzip")]
877        {
878            let algo = CompressionAlgo::Gzip;
879            let cloned = algo;
880            assert_eq!(algo, cloned);
881        }
882    }
883
884    #[test]
885    fn test_compression_algo_hash() {
886        use std::collections::HashSet;
887        #[cfg(feature = "gzip")]
888        {
889            let mut set = HashSet::new();
890            set.insert(CompressionAlgo::Gzip);
891            assert!(set.contains(&CompressionAlgo::Gzip));
892        }
893    }
894
895    // Tests for Compression struct
896    #[test]
897    fn test_compression_new() {
898        let comp = Compression::new();
899        assert!(!comp.algos.is_empty());
900        assert!(!comp.content_types.is_empty());
901        assert_eq!(comp.min_length, 0);
902        assert!(!comp.force_priority);
903    }
904
905    #[test]
906    fn test_compression_default() {
907        let comp = Compression::default();
908        assert!(!comp.algos.is_empty());
909    }
910
911    #[test]
912    fn test_compression_disable_all() {
913        let comp = Compression::new().disable_all();
914        assert!(comp.algos.is_empty());
915    }
916
917    #[cfg(feature = "gzip")]
918    #[test]
919    fn test_compression_enable_gzip() {
920        let comp = Compression::new()
921            .disable_all()
922            .enable_gzip(CompressionLevel::Fastest);
923        assert!(comp.algos.contains_key(&CompressionAlgo::Gzip));
924        assert_eq!(
925            comp.algos.get(&CompressionAlgo::Gzip),
926            Some(&CompressionLevel::Fastest)
927        );
928    }
929
930    #[cfg(feature = "gzip")]
931    #[test]
932    fn test_compression_disable_gzip() {
933        let comp = Compression::new().disable_gzip();
934        assert!(!comp.algos.contains_key(&CompressionAlgo::Gzip));
935    }
936
937    #[cfg(feature = "brotli")]
938    #[test]
939    fn test_compression_enable_brotli() {
940        let comp = Compression::new()
941            .disable_all()
942            .enable_brotli(CompressionLevel::Minsize);
943        assert!(comp.algos.contains_key(&CompressionAlgo::Brotli));
944    }
945
946    #[cfg(feature = "brotli")]
947    #[test]
948    fn test_compression_disable_brotli() {
949        let comp = Compression::new().disable_brotli();
950        assert!(!comp.algos.contains_key(&CompressionAlgo::Brotli));
951    }
952
953    #[cfg(feature = "zstd")]
954    #[test]
955    fn test_compression_enable_zstd() {
956        let comp = Compression::new()
957            .disable_all()
958            .enable_zstd(CompressionLevel::Default);
959        assert!(comp.algos.contains_key(&CompressionAlgo::Zstd));
960    }
961
962    #[cfg(feature = "zstd")]
963    #[test]
964    fn test_compression_disable_zstd() {
965        let comp = Compression::new().disable_zstd();
966        assert!(!comp.algos.contains_key(&CompressionAlgo::Zstd));
967    }
968
969    #[cfg(feature = "deflate")]
970    #[test]
971    fn test_compression_enable_deflate() {
972        let comp = Compression::new()
973            .disable_all()
974            .enable_deflate(CompressionLevel::Default);
975        assert!(comp.algos.contains_key(&CompressionAlgo::Deflate));
976    }
977
978    #[cfg(feature = "deflate")]
979    #[test]
980    fn test_compression_disable_deflate() {
981        let comp = Compression::new().disable_deflate();
982        assert!(!comp.algos.contains_key(&CompressionAlgo::Deflate));
983    }
984
985    #[test]
986    fn test_compression_min_length() {
987        let comp = Compression::new().min_length(1024);
988        assert_eq!(comp.min_length, 1024);
989    }
990
991    #[test]
992    fn test_compression_force_priority() {
993        let comp = Compression::new().force_priority(true);
994        assert!(comp.force_priority);
995    }
996
997    #[test]
998    fn test_compression_content_types() {
999        let comp = Compression::new().content_types(&[mime::TEXT_PLAIN, mime::TEXT_HTML]);
1000        assert_eq!(comp.content_types.len(), 2);
1001        assert!(comp.content_types.contains(&mime::TEXT_PLAIN));
1002        assert!(comp.content_types.contains(&mime::TEXT_HTML));
1003    }
1004
1005    #[test]
1006    fn test_compression_debug() {
1007        let comp = Compression::new();
1008        let debug_str = format!("{:?}", comp);
1009        assert!(debug_str.contains("Compression"));
1010        assert!(debug_str.contains("algos"));
1011        assert!(debug_str.contains("content_types"));
1012    }
1013
1014    #[test]
1015    fn test_compression_clone() {
1016        let comp = Compression::new().min_length(100);
1017        let cloned = comp.clone();
1018        assert_eq!(comp.min_length, cloned.min_length);
1019        assert_eq!(comp.algos.len(), cloned.algos.len());
1020    }
1021
1022    // Tests for no compression scenarios
1023    #[tokio::test]
1024    async fn test_no_accept_encoding_header() {
1025        let comp_handler = Compression::new().min_length(1);
1026        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
1027
1028        let res = TestClient::get("http://127.0.0.1:5801/hello")
1029            .send(router)
1030            .await;
1031        assert!(res.headers().get(CONTENT_ENCODING).is_none());
1032    }
1033
1034    #[tokio::test]
1035    async fn test_unsupported_encoding() {
1036        let comp_handler = Compression::new().min_length(1);
1037        let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
1038
1039        let res = TestClient::get("http://127.0.0.1:5801/hello")
1040            .add_header(ACCEPT_ENCODING, "unknown", true)
1041            .send(router)
1042            .await;
1043        assert!(res.headers().get(CONTENT_ENCODING).is_none());
1044    }
1045
1046    #[tokio::test]
1047    async fn test_empty_response() {
1048        #[handler]
1049        async fn empty() {}
1050
1051        let comp_handler = Compression::new();
1052        let router = Router::with_hoop(comp_handler).push(Router::with_path("empty").get(empty));
1053
1054        let res = TestClient::get("http://127.0.0.1:5801/empty")
1055            .add_header(ACCEPT_ENCODING, "gzip", true)
1056            .send(router)
1057            .await;
1058        assert!(res.headers().get(CONTENT_ENCODING).is_none());
1059    }
1060
1061    #[tokio::test]
1062    async fn test_chained_configuration() {
1063        #[cfg(all(feature = "gzip", feature = "brotli"))]
1064        {
1065            let comp_handler = Compression::new()
1066                .disable_all()
1067                .enable_gzip(CompressionLevel::Fastest)
1068                .enable_brotli(CompressionLevel::Default)
1069                .min_length(1)
1070                .force_priority(false)
1071                .content_types(&[mime::TEXT_PLAIN]);
1072
1073            assert_eq!(comp_handler.algos.len(), 2);
1074            assert_eq!(comp_handler.min_length, 1);
1075            assert!(!comp_handler.force_priority);
1076            assert_eq!(comp_handler.content_types.len(), 1);
1077        }
1078    }
1079}