rustapi_core/
extract.rs

1//! Extractors for RustAPI
2//!
3//! Extractors automatically parse and validate data from incoming HTTP requests.
4//! They implement the [`FromRequest`] or [`FromRequestParts`] traits and can be
5//! used as handler function parameters.
6//!
7//! # Available Extractors
8//!
9//! | Extractor | Description | Consumes Body |
10//! |-----------|-------------|---------------|
11//! | [`Json<T>`] | Parse JSON request body | Yes |
12//! | [`ValidatedJson<T>`] | Parse and validate JSON body | Yes |
13//! | [`Query<T>`] | Parse query string parameters | No |
14//! | [`Path<T>`] | Extract path parameters | No |
15//! | [`State<T>`] | Access shared application state | No |
16//! | [`Body`] | Raw request body bytes | Yes |
17//! | [`Headers`] | Access all request headers | No |
18//! | [`HeaderValue`] | Extract a specific header | No |
19//! | [`Extension<T>`] | Access middleware-injected data | No |
20//! | [`ClientIp`] | Extract client IP address | No |
21//! | [`Cookies`] | Parse request cookies (requires `cookies` feature) | No |
22//!
23//! # Example
24//!
25//! ```rust,ignore
26//! use rustapi_core::{Json, Query, Path, State};
27//! use serde::{Deserialize, Serialize};
28//!
29//! #[derive(Deserialize)]
30//! struct CreateUser {
31//!     name: String,
32//!     email: String,
33//! }
34//!
35//! #[derive(Deserialize)]
36//! struct Pagination {
37//!     page: Option<u32>,
38//!     limit: Option<u32>,
39//! }
40//!
41//! // Multiple extractors can be combined
42//! async fn create_user(
43//!     State(db): State<DbPool>,
44//!     Query(pagination): Query<Pagination>,
45//!     Json(body): Json<CreateUser>,
46//! ) -> impl IntoResponse {
47//!     // Use db, pagination, and body...
48//! }
49//! ```
50//!
51//! # Extractor Order
52//!
53//! When using multiple extractors, body-consuming extractors (like `Json` or `Body`)
54//! must come last since they consume the request body. Non-body extractors can be
55//! in any order.
56
57use crate::error::{ApiError, Result};
58use crate::request::Request;
59use crate::response::IntoResponse;
60use bytes::Bytes;
61use http::{header, StatusCode};
62use http_body_util::Full;
63use serde::de::DeserializeOwned;
64use serde::Serialize;
65use std::future::Future;
66use std::ops::{Deref, DerefMut};
67use std::str::FromStr;
68
69/// Trait for extracting data from request parts (headers, path, query)
70///
71/// This is used for extractors that don't need the request body.
72pub trait FromRequestParts: Sized {
73    /// Extract from request parts
74    fn from_request_parts(req: &Request) -> Result<Self>;
75}
76
77/// Trait for extracting data from the full request (including body)
78///
79/// This is used for extractors that consume the request body.
80pub trait FromRequest: Sized {
81    /// Extract from the full request
82    fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
83}
84
85// Blanket impl: FromRequestParts -> FromRequest
86impl<T: FromRequestParts> FromRequest for T {
87    async fn from_request(req: &mut Request) -> Result<Self> {
88        T::from_request_parts(req)
89    }
90}
91
92/// JSON body extractor
93///
94/// Parses the request body as JSON and deserializes into type `T`.
95/// Also works as a response type when T: Serialize.
96///
97/// # Example
98///
99/// ```rust,ignore
100/// #[derive(Deserialize)]
101/// struct CreateUser {
102///     name: String,
103///     email: String,
104/// }
105///
106/// async fn create_user(Json(body): Json<CreateUser>) -> impl IntoResponse {
107///     // body is already deserialized
108/// }
109/// ```
110#[derive(Debug, Clone, Copy, Default)]
111pub struct Json<T>(pub T);
112
113impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
114    async fn from_request(req: &mut Request) -> Result<Self> {
115        let body = req
116            .take_body()
117            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
118
119        let value: T = serde_json::from_slice(&body)?;
120        Ok(Json(value))
121    }
122}
123
124impl<T> Deref for Json<T> {
125    type Target = T;
126
127    fn deref(&self) -> &Self::Target {
128        &self.0
129    }
130}
131
132impl<T> DerefMut for Json<T> {
133    fn deref_mut(&mut self) -> &mut Self::Target {
134        &mut self.0
135    }
136}
137
138impl<T> From<T> for Json<T> {
139    fn from(value: T) -> Self {
140        Json(value)
141    }
142}
143
144// IntoResponse for Json - allows using Json<T> as a return type
145impl<T: Serialize> IntoResponse for Json<T> {
146    fn into_response(self) -> crate::response::Response {
147        match serde_json::to_vec(&self.0) {
148            Ok(body) => http::Response::builder()
149                .status(StatusCode::OK)
150                .header(header::CONTENT_TYPE, "application/json")
151                .body(Full::new(Bytes::from(body)))
152                .unwrap(),
153            Err(err) => {
154                ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
155            }
156        }
157    }
158}
159
160/// Validated JSON body extractor
161///
162/// Parses the request body as JSON, deserializes into type `T`, and validates
163/// using the `Validate` trait. Returns a 422 Unprocessable Entity error with
164/// detailed field-level validation errors if validation fails.
165///
166/// # Example
167///
168/// ```rust,ignore
169/// use rustapi_rs::prelude::*;
170/// use validator::Validate;
171///
172/// #[derive(Deserialize, Validate)]
173/// struct CreateUser {
174///     #[validate(email)]
175///     email: String,
176///     #[validate(length(min = 8))]
177///     password: String,
178/// }
179///
180/// async fn register(ValidatedJson(body): ValidatedJson<CreateUser>) -> impl IntoResponse {
181///     // body is already validated!
182///     // If email is invalid or password too short, a 422 error is returned automatically
183/// }
184/// ```
185#[derive(Debug, Clone, Copy, Default)]
186pub struct ValidatedJson<T>(pub T);
187
188impl<T> ValidatedJson<T> {
189    /// Create a new ValidatedJson wrapper
190    pub fn new(value: T) -> Self {
191        Self(value)
192    }
193
194    /// Get the inner value
195    pub fn into_inner(self) -> T {
196        self.0
197    }
198}
199
200impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
201    async fn from_request(req: &mut Request) -> Result<Self> {
202        // First, deserialize the JSON body
203        let body = req
204            .take_body()
205            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
206
207        let value: T = serde_json::from_slice(&body)?;
208
209        // Then, validate it
210        if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
211            // Convert validation error to API error with 422 status
212            return Err(validation_error.into());
213        }
214
215        Ok(ValidatedJson(value))
216    }
217}
218
219impl<T> Deref for ValidatedJson<T> {
220    type Target = T;
221
222    fn deref(&self) -> &Self::Target {
223        &self.0
224    }
225}
226
227impl<T> DerefMut for ValidatedJson<T> {
228    fn deref_mut(&mut self) -> &mut Self::Target {
229        &mut self.0
230    }
231}
232
233impl<T> From<T> for ValidatedJson<T> {
234    fn from(value: T) -> Self {
235        ValidatedJson(value)
236    }
237}
238
239impl<T: Serialize> IntoResponse for ValidatedJson<T> {
240    fn into_response(self) -> crate::response::Response {
241        Json(self.0).into_response()
242    }
243}
244
245/// Query string extractor
246///
247/// Parses the query string into type `T`.
248///
249/// # Example
250///
251/// ```rust,ignore
252/// #[derive(Deserialize)]
253/// struct Pagination {
254///     page: Option<u32>,
255///     limit: Option<u32>,
256/// }
257///
258/// async fn list_users(Query(params): Query<Pagination>) -> impl IntoResponse {
259///     // params.page, params.limit
260/// }
261/// ```
262#[derive(Debug, Clone)]
263pub struct Query<T>(pub T);
264
265impl<T: DeserializeOwned> FromRequestParts for Query<T> {
266    fn from_request_parts(req: &Request) -> Result<Self> {
267        let query = req.query_string().unwrap_or("");
268        let value: T = serde_urlencoded::from_str(query)
269            .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
270        Ok(Query(value))
271    }
272}
273
274impl<T> Deref for Query<T> {
275    type Target = T;
276
277    fn deref(&self) -> &Self::Target {
278        &self.0
279    }
280}
281
282/// Path parameter extractor
283///
284/// Extracts path parameters defined in the route pattern.
285///
286/// # Example
287///
288/// For route `/users/{id}`:
289///
290/// ```rust,ignore
291/// async fn get_user(Path(id): Path<i64>) -> impl IntoResponse {
292///     // id is extracted from path
293/// }
294/// ```
295///
296/// For multiple params `/users/{user_id}/posts/{post_id}`:
297///
298/// ```rust,ignore
299/// async fn get_post(Path((user_id, post_id)): Path<(i64, i64)>) -> impl IntoResponse {
300///     // Both params extracted
301/// }
302/// ```
303#[derive(Debug, Clone)]
304pub struct Path<T>(pub T);
305
306impl<T: FromStr> FromRequestParts for Path<T>
307where
308    T::Err: std::fmt::Display,
309{
310    fn from_request_parts(req: &Request) -> Result<Self> {
311        let params = req.path_params();
312
313        // For single param, get the first one
314        if let Some((_, value)) = params.iter().next() {
315            let parsed = value
316                .parse::<T>()
317                .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
318            return Ok(Path(parsed));
319        }
320
321        Err(ApiError::internal("Missing path parameter"))
322    }
323}
324
325impl<T> Deref for Path<T> {
326    type Target = T;
327
328    fn deref(&self) -> &Self::Target {
329        &self.0
330    }
331}
332
333/// State extractor
334///
335/// Extracts shared application state.
336///
337/// # Example
338///
339/// ```rust,ignore
340/// #[derive(Clone)]
341/// struct AppState {
342///     db: DbPool,
343/// }
344///
345/// async fn handler(State(state): State<AppState>) -> impl IntoResponse {
346///     // Use state.db
347/// }
348/// ```
349#[derive(Debug, Clone)]
350pub struct State<T>(pub T);
351
352impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
353    fn from_request_parts(req: &Request) -> Result<Self> {
354        req.state().get::<T>().cloned().map(State).ok_or_else(|| {
355            ApiError::internal(format!(
356                "State of type `{}` not found. Did you forget to call .state()?",
357                std::any::type_name::<T>()
358            ))
359        })
360    }
361}
362
363impl<T> Deref for State<T> {
364    type Target = T;
365
366    fn deref(&self) -> &Self::Target {
367        &self.0
368    }
369}
370
371/// Raw body bytes extractor
372#[derive(Debug, Clone)]
373pub struct Body(pub Bytes);
374
375impl FromRequest for Body {
376    async fn from_request(req: &mut Request) -> Result<Self> {
377        let body = req
378            .take_body()
379            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
380        Ok(Body(body))
381    }
382}
383
384impl Deref for Body {
385    type Target = Bytes;
386
387    fn deref(&self) -> &Self::Target {
388        &self.0
389    }
390}
391
392/// Optional extractor wrapper
393///
394/// Makes any extractor optional - returns None instead of error on failure.
395impl<T: FromRequestParts> FromRequestParts for Option<T> {
396    fn from_request_parts(req: &Request) -> Result<Self> {
397        Ok(T::from_request_parts(req).ok())
398    }
399}
400
401/// Headers extractor
402///
403/// Provides access to all request headers as a typed map.
404///
405/// # Example
406///
407/// ```rust,ignore
408/// use rustapi_core::extract::Headers;
409///
410/// async fn handler(headers: Headers) -> impl IntoResponse {
411///     if let Some(content_type) = headers.get("content-type") {
412///         format!("Content-Type: {:?}", content_type)
413///     } else {
414///         "No Content-Type header".to_string()
415///     }
416/// }
417/// ```
418#[derive(Debug, Clone)]
419pub struct Headers(pub http::HeaderMap);
420
421impl Headers {
422    /// Get a header value by name
423    pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
424        self.0.get(name)
425    }
426
427    /// Check if a header exists
428    pub fn contains(&self, name: &str) -> bool {
429        self.0.contains_key(name)
430    }
431
432    /// Get the number of headers
433    pub fn len(&self) -> usize {
434        self.0.len()
435    }
436
437    /// Check if headers are empty
438    pub fn is_empty(&self) -> bool {
439        self.0.is_empty()
440    }
441
442    /// Iterate over all headers
443    pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
444        self.0.iter()
445    }
446}
447
448impl FromRequestParts for Headers {
449    fn from_request_parts(req: &Request) -> Result<Self> {
450        Ok(Headers(req.headers().clone()))
451    }
452}
453
454impl Deref for Headers {
455    type Target = http::HeaderMap;
456
457    fn deref(&self) -> &Self::Target {
458        &self.0
459    }
460}
461
462/// Single header value extractor
463///
464/// Extracts a specific header value by name. Returns an error if the header is missing.
465///
466/// # Example
467///
468/// ```rust,ignore
469/// use rustapi_core::extract::HeaderValue;
470///
471/// async fn handler(
472///     auth: HeaderValue<{ "authorization" }>,
473/// ) -> impl IntoResponse {
474///     format!("Auth header: {}", auth.0)
475/// }
476/// ```
477///
478/// Note: Due to Rust's const generics limitations, you may need to use the
479/// `HeaderValueOf` type alias or extract headers manually using the `Headers` extractor.
480#[derive(Debug, Clone)]
481pub struct HeaderValue(pub String, pub &'static str);
482
483impl HeaderValue {
484    /// Create a new HeaderValue extractor for a specific header name
485    pub fn new(name: &'static str, value: String) -> Self {
486        Self(value, name)
487    }
488
489    /// Get the header value
490    pub fn value(&self) -> &str {
491        &self.0
492    }
493
494    /// Get the header name
495    pub fn name(&self) -> &'static str {
496        self.1
497    }
498
499    /// Extract a specific header from a request
500    pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
501        req.headers()
502            .get(name)
503            .and_then(|v| v.to_str().ok())
504            .map(|s| HeaderValue(s.to_string(), name))
505            .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
506    }
507}
508
509impl Deref for HeaderValue {
510    type Target = String;
511
512    fn deref(&self) -> &Self::Target {
513        &self.0
514    }
515}
516
517/// Extension extractor
518///
519/// Retrieves typed data from request extensions that was inserted by middleware.
520///
521/// # Example
522///
523/// ```rust,ignore
524/// use rustapi_core::extract::Extension;
525///
526/// // Middleware inserts user data
527/// #[derive(Clone)]
528/// struct CurrentUser { id: i64 }
529///
530/// async fn handler(Extension(user): Extension<CurrentUser>) -> impl IntoResponse {
531///     format!("User ID: {}", user.id)
532/// }
533/// ```
534#[derive(Debug, Clone)]
535pub struct Extension<T>(pub T);
536
537impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
538    fn from_request_parts(req: &Request) -> Result<Self> {
539        req.extensions()
540            .get::<T>()
541            .cloned()
542            .map(Extension)
543            .ok_or_else(|| {
544                ApiError::internal(format!(
545                    "Extension of type `{}` not found. Did middleware insert it?",
546                    std::any::type_name::<T>()
547                ))
548            })
549    }
550}
551
552impl<T> Deref for Extension<T> {
553    type Target = T;
554
555    fn deref(&self) -> &Self::Target {
556        &self.0
557    }
558}
559
560impl<T> DerefMut for Extension<T> {
561    fn deref_mut(&mut self) -> &mut Self::Target {
562        &mut self.0
563    }
564}
565
566/// Client IP address extractor
567///
568/// Extracts the client IP address from the request. When `trust_proxy` is enabled,
569/// it will use the `X-Forwarded-For` header if present.
570///
571/// # Example
572///
573/// ```rust,ignore
574/// use rustapi_core::extract::ClientIp;
575///
576/// async fn handler(ClientIp(ip): ClientIp) -> impl IntoResponse {
577///     format!("Your IP: {}", ip)
578/// }
579/// ```
580#[derive(Debug, Clone)]
581pub struct ClientIp(pub std::net::IpAddr);
582
583impl ClientIp {
584    /// Extract client IP, optionally trusting X-Forwarded-For header
585    pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
586        if trust_proxy {
587            // Try X-Forwarded-For header first
588            if let Some(forwarded) = req.headers().get("x-forwarded-for") {
589                if let Ok(forwarded_str) = forwarded.to_str() {
590                    // X-Forwarded-For can contain multiple IPs, take the first one
591                    if let Some(first_ip) = forwarded_str.split(',').next() {
592                        if let Ok(ip) = first_ip.trim().parse() {
593                            return Ok(ClientIp(ip));
594                        }
595                    }
596                }
597            }
598        }
599
600        // Fall back to socket address from extensions (if set by server)
601        if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
602            return Ok(ClientIp(addr.ip()));
603        }
604
605        // Default to localhost if no IP information available
606        Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
607            127, 0, 0, 1,
608        ))))
609    }
610}
611
612impl FromRequestParts for ClientIp {
613    fn from_request_parts(req: &Request) -> Result<Self> {
614        // By default, trust proxy headers
615        Self::extract_with_config(req, true)
616    }
617}
618
619/// Cookies extractor
620///
621/// Parses and provides access to request cookies from the Cookie header.
622///
623/// # Example
624///
625/// ```rust,ignore
626/// use rustapi_core::extract::Cookies;
627///
628/// async fn handler(cookies: Cookies) -> impl IntoResponse {
629///     if let Some(session) = cookies.get("session_id") {
630///         format!("Session: {}", session.value())
631///     } else {
632///         "No session cookie".to_string()
633///     }
634/// }
635/// ```
636#[cfg(feature = "cookies")]
637#[derive(Debug, Clone)]
638pub struct Cookies(pub cookie::CookieJar);
639
640#[cfg(feature = "cookies")]
641impl Cookies {
642    /// Get a cookie by name
643    pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
644        self.0.get(name)
645    }
646
647    /// Iterate over all cookies
648    pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
649        self.0.iter()
650    }
651
652    /// Check if a cookie exists
653    pub fn contains(&self, name: &str) -> bool {
654        self.0.get(name).is_some()
655    }
656}
657
658#[cfg(feature = "cookies")]
659impl FromRequestParts for Cookies {
660    fn from_request_parts(req: &Request) -> Result<Self> {
661        let mut jar = cookie::CookieJar::new();
662
663        if let Some(cookie_header) = req.headers().get(header::COOKIE) {
664            if let Ok(cookie_str) = cookie_header.to_str() {
665                // Parse each cookie from the header
666                for cookie_part in cookie_str.split(';') {
667                    let trimmed = cookie_part.trim();
668                    if !trimmed.is_empty() {
669                        if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
670                            jar.add_original(cookie.into_owned());
671                        }
672                    }
673                }
674            }
675        }
676
677        Ok(Cookies(jar))
678    }
679}
680
681#[cfg(feature = "cookies")]
682impl Deref for Cookies {
683    type Target = cookie::CookieJar;
684
685    fn deref(&self) -> &Self::Target {
686        &self.0
687    }
688}
689
690// Implement FromRequestParts for common primitive types (path params)
691macro_rules! impl_from_request_parts_for_primitives {
692    ($($ty:ty),*) => {
693        $(
694            impl FromRequestParts for $ty {
695                fn from_request_parts(req: &Request) -> Result<Self> {
696                    let Path(value) = Path::<$ty>::from_request_parts(req)?;
697                    Ok(value)
698                }
699            }
700        )*
701    };
702}
703
704impl_from_request_parts_for_primitives!(
705    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
706);
707
708// OperationModifier implementations for extractors
709
710use rustapi_openapi::utoipa_types::openapi;
711use rustapi_openapi::{
712    IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
713    ResponseSpec, Schema, SchemaRef,
714};
715use std::collections::HashMap;
716
717// ValidatedJson - Adds request body
718impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
719    fn update_operation(op: &mut Operation) {
720        let (name, _) = T::schema();
721
722        let schema_ref = SchemaRef::Ref {
723            reference: format!("#/components/schemas/{}", name),
724        };
725
726        let mut content = HashMap::new();
727        content.insert(
728            "application/json".to_string(),
729            MediaType { schema: schema_ref },
730        );
731
732        op.request_body = Some(RequestBody {
733            required: true,
734            content,
735        });
736
737        // Add 422 Validation Error response
738        op.responses.insert(
739            "422".to_string(),
740            ResponseSpec {
741                description: "Validation Error".to_string(),
742                content: {
743                    let mut map = HashMap::new();
744                    map.insert(
745                        "application/json".to_string(),
746                        MediaType {
747                            schema: SchemaRef::Ref {
748                                reference: "#/components/schemas/ValidationErrorSchema".to_string(),
749                            },
750                        },
751                    );
752                    Some(map)
753                },
754            },
755        );
756    }
757}
758
759// Json - Adds request body (Same as ValidatedJson)
760impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
761    fn update_operation(op: &mut Operation) {
762        let (name, _) = T::schema();
763
764        let schema_ref = SchemaRef::Ref {
765            reference: format!("#/components/schemas/{}", name),
766        };
767
768        let mut content = HashMap::new();
769        content.insert(
770            "application/json".to_string(),
771            MediaType { schema: schema_ref },
772        );
773
774        op.request_body = Some(RequestBody {
775            required: true,
776            content,
777        });
778    }
779}
780
781// Path - Placeholder for path params
782impl<T> OperationModifier for Path<T> {
783    fn update_operation(_op: &mut Operation) {
784        // TODO: Implement path param extraction
785    }
786}
787
788// Query - Extracts query params using IntoParams
789impl<T: IntoParams> OperationModifier for Query<T> {
790    fn update_operation(op: &mut Operation) {
791        let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
792
793        let new_params: Vec<Parameter> = params
794            .into_iter()
795            .map(|p| {
796                let schema = match p.schema {
797                    Some(schema) => match schema {
798                        openapi::RefOr::Ref(r) => SchemaRef::Ref {
799                            reference: r.ref_location,
800                        },
801                        openapi::RefOr::T(s) => {
802                            let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
803                            SchemaRef::Inline(value)
804                        }
805                    },
806                    None => SchemaRef::Inline(serde_json::Value::Null),
807                };
808
809                let required = match p.required {
810                    openapi::Required::True => true,
811                    openapi::Required::False => false,
812                };
813
814                Parameter {
815                    name: p.name,
816                    location: "query".to_string(), // explicitly query
817                    required,
818                    description: p.description,
819                    schema,
820                }
821            })
822            .collect();
823
824        if let Some(existing) = &mut op.parameters {
825            existing.extend(new_params);
826        } else {
827            op.parameters = Some(new_params);
828        }
829    }
830}
831
832// State - No op
833impl<T> OperationModifier for State<T> {
834    fn update_operation(_op: &mut Operation) {}
835}
836
837// Body - Generic binary body
838impl OperationModifier for Body {
839    fn update_operation(op: &mut Operation) {
840        let mut content = HashMap::new();
841        content.insert(
842            "application/octet-stream".to_string(),
843            MediaType {
844                schema: SchemaRef::Inline(
845                    serde_json::json!({ "type": "string", "format": "binary" }),
846                ),
847            },
848        );
849
850        op.request_body = Some(RequestBody {
851            required: true,
852            content,
853        });
854    }
855}
856
857// ResponseModifier implementations for extractors
858
859// Json<T> - 200 OK with schema T
860impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
861    fn update_response(op: &mut Operation) {
862        let (name, _) = T::schema();
863
864        let schema_ref = SchemaRef::Ref {
865            reference: format!("#/components/schemas/{}", name),
866        };
867
868        op.responses.insert(
869            "200".to_string(),
870            ResponseSpec {
871                description: "Successful response".to_string(),
872                content: {
873                    let mut map = HashMap::new();
874                    map.insert(
875                        "application/json".to_string(),
876                        MediaType { schema: schema_ref },
877                    );
878                    Some(map)
879                },
880            },
881        );
882    }
883}
884
885#[cfg(test)]
886mod tests {
887    use super::*;
888    use bytes::Bytes;
889    use http::{Extensions, Method};
890    use proptest::prelude::*;
891    use proptest::test_runner::TestCaseError;
892    use std::collections::HashMap;
893    use std::sync::Arc;
894
895    /// Create a test request with the given method, path, and headers
896    fn create_test_request_with_headers(
897        method: Method,
898        path: &str,
899        headers: Vec<(&str, &str)>,
900    ) -> Request {
901        let uri: http::Uri = path.parse().unwrap();
902        let mut builder = http::Request::builder().method(method).uri(uri);
903
904        for (name, value) in headers {
905            builder = builder.header(name, value);
906        }
907
908        let req = builder.body(()).unwrap();
909        let (parts, _) = req.into_parts();
910
911        Request::new(
912            parts,
913            Bytes::new(),
914            Arc::new(Extensions::new()),
915            HashMap::new(),
916        )
917    }
918
919    /// Create a test request with extensions
920    fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
921        method: Method,
922        path: &str,
923        extension: T,
924    ) -> Request {
925        let uri: http::Uri = path.parse().unwrap();
926        let builder = http::Request::builder().method(method).uri(uri);
927
928        let req = builder.body(()).unwrap();
929        let (mut parts, _) = req.into_parts();
930        parts.extensions.insert(extension);
931
932        Request::new(
933            parts,
934            Bytes::new(),
935            Arc::new(Extensions::new()),
936            HashMap::new(),
937        )
938    }
939
940    // **Feature: phase3-batteries-included, Property 14: Headers extractor completeness**
941    //
942    // For any request with headers H, the `Headers` extractor SHALL return a map
943    // containing all key-value pairs in H.
944    //
945    // **Validates: Requirements 5.1**
946    proptest! {
947        #![proptest_config(ProptestConfig::with_cases(100))]
948
949        #[test]
950        fn prop_headers_extractor_completeness(
951            // Generate random header names and values
952            // Using alphanumeric strings to ensure valid header names/values
953            headers in prop::collection::vec(
954                (
955                    "[a-z][a-z0-9-]{0,20}",  // Valid header name pattern
956                    "[a-zA-Z0-9 ]{1,50}"     // Valid header value pattern
957                ),
958                0..10
959            )
960        ) {
961            let result: Result<(), TestCaseError> = (|| {
962                // Convert to header tuples
963                let header_tuples: Vec<(&str, &str)> = headers
964                    .iter()
965                    .map(|(k, v)| (k.as_str(), v.as_str()))
966                    .collect();
967
968                // Create request with headers
969                let request = create_test_request_with_headers(
970                    Method::GET,
971                    "/test",
972                    header_tuples.clone(),
973                );
974
975                // Extract headers
976                let extracted = Headers::from_request_parts(&request)
977                    .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
978
979                // Verify all original headers are present
980                // HTTP allows duplicate headers - get_all() returns all values for a header name
981                for (name, value) in &headers {
982                    // Check that the header name exists
983                    let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
984                    prop_assert!(
985                        !all_values.is_empty(),
986                        "Header '{}' not found",
987                        name
988                    );
989
990                    // Check that the value is among the extracted values
991                    let value_found = all_values.iter().any(|v| {
992                        v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
993                    });
994
995                    prop_assert!(
996                        value_found,
997                        "Header '{}' value '{}' not found in extracted values",
998                        name,
999                        value
1000                    );
1001                }
1002
1003                Ok(())
1004            })();
1005            result?;
1006        }
1007    }
1008
1009    // **Feature: phase3-batteries-included, Property 15: HeaderValue extractor correctness**
1010    //
1011    // For any request with header "X" having value V, `HeaderValue::extract(req, "X")` SHALL return V;
1012    // for requests without header "X", it SHALL return an error.
1013    //
1014    // **Validates: Requirements 5.2**
1015    proptest! {
1016        #![proptest_config(ProptestConfig::with_cases(100))]
1017
1018        #[test]
1019        fn prop_header_value_extractor_correctness(
1020            header_name in "[a-z][a-z0-9-]{0,20}",
1021            header_value in "[a-zA-Z0-9 ]{1,50}",
1022            has_header in prop::bool::ANY,
1023        ) {
1024            let result: Result<(), TestCaseError> = (|| {
1025                let headers = if has_header {
1026                    vec![(header_name.as_str(), header_value.as_str())]
1027                } else {
1028                    vec![]
1029                };
1030
1031                let request = create_test_request_with_headers(Method::GET, "/test", headers);
1032
1033                // We need to use a static string for the header name in the extractor
1034                // So we'll test with a known header name
1035                let test_header = "x-test-header";
1036                let request_with_known_header = if has_header {
1037                    create_test_request_with_headers(
1038                        Method::GET,
1039                        "/test",
1040                        vec![(test_header, header_value.as_str())],
1041                    )
1042                } else {
1043                    create_test_request_with_headers(Method::GET, "/test", vec![])
1044                };
1045
1046                let result = HeaderValue::extract(&request_with_known_header, test_header);
1047
1048                if has_header {
1049                    let extracted = result
1050                        .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1051                    prop_assert_eq!(
1052                        extracted.value(),
1053                        header_value.as_str(),
1054                        "Header value mismatch"
1055                    );
1056                } else {
1057                    prop_assert!(
1058                        result.is_err(),
1059                        "Expected error when header is missing"
1060                    );
1061                }
1062
1063                Ok(())
1064            })();
1065            result?;
1066        }
1067    }
1068
1069    // **Feature: phase3-batteries-included, Property 17: ClientIp extractor with forwarding**
1070    //
1071    // For any request with socket IP S and X-Forwarded-For header F, when forwarding is enabled,
1072    // `ClientIp` SHALL return the first IP in F; when disabled, it SHALL return S.
1073    //
1074    // **Validates: Requirements 5.4**
1075    proptest! {
1076        #![proptest_config(ProptestConfig::with_cases(100))]
1077
1078        #[test]
1079        fn prop_client_ip_extractor_with_forwarding(
1080            // Generate valid IPv4 addresses
1081            forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1082                .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1083            socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1084                .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1085            has_forwarded_header in prop::bool::ANY,
1086            trust_proxy in prop::bool::ANY,
1087        ) {
1088            let result: Result<(), TestCaseError> = (|| {
1089                let headers = if has_forwarded_header {
1090                    vec![("x-forwarded-for", forwarded_ip.as_str())]
1091                } else {
1092                    vec![]
1093                };
1094
1095                // Create request with headers
1096                let uri: http::Uri = "/test".parse().unwrap();
1097                let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1098                for (name, value) in &headers {
1099                    builder = builder.header(*name, *value);
1100                }
1101                let req = builder.body(()).unwrap();
1102                let (mut parts, _) = req.into_parts();
1103
1104                // Add socket address to extensions
1105                let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1106                parts.extensions.insert(socket_addr);
1107
1108                let request = Request::new(
1109                    parts,
1110                    Bytes::new(),
1111                    Arc::new(Extensions::new()),
1112                    HashMap::new(),
1113                );
1114
1115                let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1116                    .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1117
1118                if trust_proxy && has_forwarded_header {
1119                    // Should use X-Forwarded-For
1120                    let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1121                        .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1122                    prop_assert_eq!(
1123                        extracted.0,
1124                        expected_ip,
1125                        "Should use X-Forwarded-For IP when trust_proxy is enabled"
1126                    );
1127                } else {
1128                    // Should use socket IP
1129                    prop_assert_eq!(
1130                        extracted.0,
1131                        socket_ip,
1132                        "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1133                    );
1134                }
1135
1136                Ok(())
1137            })();
1138            result?;
1139        }
1140    }
1141
1142    // **Feature: phase3-batteries-included, Property 18: Extension extractor retrieval**
1143    //
1144    // For any type T and value V inserted into request extensions by middleware,
1145    // `Extension<T>` SHALL return V.
1146    //
1147    // **Validates: Requirements 5.5**
1148    proptest! {
1149        #![proptest_config(ProptestConfig::with_cases(100))]
1150
1151        #[test]
1152        fn prop_extension_extractor_retrieval(
1153            value in any::<i64>(),
1154            has_extension in prop::bool::ANY,
1155        ) {
1156            let result: Result<(), TestCaseError> = (|| {
1157                // Create a simple wrapper type for testing
1158                #[derive(Clone, Debug, PartialEq)]
1159                struct TestExtension(i64);
1160
1161                let uri: http::Uri = "/test".parse().unwrap();
1162                let builder = http::Request::builder().method(Method::GET).uri(uri);
1163                let req = builder.body(()).unwrap();
1164                let (mut parts, _) = req.into_parts();
1165
1166                if has_extension {
1167                    parts.extensions.insert(TestExtension(value));
1168                }
1169
1170                let request = Request::new(
1171                    parts,
1172                    Bytes::new(),
1173                    Arc::new(Extensions::new()),
1174                    HashMap::new(),
1175                );
1176
1177                let result = Extension::<TestExtension>::from_request_parts(&request);
1178
1179                if has_extension {
1180                    let extracted = result
1181                        .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1182                    prop_assert_eq!(
1183                        extracted.0,
1184                        TestExtension(value),
1185                        "Extension value mismatch"
1186                    );
1187                } else {
1188                    prop_assert!(
1189                        result.is_err(),
1190                        "Expected error when extension is missing"
1191                    );
1192                }
1193
1194                Ok(())
1195            })();
1196            result?;
1197        }
1198    }
1199
1200    // Unit tests for basic functionality
1201
1202    #[test]
1203    fn test_headers_extractor_basic() {
1204        let request = create_test_request_with_headers(
1205            Method::GET,
1206            "/test",
1207            vec![
1208                ("content-type", "application/json"),
1209                ("accept", "text/html"),
1210            ],
1211        );
1212
1213        let headers = Headers::from_request_parts(&request).unwrap();
1214
1215        assert!(headers.contains("content-type"));
1216        assert!(headers.contains("accept"));
1217        assert!(!headers.contains("x-custom"));
1218        assert_eq!(headers.len(), 2);
1219    }
1220
1221    #[test]
1222    fn test_header_value_extractor_present() {
1223        let request = create_test_request_with_headers(
1224            Method::GET,
1225            "/test",
1226            vec![("authorization", "Bearer token123")],
1227        );
1228
1229        let result = HeaderValue::extract(&request, "authorization");
1230        assert!(result.is_ok());
1231        assert_eq!(result.unwrap().value(), "Bearer token123");
1232    }
1233
1234    #[test]
1235    fn test_header_value_extractor_missing() {
1236        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1237
1238        let result = HeaderValue::extract(&request, "authorization");
1239        assert!(result.is_err());
1240    }
1241
1242    #[test]
1243    fn test_client_ip_from_forwarded_header() {
1244        let request = create_test_request_with_headers(
1245            Method::GET,
1246            "/test",
1247            vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1248        );
1249
1250        let ip = ClientIp::extract_with_config(&request, true).unwrap();
1251        assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1252    }
1253
1254    #[test]
1255    fn test_client_ip_ignores_forwarded_when_not_trusted() {
1256        let uri: http::Uri = "/test".parse().unwrap();
1257        let builder = http::Request::builder()
1258            .method(Method::GET)
1259            .uri(uri)
1260            .header("x-forwarded-for", "192.168.1.100");
1261        let req = builder.body(()).unwrap();
1262        let (mut parts, _) = req.into_parts();
1263
1264        let socket_addr = std::net::SocketAddr::new(
1265            std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1266            8080,
1267        );
1268        parts.extensions.insert(socket_addr);
1269
1270        let request = Request::new(
1271            parts,
1272            Bytes::new(),
1273            Arc::new(Extensions::new()),
1274            HashMap::new(),
1275        );
1276
1277        let ip = ClientIp::extract_with_config(&request, false).unwrap();
1278        assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1279    }
1280
1281    #[test]
1282    fn test_extension_extractor_present() {
1283        #[derive(Clone, Debug, PartialEq)]
1284        struct MyData(String);
1285
1286        let request =
1287            create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1288
1289        let result = Extension::<MyData>::from_request_parts(&request);
1290        assert!(result.is_ok());
1291        assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1292    }
1293
1294    #[test]
1295    fn test_extension_extractor_missing() {
1296        #[derive(Clone, Debug)]
1297        struct MyData(String);
1298
1299        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1300
1301        let result = Extension::<MyData>::from_request_parts(&request);
1302        assert!(result.is_err());
1303    }
1304
1305    // Cookies tests (feature-gated)
1306    #[cfg(feature = "cookies")]
1307    mod cookies_tests {
1308        use super::*;
1309
1310        // **Feature: phase3-batteries-included, Property 16: Cookies extractor parsing**
1311        //
1312        // For any request with Cookie header containing cookies C, the `Cookies` extractor
1313        // SHALL return a CookieJar containing exactly the cookies in C.
1314        // Note: Duplicate cookie names result in only the last value being kept.
1315        //
1316        // **Validates: Requirements 5.3**
1317        proptest! {
1318            #![proptest_config(ProptestConfig::with_cases(100))]
1319
1320            #[test]
1321            fn prop_cookies_extractor_parsing(
1322                // Generate random cookie names and values
1323                // Using alphanumeric strings to ensure valid cookie names/values
1324                cookies in prop::collection::vec(
1325                    (
1326                        "[a-zA-Z][a-zA-Z0-9_]{0,15}",  // Valid cookie name pattern
1327                        "[a-zA-Z0-9]{1,30}"            // Valid cookie value pattern (no special chars)
1328                    ),
1329                    0..5
1330                )
1331            ) {
1332                let result: Result<(), TestCaseError> = (|| {
1333                    // Build cookie header string
1334                    let cookie_header = cookies
1335                        .iter()
1336                        .map(|(name, value)| format!("{}={}", name, value))
1337                        .collect::<Vec<_>>()
1338                        .join("; ");
1339
1340                    let headers = if !cookies.is_empty() {
1341                        vec![("cookie", cookie_header.as_str())]
1342                    } else {
1343                        vec![]
1344                    };
1345
1346                    let request = create_test_request_with_headers(Method::GET, "/test", headers);
1347
1348                    // Extract cookies
1349                    let extracted = Cookies::from_request_parts(&request)
1350                        .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1351
1352                    // Build expected cookies map - last value wins for duplicate names
1353                    let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1354                    for (name, value) in &cookies {
1355                        expected_cookies.insert(name.as_str(), value.as_str());
1356                    }
1357
1358                    // Verify all expected cookies are present with correct values
1359                    for (name, expected_value) in &expected_cookies {
1360                        let cookie = extracted.get(*name)
1361                            .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1362
1363                        prop_assert_eq!(
1364                            cookie.value(),
1365                            *expected_value,
1366                            "Cookie '{}' value mismatch",
1367                            name
1368                        );
1369                    }
1370
1371                    // Count cookies in jar should match unique cookie names
1372                    let extracted_count = extracted.iter().count();
1373                    prop_assert_eq!(
1374                        extracted_count,
1375                        expected_cookies.len(),
1376                        "Expected {} unique cookies, got {}",
1377                        expected_cookies.len(),
1378                        extracted_count
1379                    );
1380
1381                    Ok(())
1382                })();
1383                result?;
1384            }
1385        }
1386
1387        #[test]
1388        fn test_cookies_extractor_basic() {
1389            let request = create_test_request_with_headers(
1390                Method::GET,
1391                "/test",
1392                vec![("cookie", "session=abc123; user=john")],
1393            );
1394
1395            let cookies = Cookies::from_request_parts(&request).unwrap();
1396
1397            assert!(cookies.contains("session"));
1398            assert!(cookies.contains("user"));
1399            assert!(!cookies.contains("other"));
1400
1401            assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1402            assert_eq!(cookies.get("user").unwrap().value(), "john");
1403        }
1404
1405        #[test]
1406        fn test_cookies_extractor_empty() {
1407            let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1408
1409            let cookies = Cookies::from_request_parts(&request).unwrap();
1410            assert_eq!(cookies.iter().count(), 0);
1411        }
1412
1413        #[test]
1414        fn test_cookies_extractor_single() {
1415            let request = create_test_request_with_headers(
1416                Method::GET,
1417                "/test",
1418                vec![("cookie", "token=xyz789")],
1419            );
1420
1421            let cookies = Cookies::from_request_parts(&request).unwrap();
1422            assert_eq!(cookies.iter().count(), 1);
1423            assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1424        }
1425    }
1426}