1pub mod multipart;
7pub mod urlencoded;
8pub mod validation;
9
10use axum::{
11 body::Body,
12 extract::State,
13 extract::{FromRequest, Multipart, Request},
14 http::StatusCode,
15 middleware::Next,
16 response::{IntoResponse, Response},
17};
18use serde_json::json;
19use std::cell::RefCell;
20use std::num::NonZeroUsize;
21
22thread_local! {
23 static URLENCODED_JSON_CACHE: RefCell<lru::LruCache<bytes::Bytes, bytes::Bytes>> =
24 RefCell::new(lru::LruCache::new(NonZeroUsize::new(256).expect("non-zero cache size")));
25}
26
27#[derive(Debug, Clone)]
29pub struct RouteInfo {
30 pub expects_json_body: bool,
32}
33
34#[derive(Debug, Clone)]
39pub struct PreReadBody(pub bytes::Bytes);
40
41#[derive(Debug, Clone)]
43pub struct PreParsedJson(pub serde_json::Value);
44
45#[cfg(not(tarpaulin_include))]
96pub async fn validate_content_type_middleware(
97 State(route_info): State<RouteInfo>,
98 request: Request,
99 next: Next,
100) -> Result<Response, Response> {
101 use axum::body::to_bytes;
102 use axum::http::Request as HttpRequest;
103
104 let (mut parts, body) = request.into_parts();
105 let headers = &parts.headers;
106
107 let method = &parts.method;
108 if method == axum::http::Method::POST || method == axum::http::Method::PUT || method == axum::http::Method::PATCH {
109 if route_info.expects_json_body {
110 validation::validate_json_content_type(headers)?;
111 }
112
113 let content_kind = validation::validate_content_type_headers_and_classify(headers, 0)?;
114
115 let mut parsed_json: Option<serde_json::Value> = None;
116 let out_bytes: bytes::Bytes = match content_kind {
117 Some(validation::ContentTypeKind::Multipart) => {
118 let body_bytes = match to_bytes(body, usize::MAX).await {
119 Ok(bytes) => bytes,
120 Err(_) => {
121 let error_body = json!({
122 "error": "Failed to read request body"
123 });
124 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
125 }
126 };
127
128 if headers.get(axum::http::header::CONTENT_LENGTH).is_some() {
129 validation::validate_content_length(headers, body_bytes.len())?;
130 }
131
132 let mut parse_request = HttpRequest::new(Body::from(body_bytes));
133 *parse_request.headers_mut() = parts.headers.clone();
134
135 let multipart = match Multipart::from_request(parse_request, &()).await {
136 Ok(mp) => mp,
137 Err(e) => {
138 let error_body = json!({
139 "error": format!("Failed to parse multipart data: {}", e)
140 });
141 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
142 }
143 };
144
145 let json_body = match multipart::parse_multipart_to_json(multipart).await {
146 Ok(json) => json,
147 Err(e) => {
148 let error_body = json!({
149 "error": format!("Failed to process multipart data: {}", e)
150 });
151 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
152 }
153 };
154
155 let json_bytes = match serde_json::to_vec(&json_body) {
156 Ok(bytes) => bytes,
157 Err(e) => {
158 let error_body = json!({
159 "error": format!("Failed to serialize multipart data to JSON: {}", e)
160 });
161 return Err((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(error_body)).into_response());
162 }
163 };
164
165 parsed_json = Some(json_body);
166 parts.headers.insert(
167 axum::http::header::CONTENT_TYPE,
168 axum::http::HeaderValue::from_static("application/json"),
169 );
170 if let Ok(value) = axum::http::HeaderValue::from_str(&json_bytes.len().to_string()) {
171 parts.headers.insert(axum::http::header::CONTENT_LENGTH, value);
172 }
173 bytes::Bytes::from(json_bytes)
174 }
175 Some(validation::ContentTypeKind::FormUrlencoded) => {
176 let body_bytes = match to_bytes(body, usize::MAX).await {
177 Ok(bytes) => bytes,
178 Err(_) => {
179 let error_body = json!({
180 "error": "Failed to read request body"
181 });
182 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
183 }
184 };
185
186 if headers.get(axum::http::header::CONTENT_LENGTH).is_some() {
187 validation::validate_content_length(headers, body_bytes.len())?;
188 }
189
190 parts.headers.insert(
191 axum::http::header::CONTENT_TYPE,
192 axum::http::HeaderValue::from_static("application/json"),
193 );
194 if let Some(cached) = URLENCODED_JSON_CACHE.with(|cache| cache.borrow_mut().get(&body_bytes).cloned()) {
195 cached
196 } else {
197 let json_body = if body_bytes.is_empty() {
198 serde_json::json!({})
199 } else {
200 match urlencoded::parse_urlencoded_to_json(&body_bytes) {
201 Ok(json_body) => json_body,
202 Err(e) => {
203 let error_body = json!({
204 "error": format!("Failed to parse URL-encoded form data: {}", e)
205 });
206 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
207 }
208 }
209 };
210
211 let json_bytes = match serde_json::to_vec(&json_body) {
212 Ok(bytes) => bytes,
213 Err(e) => {
214 let error_body = json!({
215 "error": format!("Failed to serialize URL-encoded form data to JSON: {}", e)
216 });
217 return Err((StatusCode::INTERNAL_SERVER_ERROR, axum::Json(error_body)).into_response());
218 }
219 };
220
221 let json_bytes = bytes::Bytes::from(json_bytes);
222 parsed_json = Some(json_body);
223 URLENCODED_JSON_CACHE.with(|cache| {
224 cache.borrow_mut().put(body_bytes.clone(), json_bytes.clone());
225 });
226 json_bytes
227 }
228 }
229 Some(validation::ContentTypeKind::Json) | Some(validation::ContentTypeKind::Other) => {
230 let body_bytes = match to_bytes(body, usize::MAX).await {
231 Ok(bytes) => bytes,
232 Err(_) => {
233 let error_body = json!({
234 "error": "Failed to read request body"
235 });
236 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
237 }
238 };
239
240 if headers.get(axum::http::header::CONTENT_LENGTH).is_some() {
241 validation::validate_content_length(headers, body_bytes.len())?;
242 }
243
244 let should_validate_json =
245 route_info.expects_json_body && matches!(content_kind, Some(validation::ContentTypeKind::Json));
246 if should_validate_json && !body_bytes.is_empty() {
247 match serde_json::from_slice::<serde_json::Value>(&body_bytes) {
248 Ok(value) => parsed_json = Some(value),
249 Err(_) => {
250 let error_body = json!({
251 "detail": "Invalid request format"
252 });
253 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
254 }
255 }
256 }
257
258 body_bytes
259 }
260 None => {
261 let body_bytes = match to_bytes(body, usize::MAX).await {
262 Ok(bytes) => bytes,
263 Err(_) => {
264 let error_body = json!({
265 "error": "Failed to read request body"
266 });
267 return Err((StatusCode::BAD_REQUEST, axum::Json(error_body)).into_response());
268 }
269 };
270
271 if headers.get(axum::http::header::CONTENT_LENGTH).is_some() {
272 validation::validate_content_length(headers, body_bytes.len())?;
273 }
274 body_bytes
275 }
276 };
277
278 parts.extensions.insert(PreReadBody(out_bytes));
279 if let Some(parsed) = parsed_json {
280 parts.extensions.insert(PreParsedJson(parsed));
281 }
282
283 let request = HttpRequest::from_parts(parts, Body::empty());
284 Ok(next.run(request).await)
285 } else {
286 validation::validate_content_type_headers(headers, 0)?;
287
288 let request = HttpRequest::from_parts(parts, body);
289 Ok(next.run(request).await)
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use axum::body::Body;
297 use axum::http::Request;
298
299 #[test]
300 fn test_route_info_creation() {
301 let info = RouteInfo {
302 expects_json_body: true,
303 };
304 assert!(info.expects_json_body);
305 }
306
307 #[test]
308 fn test_route_info_expects_json_body_true() {
309 let info = RouteInfo {
310 expects_json_body: true,
311 };
312 assert_eq!(info.expects_json_body, true);
313 }
314
315 #[test]
316 fn test_route_info_expects_json_body_false() {
317 let info = RouteInfo {
318 expects_json_body: false,
319 };
320 assert_eq!(info.expects_json_body, false);
321 }
322
323 #[test]
324 fn test_request_with_zero_content_length() {
325 let headers = axum::http::HeaderMap::new();
326 assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_none());
327 }
328
329 #[test]
330 fn test_request_with_very_large_content_length() {
331 let mut headers = axum::http::HeaderMap::new();
332 let large_size = usize::MAX - 1;
333 headers.insert(
334 axum::http::header::CONTENT_LENGTH,
335 axum::http::HeaderValue::from_str(&large_size.to_string()).unwrap(),
336 );
337 assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_some());
338 }
339
340 #[test]
341 fn test_request_body_smaller_than_declared_length() {
342 let mut headers = axum::http::HeaderMap::new();
343 headers.insert(
344 axum::http::header::CONTENT_LENGTH,
345 axum::http::HeaderValue::from_static("1000"),
346 );
347 let result = super::validation::validate_content_length(&headers, 500);
348 assert!(
349 result.is_err(),
350 "Should reject when actual body is smaller than declared"
351 );
352 }
353
354 #[test]
355 fn test_request_body_larger_than_declared_length() {
356 let mut headers = axum::http::HeaderMap::new();
357 headers.insert(
358 axum::http::header::CONTENT_LENGTH,
359 axum::http::HeaderValue::from_static("500"),
360 );
361 let result = super::validation::validate_content_length(&headers, 1000);
362 assert!(
363 result.is_err(),
364 "Should reject when actual body is larger than declared"
365 );
366 }
367
368 #[test]
369 fn test_get_request_no_body_validation() {
370 let request = Request::builder()
371 .method(axum::http::Method::GET)
372 .uri("/api/users")
373 .body(Body::empty())
374 .unwrap();
375
376 let (parts, _body) = request.into_parts();
377 assert_eq!(parts.method, axum::http::Method::GET);
378 }
379
380 #[test]
381 fn test_delete_request_no_body_validation() {
382 let request = Request::builder()
383 .method(axum::http::Method::DELETE)
384 .uri("/api/users/1")
385 .body(Body::empty())
386 .unwrap();
387
388 let (parts, _body) = request.into_parts();
389 assert_eq!(parts.method, axum::http::Method::DELETE);
390 }
391
392 #[test]
393 fn test_post_request_requires_validation() {
394 let request = Request::builder()
395 .method(axum::http::Method::POST)
396 .uri("/api/users")
397 .body(Body::empty())
398 .unwrap();
399
400 let (parts, _body) = request.into_parts();
401 assert_eq!(parts.method, axum::http::Method::POST);
402 }
403
404 #[test]
405 fn test_put_request_requires_validation() {
406 let request = Request::builder()
407 .method(axum::http::Method::PUT)
408 .uri("/api/users/1")
409 .body(Body::empty())
410 .unwrap();
411
412 let (parts, _body) = request.into_parts();
413 assert_eq!(parts.method, axum::http::Method::PUT);
414 }
415
416 #[test]
417 fn test_patch_request_requires_validation() {
418 let request = Request::builder()
419 .method(axum::http::Method::PATCH)
420 .uri("/api/users/1")
421 .body(Body::empty())
422 .unwrap();
423
424 let (parts, _body) = request.into_parts();
425 assert_eq!(parts.method, axum::http::Method::PATCH);
426 }
427
428 #[test]
429 fn test_content_type_header_case_insensitive() {
430 let mut headers = axum::http::HeaderMap::new();
431 headers.insert(
432 axum::http::header::CONTENT_TYPE,
433 axum::http::HeaderValue::from_static("application/json"),
434 );
435
436 assert!(headers.get(axum::http::header::CONTENT_TYPE).is_some());
437 }
438
439 #[test]
440 fn test_content_length_header_case_insensitive() {
441 let mut headers = axum::http::HeaderMap::new();
442 headers.insert(
443 axum::http::header::CONTENT_LENGTH,
444 axum::http::HeaderValue::from_static("100"),
445 );
446
447 assert!(headers.get(axum::http::header::CONTENT_LENGTH).is_some());
448 }
449
450 #[test]
451 fn test_custom_headers_case_preserved() {
452 let mut headers = axum::http::HeaderMap::new();
453 let custom_header: axum::http::HeaderName = "X-Custom-Header".parse().unwrap();
454 headers.insert(custom_header.clone(), axum::http::HeaderValue::from_static("value"));
455
456 assert!(headers.get(&custom_header).is_some());
457 }
458
459 #[test]
460 fn test_multipart_boundary_minimal() {
461 let mut headers = axum::http::HeaderMap::new();
462 headers.insert(
463 axum::http::header::CONTENT_TYPE,
464 axum::http::HeaderValue::from_static("multipart/form-data; boundary=x"),
465 );
466
467 let result = super::validation::validate_content_type_headers(&headers, 0);
468 assert!(result.is_ok(), "Minimal boundary should be accepted");
469 }
470
471 #[test]
472 fn test_multipart_boundary_with_numbers() {
473 let mut headers = axum::http::HeaderMap::new();
474 headers.insert(
475 axum::http::header::CONTENT_TYPE,
476 axum::http::HeaderValue::from_static("multipart/form-data; boundary=boundary123456"),
477 );
478
479 let result = super::validation::validate_content_type_headers(&headers, 0);
480 assert!(result.is_ok());
481 }
482
483 #[test]
484 fn test_multipart_boundary_with_special_chars() {
485 let mut headers = axum::http::HeaderMap::new();
486 headers.insert(
487 axum::http::header::CONTENT_TYPE,
488 axum::http::HeaderValue::from_static("multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW"),
489 );
490
491 let result = super::validation::validate_content_type_headers(&headers, 0);
492 assert!(result.is_ok(), "Boundary with dashes should be accepted");
493 }
494
495 #[test]
496 fn test_multipart_empty_boundary() {
497 let mut headers = axum::http::HeaderMap::new();
498 headers.insert(
499 axum::http::header::CONTENT_TYPE,
500 axum::http::HeaderValue::from_static("multipart/form-data; boundary="),
501 );
502
503 let _result = super::validation::validate_content_type_headers(&headers, 0);
504 assert!(headers.get(axum::http::header::CONTENT_TYPE).is_some());
505 }
506
507 #[test]
508 fn test_invalid_json_body_detection() {
509 let invalid_json = r#"{"invalid": json without quotes}"#;
510 let _mime = "application/json".parse::<mime::Mime>().unwrap();
511
512 let result = serde_json::from_str::<serde_json::Value>(invalid_json);
513 assert!(result.is_err(), "Invalid JSON should fail parsing");
514 }
515
516 #[test]
517 fn test_valid_json_parsing() {
518 let valid_json = r#"{"key": "value"}"#;
519 let result = serde_json::from_str::<serde_json::Value>(valid_json);
520 assert!(result.is_ok(), "Valid JSON should parse successfully");
521 }
522
523 #[test]
524 fn test_empty_json_object() {
525 let empty_json = "{}";
526 let result = serde_json::from_str::<serde_json::Value>(empty_json);
527 assert!(result.is_ok());
528 let value = result.unwrap();
529 assert!(value.is_object());
530 assert_eq!(value.as_object().unwrap().len(), 0);
531 }
532
533 #[test]
534 fn test_form_data_mime_type() {
535 let mime = "multipart/form-data; boundary=xyz".parse::<mime::Mime>().unwrap();
536 assert_eq!(mime.type_(), mime::MULTIPART);
537 assert_eq!(mime.subtype(), "form-data");
538 }
539
540 #[test]
541 fn test_form_urlencoded_mime_type() {
542 let mime = "application/x-www-form-urlencoded".parse::<mime::Mime>().unwrap();
543 assert_eq!(mime.type_(), mime::APPLICATION);
544 assert_eq!(mime.subtype(), "x-www-form-urlencoded");
545 }
546
547 #[test]
548 fn test_json_mime_type() {
549 let mime = "application/json".parse::<mime::Mime>().unwrap();
550 assert_eq!(mime.type_(), mime::APPLICATION);
551 assert_eq!(mime.subtype(), mime::JSON);
552 }
553
554 #[test]
555 fn test_text_plain_mime_type() {
556 let mime = "text/plain".parse::<mime::Mime>().unwrap();
557 assert_eq!(mime.type_(), mime::TEXT);
558 assert_eq!(mime.subtype(), "plain");
559 }
560}