spikard_http/middleware/
mod.rs

1//! HTTP middleware for request validation
2//!
3//! Provides middleware stack setup, JSON schema validation, multipart/form-data parsing,
4//! and URL-encoded form data handling.
5
6pub mod multipart;
7pub mod urlencoded;
8pub mod validation;
9
10use axum::{
11    body::Body,
12    extract::{FromRequest, Multipart, Request},
13    http::StatusCode,
14    middleware::Next,
15    response::{IntoResponse, Response},
16};
17use serde_json::json;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21/// Route information for middleware validation
22#[derive(Debug, Clone)]
23pub struct RouteInfo {
24    /// Whether this route expects a JSON request body
25    pub expects_json_body: bool,
26}
27
28/// Registry of route metadata indexed by (method, path)
29pub type RouteRegistry = Arc<HashMap<(String, String), RouteInfo>>;
30
31/// Middleware to validate Content-Type headers and related requirements
32///
33/// This middleware performs comprehensive request body validation and transformation:
34///
35/// - **Content-Type Validation:** Ensures the request's Content-Type header matches the
36///   expected format for the route (if configured).
37///
38/// - **Multipart Form Data:** Automatically parses `multipart/form-data` requests and
39///   transforms them into JSON format for uniform downstream processing.
40///
41/// - **URL-Encoded Forms:** Parses `application/x-www-form-urlencoded` requests and
42///   converts them to JSON.
43///
44/// - **JSON Validation:** Validates JSON request bodies for well-formedness (when the
45///   Content-Type is `application/json`).
46///
47/// - **Content-Length:** Validates that the Content-Length header is present and
48///   reasonable for POST, PUT, and PATCH requests.
49///
50/// # Behavior
51///
52/// For request methods POST, PUT, and PATCH:
53/// 1. Checks if the route expects a JSON body (via `RouteRegistry`)
54/// 2. Validates Content-Type headers based on route configuration
55/// 3. Parses the request body according to Content-Type:
56///    - `multipart/form-data` → JSON (form fields as object properties)
57///    - `application/x-www-form-urlencoded` → JSON (URL parameters as object)
58///    - `application/json` → Validates JSON syntax
59/// 4. Transforms the request to have `Content-Type: application/json`
60/// 5. Passes the transformed request to the next middleware
61///
62/// For GET, DELETE, and other methods: passes through with minimal validation.
63///
64/// # Errors
65///
66/// Returns HTTP error responses for:
67/// - `400 Bad Request` - Failed to read request body, invalid JSON, malformed forms, invalid Content-Length
68/// - `500 Internal Server Error` - Failed to serialize transformed body
69///
70/// # Examples
71///
72/// ```rust
73/// use axum::{middleware::Next, extract::Request};
74/// use spikard_http::middleware::validate_content_type_middleware;
75///
76/// // This is typically used as middleware in an Axum router:
77/// // router.layer(axum::middleware::from_fn(validate_content_type_middleware))
78/// ```
79///
80/// Coverage: Tested via integration tests (multipart and form parsing tested end-to-end)
81#[cfg(not(tarpaulin_include))]
82pub async fn validate_content_type_middleware(request: Request, next: Next) -> Result<Response, Response> {
83    use axum::body::to_bytes;
84    use axum::http::Request as HttpRequest;
85
86    let (parts, body) = request.into_parts();
87    let headers = &parts.headers;
88
89    let route_info = parts.extensions.get::<RouteRegistry>().and_then(|registry| {
90        let method = parts.method.as_str();
91        let path = parts.uri.path();
92        registry.get(&(method.to_string(), path.to_string())).cloned()
93    });
94
95    let method = &parts.method;
96    if method == axum::http::Method::POST || method == axum::http::Method::PUT || method == axum::http::Method::PATCH {
97        if let Some(info) = &route_info
98            && info.expects_json_body
99        {
100            validation::validate_json_content_type(headers)?;
101        }
102
103        validation::validate_content_type_headers(headers, 0)?;
104
105        let (final_parts, final_body) = if let Some(content_type) = headers.get(axum::http::header::CONTENT_TYPE) {
106            if let Ok(content_type_str) = content_type.to_str() {
107                let parsed_mime = content_type_str.parse::<mime::Mime>().ok();
108
109                let is_multipart = parsed_mime
110                    .as_ref()
111                    .map(|mime| mime.type_() == mime::MULTIPART && mime.subtype() == "form-data")
112                    .unwrap_or(false);
113
114                let is_form_urlencoded = parsed_mime
115                    .as_ref()
116                    .map(|mime| mime.type_() == mime::APPLICATION && mime.subtype() == "x-www-form-urlencoded")
117                    .unwrap_or(false);
118
119                if is_multipart {
120                    let mut response_headers = parts.headers.clone();
121
122                    let request = HttpRequest::from_parts(parts, body);
123                    let multipart = match Multipart::from_request(request, &()).await {
124                        Ok(mp) => mp,
125                        Err(e) => {
126                            let error_body = json!({
127                                "error": format!("Failed to parse multipart data: {}", e)
128                            });
129                            return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
130                        }
131                    };
132
133                    let json_body = match multipart::parse_multipart_to_json(multipart).await {
134                        Ok(json) => json,
135                        Err(e) => {
136                            let error_body = json!({
137                                "error": format!("Failed to process multipart data: {}", e)
138                            });
139                            return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
140                        }
141                    };
142
143                    let json_bytes = match serde_json::to_vec(&json_body) {
144                        Ok(bytes) => bytes,
145                        Err(e) => {
146                            let error_body = json!({
147                                "error": format!("Failed to serialize multipart data to JSON: {}", e)
148                            });
149                            return Err((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(error_body)).into_response());
150                        }
151                    };
152
153                    response_headers.insert(
154                        axum::http::header::CONTENT_TYPE,
155                        axum::http::HeaderValue::from_static("application/json"),
156                    );
157
158                    let mut new_request = axum::http::Request::new(Body::from(json_bytes));
159                    *new_request.headers_mut() = response_headers;
160
161                    return Ok(next.run(new_request).await);
162                } else if is_form_urlencoded {
163                    let body_bytes = match to_bytes(body, usize::MAX).await {
164                        Ok(bytes) => bytes,
165                        Err(_) => {
166                            let error_body = json!({
167                                "error": "Failed to read request body"
168                            });
169                            return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
170                        }
171                    };
172
173                    validation::validate_content_length(headers, body_bytes.len())?;
174
175                    let json_body = if body_bytes.is_empty() {
176                        serde_json::json!({})
177                    } else {
178                        match urlencoded::parse_urlencoded_to_json(&body_bytes) {
179                            Ok(json_body) => json_body,
180                            Err(e) => {
181                                let error_body = json!({
182                                    "error": format!("Failed to parse URL-encoded form data: {}", e)
183                                });
184                                return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
185                            }
186                        }
187                    };
188
189                    let json_bytes = match serde_json::to_vec(&json_body) {
190                        Ok(bytes) => bytes,
191                        Err(e) => {
192                            let error_body = json!({
193                                "error": format!("Failed to serialize URL-encoded form data to JSON: {}", e)
194                            });
195                            return Err((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(error_body)).into_response());
196                        }
197                    };
198
199                    let mut new_parts = parts;
200                    new_parts.headers.insert(
201                        axum::http::header::CONTENT_TYPE,
202                        axum::http::HeaderValue::from_static("application/json"),
203                    );
204
205                    (new_parts, Body::from(json_bytes))
206                } else {
207                    let body_bytes = match to_bytes(body, usize::MAX).await {
208                        Ok(bytes) => bytes,
209                        Err(_) => {
210                            let error_body = json!({
211                                "error": "Failed to read request body"
212                            });
213                            return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
214                        }
215                    };
216
217                    validation::validate_content_length(headers, body_bytes.len())?;
218
219                    let is_json = parsed_mime
220                        .as_ref()
221                        .map(validation::is_json_content_type)
222                        .unwrap_or(false);
223
224                    if is_json
225                        && !body_bytes.is_empty()
226                        && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
227                    {
228                        let error_body = json!({
229                            "detail": "Invalid request format"
230                        });
231                        return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
232                    }
233
234                    (parts, Body::from(body_bytes))
235                }
236            } else {
237                let body_bytes = match to_bytes(body, usize::MAX).await {
238                    Ok(bytes) => bytes,
239                    Err(_) => {
240                        let error_body = json!({
241                            "error": "Failed to read request body"
242                        });
243                        return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
244                    }
245                };
246
247                validation::validate_content_length(headers, body_bytes.len())?;
248
249                (parts, Body::from(body_bytes))
250            }
251        } else {
252            let body_bytes = match to_bytes(body, usize::MAX).await {
253                Ok(bytes) => bytes,
254                Err(_) => {
255                    let error_body = json!({
256                        "error": "Failed to read request body"
257                    });
258                    return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
259                }
260            };
261
262            validation::validate_content_length(headers, body_bytes.len())?;
263
264            (parts, Body::from(body_bytes))
265        };
266
267        let request = HttpRequest::from_parts(final_parts, final_body);
268        Ok(next.run(request).await)
269    } else {
270        validation::validate_content_type_headers(headers, 0)?;
271
272        let request = HttpRequest::from_parts(parts, body);
273        Ok(next.run(request).await)
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use axum::body::Body;
281    use axum::http::Request;
282
283    #[test]
284    fn test_route_info_creation() {
285        let info = RouteInfo {
286            expects_json_body: true,
287        };
288        assert!(info.expects_json_body);
289    }
290
291    #[test]
292    fn test_route_info_expects_json_body_true() {
293        let info = RouteInfo {
294            expects_json_body: true,
295        };
296        assert_eq!(info.expects_json_body, true);
297    }
298
299    #[test]
300    fn test_route_info_expects_json_body_false() {
301        let info = RouteInfo {
302            expects_json_body: false,
303        };
304        assert_eq!(info.expects_json_body, false);
305    }
306
307    #[test]
308    fn test_route_registry_empty() {
309        let registry: RouteRegistry = Arc::new(std::collections::HashMap::new());
310        assert_eq!(registry.len(), 0);
311    }
312
313    #[test]
314    fn test_route_registry_single_entry() {
315        let mut map = std::collections::HashMap::new();
316        map.insert(
317            ("POST".to_string(), "/api/users".to_string()),
318            RouteInfo {
319                expects_json_body: true,
320            },
321        );
322        let registry: RouteRegistry = Arc::new(map);
323
324        let key = ("POST".to_string(), "/api/users".to_string());
325        assert!(registry.contains_key(&key));
326        assert_eq!(registry[&key].expects_json_body, true);
327    }
328
329    #[test]
330    fn test_route_registry_multiple_entries() {
331        let mut map = std::collections::HashMap::new();
332        map.insert(
333            ("POST".to_string(), "/api/users".to_string()),
334            RouteInfo {
335                expects_json_body: true,
336            },
337        );
338        map.insert(
339            ("GET".to_string(), "/api/users".to_string()),
340            RouteInfo {
341                expects_json_body: false,
342            },
343        );
344        map.insert(
345            ("PUT".to_string(), "/api/users/{id}".to_string()),
346            RouteInfo {
347                expects_json_body: true,
348            },
349        );
350        let registry: RouteRegistry = Arc::new(map);
351
352        assert_eq!(registry.len(), 3);
353    }
354
355    #[test]
356    fn test_route_registry_lookup_missing_route() {
357        let map = std::collections::HashMap::new();
358        let registry: RouteRegistry = Arc::new(map);
359
360        let key = ("POST".to_string(), "/api/users".to_string());
361        assert!(!registry.contains_key(&key));
362    }
363
364    #[test]
365    fn test_request_with_zero_content_length() {
366        let headers = axum::http::HeaderMap::new();
367        assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_none());
368    }
369
370    #[test]
371    fn test_request_with_very_large_content_length() {
372        let mut headers = axum::http::HeaderMap::new();
373        let large_size = usize::MAX - 1;
374        headers.insert(
375            axum::http::header::CONTENT_LENGTH,
376            axum::http::HeaderValue::from_str(&large_size.to_string()).unwrap(),
377        );
378        assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_some());
379    }
380
381    #[test]
382    fn test_request_body_smaller_than_declared_length() {
383        let mut headers = axum::http::HeaderMap::new();
384        headers.insert(
385            axum::http::header::CONTENT_LENGTH,
386            axum::http::HeaderValue::from_static("1000"),
387        );
388        let result = super::validation::validate_content_length(&headers, 500);
389        assert!(
390            result.is_err(),
391            "Should reject when actual body is smaller than declared"
392        );
393    }
394
395    #[test]
396    fn test_request_body_larger_than_declared_length() {
397        let mut headers = axum::http::HeaderMap::new();
398        headers.insert(
399            axum::http::header::CONTENT_LENGTH,
400            axum::http::HeaderValue::from_static("500"),
401        );
402        let result = super::validation::validate_content_length(&headers, 1000);
403        assert!(
404            result.is_err(),
405            "Should reject when actual body is larger than declared"
406        );
407    }
408
409    #[test]
410    fn test_get_request_no_body_validation() {
411        let request = Request::builder()
412            .method(axum::http::Method::GET)
413            .uri("/api/users")
414            .body(Body::empty())
415            .unwrap();
416
417        let (parts, _body) = request.into_parts();
418        assert_eq!(parts.method, axum::http::Method::GET);
419    }
420
421    #[test]
422    fn test_delete_request_no_body_validation() {
423        let request = Request::builder()
424            .method(axum::http::Method::DELETE)
425            .uri("/api/users/1")
426            .body(Body::empty())
427            .unwrap();
428
429        let (parts, _body) = request.into_parts();
430        assert_eq!(parts.method, axum::http::Method::DELETE);
431    }
432
433    #[test]
434    fn test_post_request_requires_validation() {
435        let request = Request::builder()
436            .method(axum::http::Method::POST)
437            .uri("/api/users")
438            .body(Body::empty())
439            .unwrap();
440
441        let (parts, _body) = request.into_parts();
442        assert_eq!(parts.method, axum::http::Method::POST);
443    }
444
445    #[test]
446    fn test_put_request_requires_validation() {
447        let request = Request::builder()
448            .method(axum::http::Method::PUT)
449            .uri("/api/users/1")
450            .body(Body::empty())
451            .unwrap();
452
453        let (parts, _body) = request.into_parts();
454        assert_eq!(parts.method, axum::http::Method::PUT);
455    }
456
457    #[test]
458    fn test_patch_request_requires_validation() {
459        let request = Request::builder()
460            .method(axum::http::Method::PATCH)
461            .uri("/api/users/1")
462            .body(Body::empty())
463            .unwrap();
464
465        let (parts, _body) = request.into_parts();
466        assert_eq!(parts.method, axum::http::Method::PATCH);
467    }
468
469    #[test]
470    fn test_content_type_header_case_insensitive() {
471        let mut headers = axum::http::HeaderMap::new();
472        headers.insert(
473            axum::http::header::CONTENT_TYPE,
474            axum::http::HeaderValue::from_static("application/json"),
475        );
476
477        assert!(headers.get(axum::http::header::CONTENT_TYPE).is_some());
478    }
479
480    #[test]
481    fn test_content_length_header_case_insensitive() {
482        let mut headers = axum::http::HeaderMap::new();
483        headers.insert(
484            axum::http::header::CONTENT_LENGTH,
485            axum::http::HeaderValue::from_static("100"),
486        );
487
488        assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_some());
489    }
490
491    #[test]
492    fn test_custom_headers_case_preserved() {
493        let mut headers = axum::http::HeaderMap::new();
494        let custom_header: axum::http::HeaderName = "X-Custom-Header".parse().unwrap();
495        headers.insert(custom_header.clone(), axum::http::HeaderValue::from_static("value"));
496
497        assert!(headers.get(&custom_header).is_some());
498    }
499
500    #[test]
501    fn test_multipart_boundary_minimal() {
502        let mut headers = axum::http::HeaderMap::new();
503        headers.insert(
504            axum::http::header::CONTENT_TYPE,
505            axum::http::HeaderValue::from_static("multipart/form-data; boundary=x"),
506        );
507
508        let result = super::validation::validate_content_type_headers(&headers, 0);
509        assert!(result.is_ok(), "Minimal boundary should be accepted");
510    }
511
512    #[test]
513    fn test_multipart_boundary_with_numbers() {
514        let mut headers = axum::http::HeaderMap::new();
515        headers.insert(
516            axum::http::header::CONTENT_TYPE,
517            axum::http::HeaderValue::from_static("multipart/form-data; boundary=boundary123456"),
518        );
519
520        let result = super::validation::validate_content_type_headers(&headers, 0);
521        assert!(result.is_ok());
522    }
523
524    #[test]
525    fn test_multipart_boundary_with_special_chars() {
526        let mut headers = axum::http::HeaderMap::new();
527        headers.insert(
528            axum::http::header::CONTENT_TYPE,
529            axum::http::HeaderValue::from_static("multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"),
530        );
531
532        let result = super::validation::validate_content_type_headers(&headers, 0);
533        assert!(result.is_ok(), "Boundary with dashes should be accepted");
534    }
535
536    #[test]
537    fn test_multipart_empty_boundary() {
538        let mut headers = axum::http::HeaderMap::new();
539        headers.insert(
540            axum::http::header::CONTENT_TYPE,
541            axum::http::HeaderValue::from_static("multipart/form-data; boundary="),
542        );
543
544        let _result = super::validation::validate_content_type_headers(&headers, 0);
545        assert!(headers.get(axum::http::header::CONTENT_TYPE).is_some());
546    }
547
548    #[test]
549    fn test_invalid_json_body_detection() {
550        let invalid_json = r#"{"invalid": json without quotes}"#;
551        let _mime = "application/json".parse::<mime::Mime>().unwrap();
552
553        let result = serde_json::from_str::<serde_json::Value>(invalid_json);
554        assert!(result.is_err(), "Invalid JSON should fail parsing");
555    }
556
557    #[test]
558    fn test_valid_json_parsing() {
559        let valid_json = r#"{"key": "value"}"#;
560        let result = serde_json::from_str::<serde_json::Value>(valid_json);
561        assert!(result.is_ok(), "Valid JSON should parse successfully");
562    }
563
564    #[test]
565    fn test_empty_json_object() {
566        let empty_json = "{}";
567        let result = serde_json::from_str::<serde_json::Value>(empty_json);
568        assert!(result.is_ok());
569        let value = result.unwrap();
570        assert!(value.is_object());
571        assert_eq!(value.as_object().unwrap().len(), 0);
572    }
573
574    #[test]
575    fn test_form_data_mime_type() {
576        let mime = "multipart/form-data; boundary=xyz".parse::<mime::Mime>().unwrap();
577        assert_eq!(mime.type_(), mime::MULTIPART);
578        assert_eq!(mime.subtype(), "form-data");
579    }
580
581    #[test]
582    fn test_form_urlencoded_mime_type() {
583        let mime = "application/x-www-form-urlencoded".parse::<mime::Mime>().unwrap();
584        assert_eq!(mime.type_(), mime::APPLICATION);
585        assert_eq!(mime.subtype(), "x-www-form-urlencoded");
586    }
587
588    #[test]
589    fn test_json_mime_type() {
590        let mime = "application/json".parse::<mime::Mime>().unwrap();
591        assert_eq!(mime.type_(), mime::APPLICATION);
592        assert_eq!(mime.subtype(), mime::JSON);
593    }
594
595    #[test]
596    fn test_text_plain_mime_type() {
597        let mime = "text/plain".parse::<mime::Mime>().unwrap();
598        assert_eq!(mime.type_(), mime::TEXT);
599        assert_eq!(mime.subtype(), "plain");
600    }
601}