1pub mod multipart;
7pub mod urlencoded;
8pub mod validation;
9
10use axum::{
11 body::Body,
12 extract::{FromRequest, Multipart, Request},
13 http::StatusCode,
14 middleware::Next,
15 response::{IntoResponse, Response},
16};
17use serde_json::json;
18use std::collections::HashMap;
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
23pub struct RouteInfo {
24 pub expects_json_body: bool,
26}
27
28pub type RouteRegistry = Arc<HashMap<(String, String), RouteInfo>>;
30
31#[cfg(not(tarpaulin_include))]
82pub async fn validate_content_type_middleware(request: Request, next: Next) -> Result<Response, Response> {
83 use axum::body::to_bytes;
84 use axum::http::Request as HttpRequest;
85
86 let (parts, body) = request.into_parts();
87 let headers = &parts.headers;
88
89 let route_info = parts.extensions.get::<RouteRegistry>().and_then(|registry| {
90 let method = parts.method.as_str();
91 let path = parts.uri.path();
92 registry.get(&(method.to_string(), path.to_string())).cloned()
93 });
94
95 let method = &parts.method;
96 if method == axum::http::Method::POST || method == axum::http::Method::PUT || method == axum::http::Method::PATCH {
97 if let Some(info) = &route_info
98 && info.expects_json_body
99 {
100 validation::validate_json_content_type(headers)?;
101 }
102
103 validation::validate_content_type_headers(headers, 0)?;
104
105 let (final_parts, final_body) = if let Some(content_type) = headers.get(axum::http::header::CONTENT_TYPE) {
106 if let Ok(content_type_str) = content_type.to_str() {
107 let parsed_mime = content_type_str.parse::<mime::Mime>().ok();
108
109 let is_multipart = parsed_mime
110 .as_ref()
111 .map(|mime| mime.type_() == mime::MULTIPART && mime.subtype() == "form-data")
112 .unwrap_or(false);
113
114 let is_form_urlencoded = parsed_mime
115 .as_ref()
116 .map(|mime| mime.type_() == mime::APPLICATION && mime.subtype() == "x-www-form-urlencoded")
117 .unwrap_or(false);
118
119 if is_multipart {
120 let mut response_headers = parts.headers.clone();
121
122 let request = HttpRequest::from_parts(parts, body);
123 let multipart = match Multipart::from_request(request, &()).await {
124 Ok(mp) => mp,
125 Err(e) => {
126 let error_body = json!({
127 "error": format!("Failed to parse multipart data: {}", e)
128 });
129 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
130 }
131 };
132
133 let json_body = match multipart::parse_multipart_to_json(multipart).await {
134 Ok(json) => json,
135 Err(e) => {
136 let error_body = json!({
137 "error": format!("Failed to process multipart data: {}", e)
138 });
139 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
140 }
141 };
142
143 let json_bytes = match serde_json::to_vec(&json_body) {
144 Ok(bytes) => bytes,
145 Err(e) => {
146 let error_body = json!({
147 "error": format!("Failed to serialize multipart data to JSON: {}", e)
148 });
149 return Err((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(error_body)).into_response());
150 }
151 };
152
153 response_headers.insert(
154 axum::http::header::CONTENT_TYPE,
155 axum::http::HeaderValue::from_static("application/json"),
156 );
157
158 let mut new_request = axum::http::Request::new(Body::from(json_bytes));
159 *new_request.headers_mut() = response_headers;
160
161 return Ok(next.run(new_request).await);
162 } else if is_form_urlencoded {
163 let body_bytes = match to_bytes(body, usize::MAX).await {
164 Ok(bytes) => bytes,
165 Err(_) => {
166 let error_body = json!({
167 "error": "Failed to read request body"
168 });
169 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
170 }
171 };
172
173 validation::validate_content_length(headers, body_bytes.len())?;
174
175 let json_body = if body_bytes.is_empty() {
176 serde_json::json!({})
177 } else {
178 match urlencoded::parse_urlencoded_to_json(&body_bytes) {
179 Ok(json_body) => json_body,
180 Err(e) => {
181 let error_body = json!({
182 "error": format!("Failed to parse URL-encoded form data: {}", e)
183 });
184 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
185 }
186 }
187 };
188
189 let json_bytes = match serde_json::to_vec(&json_body) {
190 Ok(bytes) => bytes,
191 Err(e) => {
192 let error_body = json!({
193 "error": format!("Failed to serialize URL-encoded form data to JSON: {}", e)
194 });
195 return Err((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(error_body)).into_response());
196 }
197 };
198
199 let mut new_parts = parts;
200 new_parts.headers.insert(
201 axum::http::header::CONTENT_TYPE,
202 axum::http::HeaderValue::from_static("application/json"),
203 );
204
205 (new_parts, Body::from(json_bytes))
206 } else {
207 let body_bytes = match to_bytes(body, usize::MAX).await {
208 Ok(bytes) => bytes,
209 Err(_) => {
210 let error_body = json!({
211 "error": "Failed to read request body"
212 });
213 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
214 }
215 };
216
217 validation::validate_content_length(headers, body_bytes.len())?;
218
219 let is_json = parsed_mime
220 .as_ref()
221 .map(validation::is_json_content_type)
222 .unwrap_or(false);
223
224 if is_json
225 && !body_bytes.is_empty()
226 && serde_json::from_slice::<serde_json::Value>(&body_bytes).is_err()
227 {
228 let error_body = json!({
229 "detail": "Invalid request format"
230 });
231 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
232 }
233
234 (parts, Body::from(body_bytes))
235 }
236 } else {
237 let body_bytes = match to_bytes(body, usize::MAX).await {
238 Ok(bytes) => bytes,
239 Err(_) => {
240 let error_body = json!({
241 "error": "Failed to read request body"
242 });
243 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
244 }
245 };
246
247 validation::validate_content_length(headers, body_bytes.len())?;
248
249 (parts, Body::from(body_bytes))
250 }
251 } else {
252 let body_bytes = match to_bytes(body, usize::MAX).await {
253 Ok(bytes) => bytes,
254 Err(_) => {
255 let error_body = json!({
256 "error": "Failed to read request body"
257 });
258 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
259 }
260 };
261
262 validation::validate_content_length(headers, body_bytes.len())?;
263
264 (parts, Body::from(body_bytes))
265 };
266
267 let request = HttpRequest::from_parts(final_parts, final_body);
268 Ok(next.run(request).await)
269 } else {
270 validation::validate_content_type_headers(headers, 0)?;
271
272 let request = HttpRequest::from_parts(parts, body);
273 Ok(next.run(request).await)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use axum::body::Body;
281 use axum::http::Request;
282
283 #[test]
284 fn test_route_info_creation() {
285 let info = RouteInfo {
286 expects_json_body: true,
287 };
288 assert!(info.expects_json_body);
289 }
290
291 #[test]
292 fn test_route_info_expects_json_body_true() {
293 let info = RouteInfo {
294 expects_json_body: true,
295 };
296 assert_eq!(info.expects_json_body, true);
297 }
298
299 #[test]
300 fn test_route_info_expects_json_body_false() {
301 let info = RouteInfo {
302 expects_json_body: false,
303 };
304 assert_eq!(info.expects_json_body, false);
305 }
306
307 #[test]
308 fn test_route_registry_empty() {
309 let registry: RouteRegistry = Arc::new(std::collections::HashMap::new());
310 assert_eq!(registry.len(), 0);
311 }
312
313 #[test]
314 fn test_route_registry_single_entry() {
315 let mut map = std::collections::HashMap::new();
316 map.insert(
317 ("POST".to_string(), "/api/users".to_string()),
318 RouteInfo {
319 expects_json_body: true,
320 },
321 );
322 let registry: RouteRegistry = Arc::new(map);
323
324 let key = ("POST".to_string(), "/api/users".to_string());
325 assert!(registry.contains_key(&key));
326 assert_eq!(registry[&key].expects_json_body, true);
327 }
328
329 #[test]
330 fn test_route_registry_multiple_entries() {
331 let mut map = std::collections::HashMap::new();
332 map.insert(
333 ("POST".to_string(), "/api/users".to_string()),
334 RouteInfo {
335 expects_json_body: true,
336 },
337 );
338 map.insert(
339 ("GET".to_string(), "/api/users".to_string()),
340 RouteInfo {
341 expects_json_body: false,
342 },
343 );
344 map.insert(
345 ("PUT".to_string(), "/api/users/{id}".to_string()),
346 RouteInfo {
347 expects_json_body: true,
348 },
349 );
350 let registry: RouteRegistry = Arc::new(map);
351
352 assert_eq!(registry.len(), 3);
353 }
354
355 #[test]
356 fn test_route_registry_lookup_missing_route() {
357 let map = std::collections::HashMap::new();
358 let registry: RouteRegistry = Arc::new(map);
359
360 let key = ("POST".to_string(), "/api/users".to_string());
361 assert!(!registry.contains_key(&key));
362 }
363
364 #[test]
365 fn test_request_with_zero_content_length() {
366 let headers = axum::http::HeaderMap::new();
367 assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_none());
368 }
369
370 #[test]
371 fn test_request_with_very_large_content_length() {
372 let mut headers = axum::http::HeaderMap::new();
373 let large_size = usize::MAX - 1;
374 headers.insert(
375 axum::http::header::CONTENT_LENGTH,
376 axum::http::HeaderValue::from_str(&large_size.to_string()).unwrap(),
377 );
378 assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_some());
379 }
380
381 #[test]
382 fn test_request_body_smaller_than_declared_length() {
383 let mut headers = axum::http::HeaderMap::new();
384 headers.insert(
385 axum::http::header::CONTENT_LENGTH,
386 axum::http::HeaderValue::from_static("1000"),
387 );
388 let result = super::validation::validate_content_length(&headers, 500);
389 assert!(
390 result.is_err(),
391 "Should reject when actual body is smaller than declared"
392 );
393 }
394
395 #[test]
396 fn test_request_body_larger_than_declared_length() {
397 let mut headers = axum::http::HeaderMap::new();
398 headers.insert(
399 axum::http::header::CONTENT_LENGTH,
400 axum::http::HeaderValue::from_static("500"),
401 );
402 let result = super::validation::validate_content_length(&headers, 1000);
403 assert!(
404 result.is_err(),
405 "Should reject when actual body is larger than declared"
406 );
407 }
408
409 #[test]
410 fn test_get_request_no_body_validation() {
411 let request = Request::builder()
412 .method(axum::http::Method::GET)
413 .uri("/api/users")
414 .body(Body::empty())
415 .unwrap();
416
417 let (parts, _body) = request.into_parts();
418 assert_eq!(parts.method, axum::http::Method::GET);
419 }
420
421 #[test]
422 fn test_delete_request_no_body_validation() {
423 let request = Request::builder()
424 .method(axum::http::Method::DELETE)
425 .uri("/api/users/1")
426 .body(Body::empty())
427 .unwrap();
428
429 let (parts, _body) = request.into_parts();
430 assert_eq!(parts.method, axum::http::Method::DELETE);
431 }
432
433 #[test]
434 fn test_post_request_requires_validation() {
435 let request = Request::builder()
436 .method(axum::http::Method::POST)
437 .uri("/api/users")
438 .body(Body::empty())
439 .unwrap();
440
441 let (parts, _body) = request.into_parts();
442 assert_eq!(parts.method, axum::http::Method::POST);
443 }
444
445 #[test]
446 fn test_put_request_requires_validation() {
447 let request = Request::builder()
448 .method(axum::http::Method::PUT)
449 .uri("/api/users/1")
450 .body(Body::empty())
451 .unwrap();
452
453 let (parts, _body) = request.into_parts();
454 assert_eq!(parts.method, axum::http::Method::PUT);
455 }
456
457 #[test]
458 fn test_patch_request_requires_validation() {
459 let request = Request::builder()
460 .method(axum::http::Method::PATCH)
461 .uri("/api/users/1")
462 .body(Body::empty())
463 .unwrap();
464
465 let (parts, _body) = request.into_parts();
466 assert_eq!(parts.method, axum::http::Method::PATCH);
467 }
468
469 #[test]
470 fn test_content_type_header_case_insensitive() {
471 let mut headers = axum::http::HeaderMap::new();
472 headers.insert(
473 axum::http::header::CONTENT_TYPE,
474 axum::http::HeaderValue::from_static("application/json"),
475 );
476
477 assert!(headers.get(axum::http::header::CONTENT_TYPE).is_some());
478 }
479
480 #[test]
481 fn test_content_length_header_case_insensitive() {
482 let mut headers = axum::http::HeaderMap::new();
483 headers.insert(
484 axum::http::header::CONTENT_LENGTH,
485 axum::http::HeaderValue::from_static("100"),
486 );
487
488 assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_some());
489 }
490
491 #[test]
492 fn test_custom_headers_case_preserved() {
493 let mut headers = axum::http::HeaderMap::new();
494 let custom_header: axum::http::HeaderName = "X-Custom-Header".parse().unwrap();
495 headers.insert(custom_header.clone(), axum::http::HeaderValue::from_static("value"));
496
497 assert!(headers.get(&custom_header).is_some());
498 }
499
500 #[test]
501 fn test_multipart_boundary_minimal() {
502 let mut headers = axum::http::HeaderMap::new();
503 headers.insert(
504 axum::http::header::CONTENT_TYPE,
505 axum::http::HeaderValue::from_static("multipart/form-data; boundary=x"),
506 );
507
508 let result = super::validation::validate_content_type_headers(&headers, 0);
509 assert!(result.is_ok(), "Minimal boundary should be accepted");
510 }
511
512 #[test]
513 fn test_multipart_boundary_with_numbers() {
514 let mut headers = axum::http::HeaderMap::new();
515 headers.insert(
516 axum::http::header::CONTENT_TYPE,
517 axum::http::HeaderValue::from_static("multipart/form-data; boundary=boundary123456"),
518 );
519
520 let result = super::validation::validate_content_type_headers(&headers, 0);
521 assert!(result.is_ok());
522 }
523
524 #[test]
525 fn test_multipart_boundary_with_special_chars() {
526 let mut headers = axum::http::HeaderMap::new();
527 headers.insert(
528 axum::http::header::CONTENT_TYPE,
529 axum::http::HeaderValue::from_static("multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"),
530 );
531
532 let result = super::validation::validate_content_type_headers(&headers, 0);
533 assert!(result.is_ok(), "Boundary with dashes should be accepted");
534 }
535
536 #[test]
537 fn test_multipart_empty_boundary() {
538 let mut headers = axum::http::HeaderMap::new();
539 headers.insert(
540 axum::http::header::CONTENT_TYPE,
541 axum::http::HeaderValue::from_static("multipart/form-data; boundary="),
542 );
543
544 let _result = super::validation::validate_content_type_headers(&headers, 0);
545 assert!(headers.get(axum::http::header::CONTENT_TYPE).is_some());
546 }
547
548 #[test]
549 fn test_invalid_json_body_detection() {
550 let invalid_json = r#"{"invalid": json without quotes}"#;
551 let _mime = "application/json".parse::<mime::Mime>().unwrap();
552
553 let result = serde_json::from_str::<serde_json::Value>(invalid_json);
554 assert!(result.is_err(), "Invalid JSON should fail parsing");
555 }
556
557 #[test]
558 fn test_valid_json_parsing() {
559 let valid_json = r#"{"key": "value"}"#;
560 let result = serde_json::from_str::<serde_json::Value>(valid_json);
561 assert!(result.is_ok(), "Valid JSON should parse successfully");
562 }
563
564 #[test]
565 fn test_empty_json_object() {
566 let empty_json = "{}";
567 let result = serde_json::from_str::<serde_json::Value>(empty_json);
568 assert!(result.is_ok());
569 let value = result.unwrap();
570 assert!(value.is_object());
571 assert_eq!(value.as_object().unwrap().len(), 0);
572 }
573
574 #[test]
575 fn test_form_data_mime_type() {
576 let mime = "multipart/form-data; boundary=xyz".parse::<mime::Mime>().unwrap();
577 assert_eq!(mime.type_(), mime::MULTIPART);
578 assert_eq!(mime.subtype(), "form-data");
579 }
580
581 #[test]
582 fn test_form_urlencoded_mime_type() {
583 let mime = "application/x-www-form-urlencoded".parse::<mime::Mime>().unwrap();
584 assert_eq!(mime.type_(), mime::APPLICATION);
585 assert_eq!(mime.subtype(), "x-www-form-urlencoded");
586 }
587
588 #[test]
589 fn test_json_mime_type() {
590 let mime = "application/json".parse::<mime::Mime>().unwrap();
591 assert_eq!(mime.type_(), mime::APPLICATION);
592 assert_eq!(mime.subtype(), mime::JSON);
593 }
594
595 #[test]
596 fn test_text_plain_mime_type() {
597 let mime = "text/plain".parse::<mime::Mime>().unwrap();
598 assert_eq!(mime.type_(), mime::TEXT);
599 assert_eq!(mime.subtype(), "plain");
600 }
601}