viz_core/middleware/
compression.rs1use std::str::FromStr;
4
5use async_compression::tokio::bufread;
6use tokio_util::io::{ReaderStream, StreamReader};
7
8use crate::{
9 header::{HeaderValue, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH},
10 Body, Handler, IntoResponse, Request, Response, Result, Transform,
11};
12
13#[derive(Debug)]
15pub struct Config;
16
17impl<H> Transform<H> for Config
18where
19 H: Clone,
20{
21 type Output = CompressionMiddleware<H>;
22
23 fn transform(&self, h: H) -> Self::Output {
24 CompressionMiddleware { h }
25 }
26}
27
28#[derive(Clone, Debug)]
30pub struct CompressionMiddleware<H> {
31 h: H,
32}
33
34#[crate::async_trait]
35impl<H, O> Handler<Request> for CompressionMiddleware<H>
36where
37 H: Handler<Request, Output = Result<O>>,
38 O: IntoResponse,
39{
40 type Output = Result<Response>;
41
42 async fn call(&self, req: Request) -> Self::Output {
43 let accept_encoding = req
44 .headers()
45 .get(ACCEPT_ENCODING)
46 .map(HeaderValue::to_str)
47 .and_then(Result::ok)
48 .and_then(parse_accept_encoding);
49
50 let raw = self.h.call(req).await?;
51
52 Ok(match accept_encoding {
53 Some(algo) => Compress::new(raw, algo).into_response(),
54 None => raw.into_response(),
55 })
56 }
57}
58
59#[derive(Debug)]
62pub struct Compress<T> {
63 inner: T,
64 algo: ContentCoding,
65}
66
67impl<T> Compress<T> {
68 pub const fn new(inner: T, algo: ContentCoding) -> Self {
70 Self { inner, algo }
71 }
72}
73
74impl<T: IntoResponse> IntoResponse for Compress<T> {
75 fn into_response(self) -> Response {
76 let mut res = self.inner.into_response();
77
78 match self.algo {
79 ContentCoding::Gzip | ContentCoding::Deflate | ContentCoding::Brotli => {
80 res = res.map(|body| {
81 let body = StreamReader::new(body);
82 if self.algo == ContentCoding::Gzip {
83 Body::from_stream(ReaderStream::new(bufread::GzipEncoder::new(body)))
84 } else if self.algo == ContentCoding::Deflate {
85 Body::from_stream(ReaderStream::new(bufread::DeflateEncoder::new(body)))
86 } else {
87 Body::from_stream(ReaderStream::new(bufread::BrotliEncoder::new(body)))
88 }
89 });
90 res.headers_mut()
91 .append(CONTENT_ENCODING, HeaderValue::from_static(self.algo.into()));
92 res.headers_mut().remove(CONTENT_LENGTH);
93 res
94 }
95 ContentCoding::Any => res,
96 }
97 }
98}
99
100#[derive(Debug, PartialEq, Eq)]
104pub enum ContentCoding {
105 Gzip,
107 Deflate,
109 Brotli,
111 Any,
113}
114
115impl FromStr for ContentCoding {
116 type Err = ();
117
118 fn from_str(s: &str) -> Result<Self, Self::Err> {
119 if s.eq_ignore_ascii_case("deflate") {
120 Ok(Self::Deflate)
121 } else if s.eq_ignore_ascii_case("gzip") {
122 Ok(Self::Gzip)
123 } else if s.eq_ignore_ascii_case("br") {
124 Ok(Self::Brotli)
125 } else if s == "*" {
126 Ok(Self::Any)
127 } else {
128 Err(())
129 }
130 }
131}
132
133impl From<ContentCoding> for &'static str {
134 fn from(cc: ContentCoding) -> Self {
135 match cc {
136 ContentCoding::Gzip => "gzip",
137 ContentCoding::Deflate => "deflate",
138 ContentCoding::Brotli => "br",
139 ContentCoding::Any => "*",
140 }
141 }
142}
143
144#[allow(clippy::cast_sign_loss)]
145#[allow(clippy::cast_possible_truncation)]
146fn parse_accept_encoding(s: &str) -> Option<ContentCoding> {
147 s.split(',')
148 .map(str::trim)
149 .filter_map(|v| match v.split_once(";q=") {
150 None => v.parse::<ContentCoding>().ok().map(|c| (c, 100)),
151 Some((c, q)) => Some((
152 c.parse::<ContentCoding>().ok()?,
153 q.parse::<f32>()
154 .ok()
155 .filter(|v| *v >= 0. && *v <= 1.)
156 .map(|v| (v * 100.) as u8)?,
157 )),
158 })
159 .max_by_key(|(_, q)| *q)
160 .map(|(c, _)| c)
161}