tosic_http/middleware/
compression.rs

1//! Compression middleware
2
3use crate::body::message_body::MessageBody;
4use crate::body::BoxBody;
5use crate::error::ServerError;
6use crate::prelude::{Error, HttpPayload, HttpRequest, HttpResponse};
7use flate2::write::{DeflateEncoder, GzEncoder};
8use flate2::Compression;
9use std::future::Future;
10use std::io::Write;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use tower::{Layer, Service};
14use tracing::warn;
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17/// Compression type
18pub enum CompressionType {
19    /// Gzip compression
20    Gzip,
21    /// Deflate compression
22    Deflate,
23}
24
25#[derive(Clone, Debug)]
26/// The compression layer to be used
27pub struct CompressionLayer;
28
29impl Default for CompressionLayer {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl CompressionLayer {
36    /// Create a new compression layer
37    pub fn new() -> Self {
38        Self
39    }
40}
41
42impl<S: Clone> Layer<S> for CompressionLayer {
43    type Service = CompressionMiddleware<S>;
44
45    fn layer(&self, service: S) -> Self::Service {
46        CompressionMiddleware { inner: service }
47    }
48}
49
50#[derive(Clone, Debug)]
51/// Compression middleware
52pub struct CompressionMiddleware<S: Clone> {
53    inner: S,
54}
55
56impl<S> Service<(HttpRequest, HttpPayload)> for CompressionMiddleware<S>
57where
58    S: Service<(HttpRequest, HttpPayload), Response = HttpResponse, Error = Error>
59        + Clone
60        + Send
61        + Sync
62        + 'static,
63    S::Future: Send + 'static,
64{
65    type Response = HttpResponse;
66    type Error = Error;
67    type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Error>> + Send>>;
68
69    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70        self.inner.poll_ready(cx)
71    }
72
73    fn call(&mut self, req: (HttpRequest, HttpPayload)) -> Self::Future {
74        let mut inner = self.inner.clone();
75        let (request, payload) = req;
76
77        let accept_encoding = request.headers().get("Accept-Encoding").cloned();
78
79        Box::pin(async move {
80            let mut response = inner.call((request, payload)).await?;
81
82            let supported_encodings = vec![CompressionType::Gzip, CompressionType::Deflate];
83
84            if let Some(encoding_header) = accept_encoding {
85                if let Ok(encoding_str) = encoding_header.to_str() {
86                    let client_encodings = parse_accept_encoding(encoding_str);
87
88                    if let Some(best_encoding) =
89                        negotiate_encoding(&client_encodings, &supported_encodings)
90                    {
91                        response = compress_response(response, best_encoding).await?;
92
93                        let encoding_value = match best_encoding {
94                            CompressionType::Gzip => "gzip",
95                            CompressionType::Deflate => "deflate",
96                        };
97                        response
98                            .headers_mut()
99                            .insert("Content-Encoding", encoding_value.parse().unwrap());
100                    } else {
101                        warn!("No common encoding found between client and server");
102                    }
103                }
104            }
105
106            response
107                .headers_mut()
108                .insert("Vary", "Accept-Encoding".parse().unwrap());
109
110            Ok(response)
111        })
112    }
113}
114
115/// Helper function to compress the response
116async fn compress_response(
117    mut response: HttpResponse,
118    compression_type: CompressionType,
119) -> Result<HttpResponse, ServerError> {
120    // Read the body
121    let body_bytes = response
122        .body
123        .clone()
124        .try_into_bytes()
125        .expect("Unable to read body");
126
127    let compressed_body = match compression_type {
128        CompressionType::Gzip => {
129            let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
130            encoder.write_all(&body_bytes)?;
131            encoder.finish()?
132        }
133        CompressionType::Deflate => {
134            let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
135            encoder.write_all(&body_bytes)?;
136            encoder.finish()?
137        }
138    };
139
140    let body = BoxBody::new(compressed_body);
141
142    response = response.set_body(body);
143
144    Ok(response)
145}
146
147/// Helper function to parse the Accept-Encoding header
148fn parse_accept_encoding(header_value: &str) -> Vec<(CompressionType, f32)> {
149    let mut encodings = Vec::new();
150
151    for part in header_value.split(',') {
152        let part = part.trim();
153        let mut tokens = part.split(';');
154
155        if let Some(encoding_str) = tokens.next() {
156            let quality = tokens
157                .find_map(|token| {
158                    if token.trim().starts_with("q=") {
159                        token.trim()[2..].parse::<f32>().ok()
160                    } else {
161                        None
162                    }
163                })
164                .unwrap_or(1.0);
165
166            let encoding = match encoding_str {
167                "gzip" => Some(CompressionType::Gzip),
168                "deflate" => Some(CompressionType::Deflate),
169                "*" => Some(CompressionType::Gzip),
170                _ => None,
171            };
172
173            if let Some(enc) = encoding {
174                encodings.push((enc, quality));
175            }
176        }
177    }
178
179    encodings.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
180
181    encodings
182}
183
184/// Helper function to negotiate the best encoding
185fn negotiate_encoding(
186    client_encodings: &[(CompressionType, f32)],
187    server_encodings: &[CompressionType],
188) -> Option<CompressionType> {
189    for (encoding, _) in client_encodings {
190        if server_encodings.contains(encoding) {
191            return Some(*encoding);
192        }
193    }
194    None
195}