1use crate::core::{Middleware, Next};
4use crate::types::{Request, Response};
5use async_trait::async_trait;
6use std::collections::HashMap;
7
8#[derive(Debug)]
10pub struct ContentNegotiationMiddleware {
11 supported_types: HashMap<String, f32>, default_type: String,
13}
14
15impl ContentNegotiationMiddleware {
16 pub fn new() -> Self {
18 let mut supported_types = HashMap::new();
19 supported_types.insert("application/json".to_string(), 1.0);
20 supported_types.insert("text/html".to_string(), 0.9);
21 supported_types.insert("text/plain".to_string(), 0.8);
22 supported_types.insert("application/xml".to_string(), 0.7);
23
24 Self {
25 supported_types,
26 default_type: "application/json".to_string(),
27 }
28 }
29
30 pub fn support_type(mut self, mime_type: String, quality: f32) -> Self {
32 self.supported_types.insert(mime_type, quality);
33 self
34 }
35
36 pub fn default_type(mut self, mime_type: String) -> Self {
38 self.default_type = mime_type;
39 self
40 }
41
42 fn parse_accept_header(&self, accept_header: &str) -> Vec<(String, f32)> {
44 let mut types = Vec::new();
45
46 for part in accept_header.split(',') {
47 let part = part.trim();
48 if let Some((mime_type, quality_str)) = part.split_once(";q=") {
49 let mime_type = mime_type.trim().to_lowercase();
50 let quality = quality_str.parse::<f32>().unwrap_or(1.0);
51 types.push((mime_type, quality));
52 } else {
53 let mime_type = part.trim().to_lowercase();
54 types.push((mime_type, 1.0));
55 }
56 }
57
58 types.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
60 types
61 }
62
63 fn negotiate_content_type(&self, accept_header: &str) -> String {
65 let accepted_types = self.parse_accept_header(accept_header);
66
67 for (accepted_type, _) in accepted_types {
68 if accepted_type == "*/*" {
69 return self.default_type.clone();
70 }
71
72 if self.supported_types.contains_key(&accepted_type) {
73 return accepted_type;
74 }
75
76 if accepted_type.ends_with("/*") {
78 let prefix = accepted_type.trim_end_matches("/*");
79 for supported_type in self.supported_types.keys() {
80 if supported_type.starts_with(prefix) {
81 return supported_type.clone();
82 }
83 }
84 }
85 }
86
87 self.default_type.clone()
88 }
89}
90
91impl Default for ContentNegotiationMiddleware {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97#[async_trait]
98impl Middleware for ContentNegotiationMiddleware {
99 async fn call(&self, mut request: Request, next: Next) -> crate::Result<Response> {
100 if let Some(accept_header) = request.headers.get("accept") {
102 let best_type = self.negotiate_content_type(accept_header);
103 request
104 .extensions
105 .insert("negotiated_content_type".to_string(), best_type);
106 } else {
107 request.extensions.insert(
108 "negotiated_content_type".to_string(),
109 self.default_type.clone(),
110 );
111 }
112
113 let mut response = next.run(request).await?;
114
115 response
117 .headers
118 .insert("vary".to_string(), "accept, accept-encoding".to_string());
119
120 Ok(response)
121 }
122}
123
124#[derive(Debug)]
126pub struct CompressionMiddleware {
127 enabled: bool,
128 min_size: usize,
129}
130
131impl CompressionMiddleware {
132 pub fn new() -> Self {
134 Self {
135 enabled: true,
136 min_size: 1024, }
138 }
139
140 pub fn min_size(mut self, size: usize) -> Self {
142 self.min_size = size;
143 self
144 }
145
146 pub fn enabled(mut self, enabled: bool) -> Self {
148 self.enabled = enabled;
149 self
150 }
151
152 fn should_compress(&self, content_type: &str, content_length: usize) -> bool {
154 if !self.enabled || content_length < self.min_size {
155 return false;
156 }
157
158 let exclude_types = [
160 "image/",
161 "video/",
162 "audio/",
163 "application/zip",
164 "application/gzip",
165 "application/x-rar",
166 "application/pdf",
167 ];
168
169 !exclude_types
170 .iter()
171 .any(|&excluded| content_type.starts_with(excluded))
172 }
173}
174
175impl Default for CompressionMiddleware {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181#[async_trait]
182impl Middleware for CompressionMiddleware {
183 async fn call(&self, request: Request, next: Next) -> crate::Result<Response> {
184 let mut response = next.run(request).await?;
185
186 if self.enabled {
187 let content_type = response
189 .headers
190 .get("content-type")
191 .cloned()
192 .unwrap_or_else(|| "text/plain".to_string());
193
194 let content_length = 2048; if self.should_compress(&content_type, content_length) {
198 response
200 .headers
201 .insert("content-encoding".to_string(), "gzip".to_string());
202 response
203 .headers
204 .insert("vary".to_string(), "accept-encoding".to_string());
205 }
206 }
207
208 Ok(response)
209 }
210}
211
212#[derive(Debug, Default)]
214pub struct RangeMiddleware;
215
216impl RangeMiddleware {
217 pub fn new() -> Self {
218 Self
219 }
220
221 #[allow(dead_code)]
223 fn parse_range(&self, range_header: &str, content_length: usize) -> Option<(usize, usize)> {
224 if !range_header.starts_with("bytes=") {
225 return None;
226 }
227
228 let range_spec = range_header.strip_prefix("bytes=")?;
229
230 if let Some((start_str, end_str)) = range_spec.split_once('-') {
231 let (start, end) = if start_str.is_empty() {
232 if let Ok(suffix_length) = end_str.parse::<usize>() {
234 let start = content_length.saturating_sub(suffix_length);
235 let end = content_length.saturating_sub(1);
236 (start, end)
237 } else {
238 return None;
239 }
240 } else {
241 let start = start_str.parse().ok()?;
242 let end = if end_str.is_empty() {
243 content_length.saturating_sub(1)
245 } else {
246 end_str
247 .parse::<usize>()
248 .ok()?
249 .min(content_length.saturating_sub(1))
250 };
251 (start, end)
252 };
253
254 if start <= end && start < content_length {
255 Some((start, end))
256 } else {
257 None
258 }
259 } else {
260 None
261 }
262 }
263}
264
265#[async_trait]
266impl Middleware for RangeMiddleware {
267 async fn call(&self, request: Request, next: Next) -> crate::Result<Response> {
268 let mut response = next.run(request).await?;
269
270 response
272 .headers
273 .insert("accept-ranges".to_string(), "bytes".to_string());
274
275 Ok(response)
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_parse_accept_header() {
287 let middleware = ContentNegotiationMiddleware::new();
288 let types = middleware.parse_accept_header("text/html,application/xml;q=0.9,*/*;q=0.8");
289
290 assert_eq!(types.len(), 3);
291 assert_eq!(types[0], ("text/html".to_string(), 1.0));
292 assert_eq!(types[1], ("application/xml".to_string(), 0.9));
293 assert_eq!(types[2], ("*/*".to_string(), 0.8));
294 }
295
296 #[test]
297 fn test_content_type_negotiation() {
298 let middleware = ContentNegotiationMiddleware::new();
299
300 let result = middleware.negotiate_content_type("application/json,text/html;q=0.9");
301 assert_eq!(result, "application/json");
302
303 let result = middleware.negotiate_content_type("text/html,application/json;q=0.9");
304 assert_eq!(result, "text/html");
305
306 let result = middleware.negotiate_content_type("*/*");
307 assert_eq!(result, "application/json"); }
309
310 #[test]
311 fn test_compression_should_compress() {
312 let middleware = CompressionMiddleware::new();
313
314 assert!(middleware.should_compress("text/html", 2048));
315 assert!(!middleware.should_compress("text/html", 512)); assert!(!middleware.should_compress("image/jpeg", 2048)); }
318
319 #[test]
320 fn test_range_parsing() {
321 let middleware = RangeMiddleware::new();
322
323 assert_eq!(middleware.parse_range("bytes=0-499", 1000), Some((0, 499)));
324 assert_eq!(middleware.parse_range("bytes=500-", 1000), Some((500, 999)));
325 assert_eq!(middleware.parse_range("bytes=-500", 1000), Some((500, 999)));
326 assert_eq!(middleware.parse_range("bytes=invalid", 1000), None);
327 }
328}