tosic_http/middleware/
compression.rs1use crate::body::message_body::MessageBody;
4use crate::body::BoxBody;
5use crate::error::ServerError;
6use crate::prelude::{Error, HttpPayload, HttpRequest, HttpResponse};
7use flate2::write::{DeflateEncoder, GzEncoder};
8use flate2::Compression;
9use std::future::Future;
10use std::io::Write;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use tower::{Layer, Service};
14use tracing::warn;
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum CompressionType {
19 Gzip,
21 Deflate,
23}
24
25#[derive(Clone, Debug)]
26pub struct CompressionLayer;
28
29impl Default for CompressionLayer {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl CompressionLayer {
36 pub fn new() -> Self {
38 Self
39 }
40}
41
42impl<S: Clone> Layer<S> for CompressionLayer {
43 type Service = CompressionMiddleware<S>;
44
45 fn layer(&self, service: S) -> Self::Service {
46 CompressionMiddleware { inner: service }
47 }
48}
49
50#[derive(Clone, Debug)]
51pub struct CompressionMiddleware<S: Clone> {
53 inner: S,
54}
55
56impl<S> Service<(HttpRequest, HttpPayload)> for CompressionMiddleware<S>
57where
58 S: Service<(HttpRequest, HttpPayload), Response = HttpResponse, Error = Error>
59 + Clone
60 + Send
61 + Sync
62 + 'static,
63 S::Future: Send + 'static,
64{
65 type Response = HttpResponse;
66 type Error = Error;
67 type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Error>> + Send>>;
68
69 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70 self.inner.poll_ready(cx)
71 }
72
73 fn call(&mut self, req: (HttpRequest, HttpPayload)) -> Self::Future {
74 let mut inner = self.inner.clone();
75 let (request, payload) = req;
76
77 let accept_encoding = request.headers().get("Accept-Encoding").cloned();
78
79 Box::pin(async move {
80 let mut response = inner.call((request, payload)).await?;
81
82 let supported_encodings = vec![CompressionType::Gzip, CompressionType::Deflate];
83
84 if let Some(encoding_header) = accept_encoding {
85 if let Ok(encoding_str) = encoding_header.to_str() {
86 let client_encodings = parse_accept_encoding(encoding_str);
87
88 if let Some(best_encoding) =
89 negotiate_encoding(&client_encodings, &supported_encodings)
90 {
91 response = compress_response(response, best_encoding).await?;
92
93 let encoding_value = match best_encoding {
94 CompressionType::Gzip => "gzip",
95 CompressionType::Deflate => "deflate",
96 };
97 response
98 .headers_mut()
99 .insert("Content-Encoding", encoding_value.parse().unwrap());
100 } else {
101 warn!("No common encoding found between client and server");
102 }
103 }
104 }
105
106 response
107 .headers_mut()
108 .insert("Vary", "Accept-Encoding".parse().unwrap());
109
110 Ok(response)
111 })
112 }
113}
114
115async fn compress_response(
117 mut response: HttpResponse,
118 compression_type: CompressionType,
119) -> Result<HttpResponse, ServerError> {
120 let body_bytes = response
122 .body
123 .clone()
124 .try_into_bytes()
125 .expect("Unable to read body");
126
127 let compressed_body = match compression_type {
128 CompressionType::Gzip => {
129 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
130 encoder.write_all(&body_bytes)?;
131 encoder.finish()?
132 }
133 CompressionType::Deflate => {
134 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
135 encoder.write_all(&body_bytes)?;
136 encoder.finish()?
137 }
138 };
139
140 let body = BoxBody::new(compressed_body);
141
142 response = response.set_body(body);
143
144 Ok(response)
145}
146
147fn parse_accept_encoding(header_value: &str) -> Vec<(CompressionType, f32)> {
149 let mut encodings = Vec::new();
150
151 for part in header_value.split(',') {
152 let part = part.trim();
153 let mut tokens = part.split(';');
154
155 if let Some(encoding_str) = tokens.next() {
156 let quality = tokens
157 .find_map(|token| {
158 if token.trim().starts_with("q=") {
159 token.trim()[2..].parse::<f32>().ok()
160 } else {
161 None
162 }
163 })
164 .unwrap_or(1.0);
165
166 let encoding = match encoding_str {
167 "gzip" => Some(CompressionType::Gzip),
168 "deflate" => Some(CompressionType::Deflate),
169 "*" => Some(CompressionType::Gzip),
170 _ => None,
171 };
172
173 if let Some(enc) = encoding {
174 encodings.push((enc, quality));
175 }
176 }
177 }
178
179 encodings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
180
181 encodings
182}
183
184fn negotiate_encoding(
186 client_encodings: &[(CompressionType, f32)],
187 server_encodings: &[CompressionType],
188) -> Option<CompressionType> {
189 for (encoding, _) in client_encodings {
190 if server_encodings.contains(encoding) {
191 return Some(*encoding);
192 }
193 }
194 None
195}