1use crate::core::{Middleware, Next};
4use crate::types::{Request, Response};
5use async_trait::async_trait;
6use std::collections::HashMap;
7
8#[derive(Debug)]
10pub struct ContentNegotiationMiddleware {
11    supported_types: HashMap<String, f32>, default_type: String,
13}
14
15impl ContentNegotiationMiddleware {
16    pub fn new() -> Self {
18        let mut supported_types = HashMap::new();
19        supported_types.insert("application/json".to_string(), 1.0);
20        supported_types.insert("text/html".to_string(), 0.9);
21        supported_types.insert("text/plain".to_string(), 0.8);
22        supported_types.insert("application/xml".to_string(), 0.7);
23
24        Self {
25            supported_types,
26            default_type: "application/json".to_string(),
27        }
28    }
29
30    pub fn support_type(mut self, mime_type: String, quality: f32) -> Self {
32        self.supported_types.insert(mime_type, quality);
33        self
34    }
35
36    pub fn default_type(mut self, mime_type: String) -> Self {
38        self.default_type = mime_type;
39        self
40    }
41
42    fn parse_accept_header(&self, accept_header: &str) -> Vec<(String, f32)> {
44        let mut types = Vec::new();
45
46        for part in accept_header.split(',') {
47            let part = part.trim();
48            if let Some((mime_type, quality_str)) = part.split_once(";q=") {
49                let mime_type = mime_type.trim().to_lowercase();
50                let quality = quality_str.parse::<f32>().unwrap_or(1.0);
51                types.push((mime_type, quality));
52            } else {
53                let mime_type = part.trim().to_lowercase();
54                types.push((mime_type, 1.0));
55            }
56        }
57
58        types.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
60        types
61    }
62
63    fn negotiate_content_type(&self, accept_header: &str) -> String {
65        let accepted_types = self.parse_accept_header(accept_header);
66
67        for (accepted_type, _) in accepted_types {
68            if accepted_type == "*/*" {
69                return self.default_type.clone();
70            }
71
72            if self.supported_types.contains_key(&accepted_type) {
73                return accepted_type;
74            }
75
76            if accepted_type.ends_with("/*") {
78                let prefix = accepted_type.trim_end_matches("/*");
79                for supported_type in self.supported_types.keys() {
80                    if supported_type.starts_with(prefix) {
81                        return supported_type.clone();
82                    }
83                }
84            }
85        }
86
87        self.default_type.clone()
88    }
89}
90
91impl Default for ContentNegotiationMiddleware {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97#[async_trait]
98impl Middleware for ContentNegotiationMiddleware {
99    async fn call(&self, mut request: Request, next: Next) -> crate::Result<Response> {
100        if let Some(accept_header) = request.headers.get("accept") {
102            let best_type = self.negotiate_content_type(accept_header);
103            request
104                .extensions
105                .insert("negotiated_content_type".to_string(), best_type);
106        } else {
107            request.extensions.insert(
108                "negotiated_content_type".to_string(),
109                self.default_type.clone(),
110            );
111        }
112
113        let mut response = next.run(request).await?;
114
115        response
117            .headers
118            .insert("vary".to_string(), "accept, accept-encoding".to_string());
119
120        Ok(response)
121    }
122}
123
124#[derive(Debug)]
126pub struct CompressionMiddleware {
127    enabled: bool,
128    min_size: usize,
129}
130
131impl CompressionMiddleware {
132    pub fn new() -> Self {
134        Self {
135            enabled: true,
136            min_size: 1024, }
138    }
139
140    pub fn min_size(mut self, size: usize) -> Self {
142        self.min_size = size;
143        self
144    }
145
146    pub fn enabled(mut self, enabled: bool) -> Self {
148        self.enabled = enabled;
149        self
150    }
151
152    fn should_compress(&self, content_type: &str, content_length: usize) -> bool {
154        if !self.enabled || content_length < self.min_size {
155            return false;
156        }
157
158        let exclude_types = [
160            "image/",
161            "video/",
162            "audio/",
163            "application/zip",
164            "application/gzip",
165            "application/x-rar",
166            "application/pdf",
167        ];
168
169        !exclude_types
170            .iter()
171            .any(|&excluded| content_type.starts_with(excluded))
172    }
173}
174
175impl Default for CompressionMiddleware {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181#[async_trait]
182impl Middleware for CompressionMiddleware {
183    async fn call(&self, request: Request, next: Next) -> crate::Result<Response> {
184        let mut response = next.run(request).await?;
185
186        if self.enabled {
187            let content_type = response
189                .headers
190                .get("content-type")
191                .cloned()
192                .unwrap_or_else(|| "text/plain".to_string());
193
194            let content_length = 2048; if self.should_compress(&content_type, content_length) {
198                response
200                    .headers
201                    .insert("content-encoding".to_string(), "gzip".to_string());
202                response
203                    .headers
204                    .insert("vary".to_string(), "accept-encoding".to_string());
205            }
206        }
207
208        Ok(response)
209    }
210}
211
212#[derive(Debug, Default)]
214pub struct RangeMiddleware;
215
216impl RangeMiddleware {
217    pub fn new() -> Self {
218        Self
219    }
220
221    #[allow(dead_code)]
223    fn parse_range(&self, range_header: &str, content_length: usize) -> Option<(usize, usize)> {
224        if !range_header.starts_with("bytes=") {
225            return None;
226        }
227
228        let range_spec = range_header.strip_prefix("bytes=")?;
229
230        if let Some((start_str, end_str)) = range_spec.split_once('-') {
231            let (start, end) = if start_str.is_empty() {
232                if let Ok(suffix_length) = end_str.parse::<usize>() {
234                    let start = content_length.saturating_sub(suffix_length);
235                    let end = content_length.saturating_sub(1);
236                    (start, end)
237                } else {
238                    return None;
239                }
240            } else {
241                let start = start_str.parse().ok()?;
242                let end = if end_str.is_empty() {
243                    content_length.saturating_sub(1)
245                } else {
246                    end_str
247                        .parse::<usize>()
248                        .ok()?
249                        .min(content_length.saturating_sub(1))
250                };
251                (start, end)
252            };
253
254            if start <= end && start < content_length {
255                Some((start, end))
256            } else {
257                None
258            }
259        } else {
260            None
261        }
262    }
263}
264
265#[async_trait]
266impl Middleware for RangeMiddleware {
267    async fn call(&self, request: Request, next: Next) -> crate::Result<Response> {
268        let mut response = next.run(request).await?;
269
270        response
272            .headers
273            .insert("accept-ranges".to_string(), "bytes".to_string());
274
275        Ok(response)
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_parse_accept_header() {
287        let middleware = ContentNegotiationMiddleware::new();
288        let types = middleware.parse_accept_header("text/html,application/xml;q=0.9,*/*;q=0.8");
289
290        assert_eq!(types.len(), 3);
291        assert_eq!(types[0], ("text/html".to_string(), 1.0));
292        assert_eq!(types[1], ("application/xml".to_string(), 0.9));
293        assert_eq!(types[2], ("*/*".to_string(), 0.8));
294    }
295
296    #[test]
297    fn test_content_type_negotiation() {
298        let middleware = ContentNegotiationMiddleware::new();
299
300        let result = middleware.negotiate_content_type("application/json,text/html;q=0.9");
301        assert_eq!(result, "application/json");
302
303        let result = middleware.negotiate_content_type("text/html,application/json;q=0.9");
304        assert_eq!(result, "text/html");
305
306        let result = middleware.negotiate_content_type("*/*");
307        assert_eq!(result, "application/json"); }
309
310    #[test]
311    fn test_compression_should_compress() {
312        let middleware = CompressionMiddleware::new();
313
314        assert!(middleware.should_compress("text/html", 2048));
315        assert!(!middleware.should_compress("text/html", 512)); assert!(!middleware.should_compress("image/jpeg", 2048)); }
318
319    #[test]
320    fn test_range_parsing() {
321        let middleware = RangeMiddleware::new();
322
323        assert_eq!(middleware.parse_range("bytes=0-499", 1000), Some((0, 499)));
324        assert_eq!(middleware.parse_range("bytes=500-", 1000), Some((500, 999)));
325        assert_eq!(middleware.parse_range("bytes=-500", 1000), Some((500, 999)));
326        assert_eq!(middleware.parse_range("bytes=invalid", 1000), None);
327    }
328}