Skip to main content

spikard_http/middleware/
mod.rs

1//! HTTP middleware for request validation
2//!
3//! Provides middleware stack setup, JSON schema validation, multipart/form-data parsing,
4//! and URL-encoded form data handling.
5
6pub mod multipart;
7pub mod urlencoded;
8pub mod validation;
9
10use axum::{
11    body::Body,
12    extract::State,
13    extract::{FromRequest, Multipart, Request},
14    http::StatusCode,
15    middleware::Next,
16    response::{IntoResponse, Response},
17};
18use serde_json::json;
19use std::cell::RefCell;
20use std::num::NonZeroUsize;
21
22thread_local! {
23    static URLENCODED_JSON_CACHE: RefCell<lru::LruCache<bytes::Bytes, bytes::Bytes>> =
24        RefCell::new(lru::LruCache::new(NonZeroUsize::new(256).expect("non-zero cache size")));
25}
26
27/// Route information for middleware validation
28#[derive(Debug, Clone)]
29pub struct RouteInfo {
30    /// Whether this route expects a JSON request body
31    pub expects_json_body: bool,
32}
33
34/// Request extension carrying the already-collected request body.
35///
36/// This avoids double-reading the body stream: middleware can read once for
37/// content-length / syntax checks, and request extraction can reuse the bytes.
38#[derive(Debug, Clone)]
39pub struct PreReadBody(pub bytes::Bytes);
40
41/// Request extension carrying a pre-parsed JSON body.
42#[derive(Debug, Clone)]
43pub struct PreParsedJson(pub serde_json::Value);
44
45/// Middleware to validate Content-Type headers and related requirements
46///
47/// This middleware performs comprehensive request body validation and transformation:
48///
49/// - **Content-Type Validation:** Ensures the request's Content-Type header matches the
50///   expected format for the route (if configured).
51///
52/// - **Multipart Form Data:** Automatically parses `multipart/form-data` requests and
53///   transforms them into JSON format for uniform downstream processing.
54///
55/// - **URL-Encoded Forms:** Parses `application/x-www-form-urlencoded` requests and
56///   converts them to JSON.
57///
58/// - **JSON Validation:** Validates JSON request bodies for well-formedness (when the
59///   Content-Type is `application/json`).
60///
61/// - **Content-Length:** Validates that the Content-Length header is present and
62///   reasonable for POST, PUT, and PATCH requests.
63///
64/// # Behavior
65///
66/// For request methods POST, PUT, and PATCH:
67/// 1. Checks if the route expects a JSON body (via route state)
68/// 2. Validates Content-Type headers based on route configuration
69/// 3. Parses the request body according to Content-Type:
70///    - `multipart/form-data` → JSON (form fields as object properties)
71///    - `application/x-www-form-urlencoded` → JSON (URL parameters as object)
72///    - `application/json` → Validates JSON syntax
73/// 4. Transforms the request to have `Content-Type: application/json`
74/// 5. Passes the transformed request to the next middleware
75///
76/// For GET, DELETE, and other methods: passes through with minimal validation.
77///
78/// # Errors
79///
80/// Returns HTTP error responses for:
81/// - `400 Bad Request` - Failed to read request body, invalid JSON, malformed forms, invalid Content-Length
82/// - `500 Internal Server Error` - Failed to serialize transformed body
83///
84/// # Examples
85///
86/// ```rust
87/// use axum::{middleware::Next, extract::Request};
88/// use spikard_http::middleware::validate_content_type_middleware;
89///
90/// // This is typically used as middleware in an Axum router:
91/// // router.layer(axum::middleware::from_fn(validate_content_type_middleware))
92/// ```
93///
94/// Coverage: Tested via integration tests (multipart and form parsing tested end-to-end)
95#[cfg(not(tarpaulin_include))]
96pub async fn validate_content_type_middleware(
97    State(route_info): State<RouteInfo>,
98    request: Request,
99    next: Next,
100) -> Result<Response, Response> {
101    use axum::body::to_bytes;
102    use axum::http::Request as HttpRequest;
103
104    let (mut parts, body) = request.into_parts();
105    let headers = &parts.headers;
106
107    let method = &parts.method;
108    if method == axum::http::Method::POST || method == axum::http::Method::PUT || method == axum::http::Method::PATCH {
109        if route_info.expects_json_body {
110            validation::validate_json_content_type(headers)?;
111        }
112
113        let content_kind = validation::validate_content_type_headers_and_classify(headers, 0)?;
114
115        let mut parsed_json: Option<serde_json::Value> = None;
116        let out_bytes: bytes::Bytes = match content_kind {
117            Some(validation::ContentTypeKind::Multipart) => {
118                let body_bytes = match to_bytes(body, usize::MAX).await {
119                    Ok(bytes) => bytes,
120                    Err(_) => {
121                        let error_body = json!({
122                            "error": "Failed to read request body"
123                        });
124                        return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
125                    }
126                };
127
128                if headers.get(axum::http::header::CONTENT_LENGTH).is_some() {
129                    validation::validate_content_length(headers, body_bytes.len())?;
130                }
131
132                let mut parse_request = HttpRequest::new(Body::from(body_bytes));
133                *parse_request.headers_mut() = parts.headers.clone();
134
135                let multipart = match Multipart::from_request(parse_request, &()).await {
136                    Ok(mp) => mp,
137                    Err(e) => {
138                        let error_body = json!({
139                            "error": format!("Failed to parse multipart data: {}", e)
140                        });
141                        return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
142                    }
143                };
144
145                let json_body = match multipart::parse_multipart_to_json(multipart).await {
146                    Ok(json) => json,
147                    Err(e) => {
148                        let error_body = json!({
149                            "error": format!("Failed to process multipart data: {}", e)
150                        });
151                        return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
152                    }
153                };
154
155                let json_bytes = match serde_json::to_vec(&json_body) {
156                    Ok(bytes) => bytes,
157                    Err(e) => {
158                        let error_body = json!({
159                            "error": format!("Failed to serialize multipart data to JSON: {}", e)
160                        });
161                        return Err((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(error_body)).into_response());
162                    }
163                };
164
165                parsed_json = Some(json_body);
166                parts.headers.insert(
167                    axum::http::header::CONTENT_TYPE,
168                    axum::http::HeaderValue::from_static("application/json"),
169                );
170                if let Ok(value) = axum::http::HeaderValue::from_str(&json_bytes.len().to_string()) {
171                    parts.headers.insert(axum::http::header::CONTENT_LENGTH, value);
172                }
173                bytes::Bytes::from(json_bytes)
174            }
175            Some(validation::ContentTypeKind::FormUrlencoded) => {
176                let body_bytes = match to_bytes(body, usize::MAX).await {
177                    Ok(bytes) => bytes,
178                    Err(_) => {
179                        let error_body = json!({
180                            "error": "Failed to read request body"
181                        });
182                        return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
183                    }
184                };
185
186                if headers.get(axum::http::header::CONTENT_LENGTH).is_some() {
187                    validation::validate_content_length(headers, body_bytes.len())?;
188                }
189
190                parts.headers.insert(
191                    axum::http::header::CONTENT_TYPE,
192                    axum::http::HeaderValue::from_static("application/json"),
193                );
194                if let Some(cached) = URLENCODED_JSON_CACHE.with(|cache| cache.borrow_mut().get(&body_bytes).cloned()) {
195                    cached
196                } else {
197                    let json_body = if body_bytes.is_empty() {
198                        serde_json::json!({})
199                    } else {
200                        match urlencoded::parse_urlencoded_to_json(&body_bytes) {
201                            Ok(json_body) => json_body,
202                            Err(e) => {
203                                let error_body = json!({
204                                    "error": format!("Failed to parse URL-encoded form data: {}", e)
205                                });
206                                return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
207                            }
208                        }
209                    };
210
211                    let json_bytes = match serde_json::to_vec(&json_body) {
212                        Ok(bytes) => bytes,
213                        Err(e) => {
214                            let error_body = json!({
215                                "error": format!("Failed to serialize URL-encoded form data to JSON: {}", e)
216                            });
217                            return Err((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(error_body)).into_response());
218                        }
219                    };
220
221                    let json_bytes = bytes::Bytes::from(json_bytes);
222                    parsed_json = Some(json_body);
223                    URLENCODED_JSON_CACHE.with(|cache| {
224                        cache.borrow_mut().put(body_bytes.clone(), json_bytes.clone());
225                    });
226                    json_bytes
227                }
228            }
229            Some(validation::ContentTypeKind::Json) | Some(validation::ContentTypeKind::Other) => {
230                let body_bytes = match to_bytes(body, usize::MAX).await {
231                    Ok(bytes) => bytes,
232                    Err(_) => {
233                        let error_body = json!({
234                            "error": "Failed to read request body"
235                        });
236                        return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
237                    }
238                };
239
240                if headers.get(axum::http::header::CONTENT_LENGTH).is_some() {
241                    validation::validate_content_length(headers, body_bytes.len())?;
242                }
243
244                let should_validate_json =
245                    route_info.expects_json_body && matches!(content_kind, Some(validation::ContentTypeKind::Json));
246                if should_validate_json && !body_bytes.is_empty() {
247                    match serde_json::from_slice::<serde_json::Value>(&body_bytes) {
248                        Ok(value) => parsed_json = Some(value),
249                        Err(_) => {
250                            let error_body = json!({
251                                "detail": "Invalid request format"
252                            });
253                            return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
254                        }
255                    }
256                }
257
258                body_bytes
259            }
260            None => {
261                let body_bytes = match to_bytes(body, usize::MAX).await {
262                    Ok(bytes) => bytes,
263                    Err(_) => {
264                        let error_body = json!({
265                            "error": "Failed to read request body"
266                        });
267                        return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
268                    }
269                };
270
271                if headers.get(axum::http::header::CONTENT_LENGTH).is_some() {
272                    validation::validate_content_length(headers, body_bytes.len())?;
273                }
274                body_bytes
275            }
276        };
277
278        parts.extensions.insert(PreReadBody(out_bytes));
279        if let Some(parsed) = parsed_json {
280            parts.extensions.insert(PreParsedJson(parsed));
281        }
282
283        let request = HttpRequest::from_parts(parts, Body::empty());
284        Ok(next.run(request).await)
285    } else {
286        validation::validate_content_type_headers(headers, 0)?;
287
288        let request = HttpRequest::from_parts(parts, body);
289        Ok(next.run(request).await)
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use axum::body::Body;
297    use axum::http::Request;
298
299    #[test]
300    fn test_route_info_creation() {
301        let info = RouteInfo {
302            expects_json_body: true,
303        };
304        assert!(info.expects_json_body);
305    }
306
307    #[test]
308    fn test_route_info_expects_json_body_true() {
309        let info = RouteInfo {
310            expects_json_body: true,
311        };
312        assert_eq!(info.expects_json_body, true);
313    }
314
315    #[test]
316    fn test_route_info_expects_json_body_false() {
317        let info = RouteInfo {
318            expects_json_body: false,
319        };
320        assert_eq!(info.expects_json_body, false);
321    }
322
323    #[test]
324    fn test_request_with_zero_content_length() {
325        let headers = axum::http::HeaderMap::new();
326        assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_none());
327    }
328
329    #[test]
330    fn test_request_with_very_large_content_length() {
331        let mut headers = axum::http::HeaderMap::new();
332        let large_size = usize::MAX - 1;
333        headers.insert(
334            axum::http::header::CONTENT_LENGTH,
335            axum::http::HeaderValue::from_str(&large_size.to_string()).unwrap(),
336        );
337        assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_some());
338    }
339
340    #[test]
341    fn test_request_body_smaller_than_declared_length() {
342        let mut headers = axum::http::HeaderMap::new();
343        headers.insert(
344            axum::http::header::CONTENT_LENGTH,
345            axum::http::HeaderValue::from_static("1000"),
346        );
347        let result = super::validation::validate_content_length(&headers, 500);
348        assert!(
349            result.is_err(),
350            "Should reject when actual body is smaller than declared"
351        );
352    }
353
354    #[test]
355    fn test_request_body_larger_than_declared_length() {
356        let mut headers = axum::http::HeaderMap::new();
357        headers.insert(
358            axum::http::header::CONTENT_LENGTH,
359            axum::http::HeaderValue::from_static("500"),
360        );
361        let result = super::validation::validate_content_length(&headers, 1000);
362        assert!(
363            result.is_err(),
364            "Should reject when actual body is larger than declared"
365        );
366    }
367
368    #[test]
369    fn test_get_request_no_body_validation() {
370        let request = Request::builder()
371            .method(axum::http::Method::GET)
372            .uri("/api/users")
373            .body(Body::empty())
374            .unwrap();
375
376        let (parts, _body) = request.into_parts();
377        assert_eq!(parts.method, axum::http::Method::GET);
378    }
379
380    #[test]
381    fn test_delete_request_no_body_validation() {
382        let request = Request::builder()
383            .method(axum::http::Method::DELETE)
384            .uri("/api/users/1")
385            .body(Body::empty())
386            .unwrap();
387
388        let (parts, _body) = request.into_parts();
389        assert_eq!(parts.method, axum::http::Method::DELETE);
390    }
391
392    #[test]
393    fn test_post_request_requires_validation() {
394        let request = Request::builder()
395            .method(axum::http::Method::POST)
396            .uri("/api/users")
397            .body(Body::empty())
398            .unwrap();
399
400        let (parts, _body) = request.into_parts();
401        assert_eq!(parts.method, axum::http::Method::POST);
402    }
403
404    #[test]
405    fn test_put_request_requires_validation() {
406        let request = Request::builder()
407            .method(axum::http::Method::PUT)
408            .uri("/api/users/1")
409            .body(Body::empty())
410            .unwrap();
411
412        let (parts, _body) = request.into_parts();
413        assert_eq!(parts.method, axum::http::Method::PUT);
414    }
415
416    #[test]
417    fn test_patch_request_requires_validation() {
418        let request = Request::builder()
419            .method(axum::http::Method::PATCH)
420            .uri("/api/users/1")
421            .body(Body::empty())
422            .unwrap();
423
424        let (parts, _body) = request.into_parts();
425        assert_eq!(parts.method, axum::http::Method::PATCH);
426    }
427
428    #[test]
429    fn test_content_type_header_case_insensitive() {
430        let mut headers = axum::http::HeaderMap::new();
431        headers.insert(
432            axum::http::header::CONTENT_TYPE,
433            axum::http::HeaderValue::from_static("application/json"),
434        );
435
436        assert!(headers.get(axum::http::header::CONTENT_TYPE).is_some());
437    }
438
439    #[test]
440    fn test_content_length_header_case_insensitive() {
441        let mut headers = axum::http::HeaderMap::new();
442        headers.insert(
443            axum::http::header::CONTENT_LENGTH,
444            axum::http::HeaderValue::from_static("100"),
445        );
446
447        assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_some());
448    }
449
450    #[test]
451    fn test_custom_headers_case_preserved() {
452        let mut headers = axum::http::HeaderMap::new();
453        let custom_header: axum::http::HeaderName = "X-Custom-Header".parse().unwrap();
454        headers.insert(custom_header.clone(), axum::http::HeaderValue::from_static("value"));
455
456        assert!(headers.get(&custom_header).is_some());
457    }
458
459    #[test]
460    fn test_multipart_boundary_minimal() {
461        let mut headers = axum::http::HeaderMap::new();
462        headers.insert(
463            axum::http::header::CONTENT_TYPE,
464            axum::http::HeaderValue::from_static("multipart/form-data; boundary=x"),
465        );
466
467        let result = super::validation::validate_content_type_headers(&headers, 0);
468        assert!(result.is_ok(), "Minimal boundary should be accepted");
469    }
470
471    #[test]
472    fn test_multipart_boundary_with_numbers() {
473        let mut headers = axum::http::HeaderMap::new();
474        headers.insert(
475            axum::http::header::CONTENT_TYPE,
476            axum::http::HeaderValue::from_static("multipart/form-data; boundary=boundary123456"),
477        );
478
479        let result = super::validation::validate_content_type_headers(&headers, 0);
480        assert!(result.is_ok());
481    }
482
483    #[test]
484    fn test_multipart_boundary_with_special_chars() {
485        let mut headers = axum::http::HeaderMap::new();
486        headers.insert(
487            axum::http::header::CONTENT_TYPE,
488            axum::http::HeaderValue::from_static("multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"),
489        );
490
491        let result = super::validation::validate_content_type_headers(&headers, 0);
492        assert!(result.is_ok(), "Boundary with dashes should be accepted");
493    }
494
495    #[test]
496    fn test_multipart_empty_boundary() {
497        let mut headers = axum::http::HeaderMap::new();
498        headers.insert(
499            axum::http::header::CONTENT_TYPE,
500            axum::http::HeaderValue::from_static("multipart/form-data; boundary="),
501        );
502
503        let _result = super::validation::validate_content_type_headers(&headers, 0);
504        assert!(headers.get(axum::http::header::CONTENT_TYPE).is_some());
505    }
506
507    #[test]
508    fn test_invalid_json_body_detection() {
509        let invalid_json = r#"{"invalid": json without quotes}"#;
510        let _mime = "application/json".parse::<mime::Mime>().unwrap();
511
512        let result = serde_json::from_str::<serde_json::Value>(invalid_json);
513        assert!(result.is_err(), "Invalid JSON should fail parsing");
514    }
515
516    #[test]
517    fn test_valid_json_parsing() {
518        let valid_json = r#"{"key": "value"}"#;
519        let result = serde_json::from_str::<serde_json::Value>(valid_json);
520        assert!(result.is_ok(), "Valid JSON should parse successfully");
521    }
522
523    #[test]
524    fn test_empty_json_object() {
525        let empty_json = "{}";
526        let result = serde_json::from_str::<serde_json::Value>(empty_json);
527        assert!(result.is_ok());
528        let value = result.unwrap();
529        assert!(value.is_object());
530        assert_eq!(value.as_object().unwrap().len(), 0);
531    }
532
533    #[test]
534    fn test_form_data_mime_type() {
535        let mime = "multipart/form-data; boundary=xyz".parse::<mime::Mime>().unwrap();
536        assert_eq!(mime.type_(), mime::MULTIPART);
537        assert_eq!(mime.subtype(), "form-data");
538    }
539
540    #[test]
541    fn test_form_urlencoded_mime_type() {
542        let mime = "application/x-www-form-urlencoded".parse::<mime::Mime>().unwrap();
543        assert_eq!(mime.type_(), mime::APPLICATION);
544        assert_eq!(mime.subtype(), "x-www-form-urlencoded");
545    }
546
547    #[test]
548    fn test_json_mime_type() {
549        let mime = "application/json".parse::<mime::Mime>().unwrap();
550        assert_eq!(mime.type_(), mime::APPLICATION);
551        assert_eq!(mime.subtype(), mime::JSON);
552    }
553
554    #[test]
555    fn test_text_plain_mime_type() {
556        let mime = "text/plain".parse::<mime::Mime>().unwrap();
557        assert_eq!(mime.type_(), mime::TEXT);
558        assert_eq!(mime.subtype(), "plain");
559    }
560}