1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3use std::fmt::{self, Display, Formatter};
8use std::str::FromStr;
9
10use indexmap::IndexMap;
11
12use salvo_core::http::body::ResBody;
13use salvo_core::http::header::{
14 ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE, HeaderValue,
15};
16use salvo_core::http::{self, Mime, StatusCode, mime};
17use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
18
19mod encoder;
20mod stream;
21use encoder::Encoder;
22use stream::EncodeStream;
23
24#[non_exhaustive]
26#[derive(Clone, Copy, Default, Debug, Eq, PartialEq)]
27pub enum CompressionLevel {
28 Fastest,
30 Minsize,
32 #[default]
34 Default,
35 Precise(u32),
40}
41
42#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
44#[non_exhaustive]
45pub enum CompressionAlgo {
46 #[cfg(feature = "brotli")]
48 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
49 Brotli,
50
51 #[cfg(feature = "deflate")]
53 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
54 Deflate,
55
56 #[cfg(feature = "gzip")]
58 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
59 Gzip,
60
61 #[cfg(feature = "zstd")]
63 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
64 Zstd,
65}
66
67impl FromStr for CompressionAlgo {
68 type Err = String;
69
70 fn from_str(s: &str) -> Result<Self, Self::Err> {
71 match s {
72 #[cfg(feature = "brotli")]
73 "br" => Ok(CompressionAlgo::Brotli),
74 #[cfg(feature = "brotli")]
75 "brotli" => Ok(CompressionAlgo::Brotli),
76
77 #[cfg(feature = "deflate")]
78 "deflate" => Ok(CompressionAlgo::Deflate),
79
80 #[cfg(feature = "gzip")]
81 "gzip" => Ok(CompressionAlgo::Gzip),
82
83 #[cfg(feature = "zstd")]
84 "zstd" => Ok(CompressionAlgo::Zstd),
85 _ => Err(format!("unknown compression algorithm: {s}")),
86 }
87 }
88}
89
90impl Display for CompressionAlgo {
91 #[allow(unreachable_patterns)]
92 #[allow(unused_variables)]
93 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
94 match self {
95 #[cfg(feature = "brotli")]
96 CompressionAlgo::Brotli => write!(f, "br"),
97 #[cfg(feature = "deflate")]
98 CompressionAlgo::Deflate => write!(f, "deflate"),
99 #[cfg(feature = "gzip")]
100 CompressionAlgo::Gzip => write!(f, "gzip"),
101 #[cfg(feature = "zstd")]
102 CompressionAlgo::Zstd => write!(f, "zstd"),
103 _ => unreachable!(),
104 }
105 }
106}
107
108impl From<CompressionAlgo> for HeaderValue {
109 #[inline]
110 fn from(algo: CompressionAlgo) -> Self {
111 match algo {
112 #[cfg(feature = "brotli")]
113 CompressionAlgo::Brotli => HeaderValue::from_static("br"),
114 #[cfg(feature = "deflate")]
115 CompressionAlgo::Deflate => HeaderValue::from_static("deflate"),
116 #[cfg(feature = "gzip")]
117 CompressionAlgo::Gzip => HeaderValue::from_static("gzip"),
118 #[cfg(feature = "zstd")]
119 CompressionAlgo::Zstd => HeaderValue::from_static("zstd"),
120 }
121 }
122}
123
124#[derive(Clone, Debug)]
126#[non_exhaustive]
127pub struct Compression {
128 pub algos: IndexMap<CompressionAlgo, CompressionLevel>,
130 pub content_types: Vec<Mime>,
132 pub min_length: usize,
134 pub force_priority: bool,
136}
137
138impl Default for Compression {
139 fn default() -> Self {
140 #[allow(unused_mut)]
141 let mut algos = IndexMap::new();
142 #[cfg(feature = "zstd")]
143 algos.insert(CompressionAlgo::Zstd, CompressionLevel::Default);
144 #[cfg(feature = "gzip")]
145 algos.insert(CompressionAlgo::Gzip, CompressionLevel::Default);
146 #[cfg(feature = "deflate")]
147 algos.insert(CompressionAlgo::Deflate, CompressionLevel::Default);
148 #[cfg(feature = "brotli")]
149 algos.insert(CompressionAlgo::Brotli, CompressionLevel::Default);
150 Self {
151 algos,
152 content_types: vec![
153 mime::TEXT_STAR,
154 mime::APPLICATION_JAVASCRIPT,
155 mime::APPLICATION_JSON,
156 mime::IMAGE_SVG,
157 "application/wasm".parse().expect("invalid mime type"),
158 "application/xml".parse().expect("invalid mime type"),
159 "application/rss+xml".parse().expect("invalid mime type"),
160 ],
161 min_length: 0,
162 force_priority: false,
163 }
164 }
165}
166
167impl Compression {
168 #[inline]
170 pub fn new() -> Self {
171 Default::default()
172 }
173
174 #[inline]
176 pub fn disable_all(mut self) -> Self {
177 self.algos.clear();
178 self
179 }
180
181 #[cfg(feature = "gzip")]
183 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
184 #[inline]
185 pub fn enable_gzip(mut self, level: CompressionLevel) -> Self {
186 self.algos.insert(CompressionAlgo::Gzip, level);
187 self
188 }
189 #[cfg(feature = "gzip")]
191 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
192 #[inline]
193 pub fn disable_gzip(mut self) -> Self {
194 self.algos.shift_remove(&CompressionAlgo::Gzip);
195 self
196 }
197 #[cfg(feature = "zstd")]
199 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
200 #[inline]
201 pub fn enable_zstd(mut self, level: CompressionLevel) -> Self {
202 self.algos.insert(CompressionAlgo::Zstd, level);
203 self
204 }
205 #[cfg(feature = "zstd")]
207 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
208 #[inline]
209 pub fn disable_zstd(mut self) -> Self {
210 self.algos.shift_remove(&CompressionAlgo::Zstd);
211 self
212 }
213 #[cfg(feature = "brotli")]
215 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
216 #[inline]
217 pub fn enable_brotli(mut self, level: CompressionLevel) -> Self {
218 self.algos.insert(CompressionAlgo::Brotli, level);
219 self
220 }
221 #[cfg(feature = "brotli")]
223 #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))]
224 #[inline]
225 pub fn disable_brotli(mut self) -> Self {
226 self.algos.shift_remove(&CompressionAlgo::Brotli);
227 self
228 }
229
230 #[cfg(feature = "deflate")]
232 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
233 #[inline]
234 pub fn enable_deflate(mut self, level: CompressionLevel) -> Self {
235 self.algos.insert(CompressionAlgo::Deflate, level);
236 self
237 }
238
239 #[cfg(feature = "deflate")]
241 #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))]
242 #[inline]
243 pub fn disable_deflate(mut self) -> Self {
244 self.algos.shift_remove(&CompressionAlgo::Deflate);
245 self
246 }
247
248 #[inline]
251 pub fn min_length(mut self, size: usize) -> Self {
252 self.min_length = size;
253 self
254 }
255 #[inline]
257 pub fn force_priority(mut self, force_priority: bool) -> Self {
258 self.force_priority = force_priority;
259 self
260 }
261
262 #[inline]
264 pub fn content_types(mut self, content_types: &[Mime]) -> Self {
265 self.content_types = content_types.to_vec();
266 self
267 }
268
269 fn negotiate(
270 &self,
271 req: &Request,
272 res: &Response,
273 ) -> Option<(CompressionAlgo, CompressionLevel)> {
274 if req.headers().contains_key(&CONTENT_ENCODING) {
275 return None;
276 }
277
278 if !self.content_types.is_empty() {
279 let content_type = res
280 .headers()
281 .get(CONTENT_TYPE)
282 .and_then(|v| v.to_str().ok())
283 .unwrap_or_default();
284 if content_type.is_empty() {
285 return None;
286 }
287 if let Ok(content_type) = content_type.parse::<Mime>() {
288 if !self.content_types.iter().any(|citem| {
289 citem.type_() == content_type.type_()
290 && (citem.subtype() == "*" || citem.subtype() == content_type.subtype())
291 }) {
292 return None;
293 }
294 } else {
295 return None;
296 }
297 }
298 let header = req
299 .headers()
300 .get(ACCEPT_ENCODING)
301 .and_then(|v| v.to_str().ok())?;
302
303 let accept_algos = http::parse_accept_encoding(header)
304 .into_iter()
305 .filter_map(|(algo, level)| {
306 if let Ok(algo) = algo.parse::<CompressionAlgo>() {
307 Some((algo, level))
308 } else {
309 None
310 }
311 })
312 .collect::<Vec<_>>();
313 if self.force_priority {
314 let accept_algos = accept_algos
315 .into_iter()
316 .map(|(algo, _)| algo)
317 .collect::<Vec<_>>();
318 self.algos
319 .iter()
320 .find(|(algo, _level)| accept_algos.contains(algo))
321 .map(|(algo, level)| (*algo, *level))
322 } else {
323 accept_algos
324 .into_iter()
325 .find_map(|(algo, _)| self.algos.get(&algo).map(|level| (algo, *level)))
326 }
327 }
328}
329
330#[async_trait]
331impl Handler for Compression {
332 async fn handle(
333 &self,
334 req: &mut Request,
335 depot: &mut Depot,
336 res: &mut Response,
337 ctrl: &mut FlowCtrl,
338 ) {
339 ctrl.call_next(req, depot, res).await;
340 if ctrl.is_ceased() || res.headers().contains_key(CONTENT_ENCODING) {
341 return;
342 }
343
344 if let Some(code) = res.status_code {
345 if code == StatusCode::SWITCHING_PROTOCOLS || code == StatusCode::NO_CONTENT {
346 return;
347 }
348 }
349
350 match res.take_body() {
351 ResBody::None => {
352 return;
353 }
354 ResBody::Once(bytes) => {
355 if self.min_length > 0 && bytes.len() < self.min_length {
356 res.body(ResBody::Once(bytes));
357 return;
358 }
359 match self.negotiate(req, res) {
360 Some((algo, level)) => {
361 res.stream(EncodeStream::new(algo, level, Some(bytes)));
362 res.headers_mut().append(CONTENT_ENCODING, algo.into());
363 }
364 None => {
365 res.body(ResBody::Once(bytes));
366 return;
367 }
368 }
369 }
370 ResBody::Chunks(chunks) => {
371 if self.min_length > 0 {
372 let len: usize = chunks.iter().map(|c| c.len()).sum();
373 if len < self.min_length {
374 res.body(ResBody::Chunks(chunks));
375 return;
376 }
377 }
378 match self.negotiate(req, res) {
379 Some((algo, level)) => {
380 res.stream(EncodeStream::new(algo, level, chunks));
381 res.headers_mut().append(CONTENT_ENCODING, algo.into());
382 }
383 None => {
384 res.body(ResBody::Chunks(chunks));
385 return;
386 }
387 }
388 }
389 ResBody::Hyper(body) => match self.negotiate(req, res) {
390 Some((algo, level)) => {
391 res.stream(EncodeStream::new(algo, level, body));
392 res.headers_mut().append(CONTENT_ENCODING, algo.into());
393 }
394 None => {
395 res.body(ResBody::Hyper(body));
396 return;
397 }
398 },
399 ResBody::Stream(body) => {
400 let body = body.into_inner();
401 match self.negotiate(req, res) {
402 Some((algo, level)) => {
403 res.stream(EncodeStream::new(algo, level, body));
404 res.headers_mut().append(CONTENT_ENCODING, algo.into());
405 }
406 None => {
407 res.body(ResBody::stream(body));
408 return;
409 }
410 }
411 }
412 body => {
413 res.body(body);
414 return;
415 }
416 }
417 res.headers_mut().remove(CONTENT_LENGTH);
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use salvo_core::prelude::*;
424 use salvo_core::test::{ResponseExt, TestClient};
425
426 use super::*;
427
428 #[handler]
429 async fn hello() -> &'static str {
430 "hello"
431 }
432
433 #[tokio::test]
434 async fn test_gzip() {
435 let comp_handler = Compression::new().min_length(1);
436 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
437
438 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
439 .add_header(ACCEPT_ENCODING, "gzip", true)
440 .send(router)
441 .await;
442 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "gzip");
443 let content = res.take_string().await.unwrap();
444 assert_eq!(content, "hello");
445 }
446
447 #[tokio::test]
448 async fn test_brotli() {
449 let comp_handler = Compression::new().min_length(1);
450 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
451
452 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
453 .add_header(ACCEPT_ENCODING, "br", true)
454 .send(router)
455 .await;
456 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "br");
457 let content = res.take_string().await.unwrap();
458 assert_eq!(content, "hello");
459 }
460
461 #[tokio::test]
462 async fn test_deflate() {
463 let comp_handler = Compression::new().min_length(1);
464 let router = Router::with_hoop(comp_handler).push(Router::with_path("hello").get(hello));
465
466 let mut res = TestClient::get("http://127.0.0.1:5801/hello")
467 .add_header(ACCEPT_ENCODING, "deflate", true)
468 .send(router)
469 .await;
470 assert_eq!(res.headers().get(CONTENT_ENCODING).unwrap(), "deflate");
471 let content = res.take_string().await.unwrap();
472 assert_eq!(content, "hello");
473 }
474}