rustapi_core/
extract.rs

1//! Extractors for RustAPI
2//!
3//! Extractors automatically parse and validate data from incoming HTTP requests.
4//! They implement the [`FromRequest`] or [`FromRequestParts`] traits and can be
5//! used as handler function parameters.
6//!
7//! # Available Extractors
8//!
9//! | Extractor | Description | Consumes Body |
10//! |-----------|-------------|---------------|
11//! | [`Json<T>`] | Parse JSON request body | Yes |
12//! | [`ValidatedJson<T>`] | Parse and validate JSON body | Yes |
13//! | [`Query<T>`] | Parse query string parameters | No |
14//! | [`Path<T>`] | Extract path parameters | No |
15//! | [`State<T>`] | Access shared application state | No |
16//! | [`Body`] | Raw request body bytes | Yes |
17//! | [`Headers`] | Access all request headers | No |
18//! | [`HeaderValue`] | Extract a specific header | No |
19//! | [`Extension<T>`] | Access middleware-injected data | No |
20//! | [`ClientIp`] | Extract client IP address | No |
21//! | [`Cookies`] | Parse request cookies (requires `cookies` feature) | No |
22//!
23//! # Example
24//!
25//! ```rust,ignore
26//! use rustapi_core::{Json, Query, Path, State};
27//! use serde::{Deserialize, Serialize};
28//!
29//! #[derive(Deserialize)]
30//! struct CreateUser {
31//!     name: String,
32//!     email: String,
33//! }
34//!
35//! #[derive(Deserialize)]
36//! struct Pagination {
37//!     page: Option<u32>,
38//!     limit: Option<u32>,
39//! }
40//!
41//! // Multiple extractors can be combined
42//! async fn create_user(
43//!     State(db): State<DbPool>,
44//!     Query(pagination): Query<Pagination>,
45//!     Json(body): Json<CreateUser>,
46//! ) -> impl IntoResponse {
47//!     // Use db, pagination, and body...
48//! }
49//! ```
50//!
51//! # Extractor Order
52//!
53//! When using multiple extractors, body-consuming extractors (like `Json` or `Body`)
54//! must come last since they consume the request body. Non-body extractors can be
55//! in any order.
56
57use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use bytes::Bytes;
63use http::{header, StatusCode};
64use http_body_util::Full;
65use serde::de::DeserializeOwned;
66use serde::Serialize;
67use std::future::Future;
68use std::ops::{Deref, DerefMut};
69use std::str::FromStr;
70
71/// Trait for extracting data from request parts (headers, path, query)
72///
73/// This is used for extractors that don't need the request body.
74pub trait FromRequestParts: Sized {
75    /// Extract from request parts
76    fn from_request_parts(req: &Request) -> Result<Self>;
77}
78
79/// Trait for extracting data from the full request (including body)
80///
81/// This is used for extractors that consume the request body.
82pub trait FromRequest: Sized {
83    /// Extract from the full request
84    fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
85}
86
87// Blanket impl: FromRequestParts -> FromRequest
88impl<T: FromRequestParts> FromRequest for T {
89    async fn from_request(req: &mut Request) -> Result<Self> {
90        T::from_request_parts(req)
91    }
92}
93
94/// JSON body extractor
95///
96/// Parses the request body as JSON and deserializes into type `T`.
97/// Also works as a response type when T: Serialize.
98///
99/// # Example
100///
101/// ```rust,ignore
102/// #[derive(Deserialize)]
103/// struct CreateUser {
104///     name: String,
105///     email: String,
106/// }
107///
108/// async fn create_user(Json(body): Json<CreateUser>) -> impl IntoResponse {
109///     // body is already deserialized
110/// }
111/// ```
112#[derive(Debug, Clone, Copy, Default)]
113pub struct Json<T>(pub T);
114
115impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
116    async fn from_request(req: &mut Request) -> Result<Self> {
117        req.load_body().await?;
118        let body = req
119            .take_body()
120            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
121
122        // Use simd-json accelerated parsing when available (2-4x faster)
123        let value: T = json::from_slice(&body)?;
124        Ok(Json(value))
125    }
126}
127
128impl<T> Deref for Json<T> {
129    type Target = T;
130
131    fn deref(&self) -> &Self::Target {
132        &self.0
133    }
134}
135
136impl<T> DerefMut for Json<T> {
137    fn deref_mut(&mut self) -> &mut Self::Target {
138        &mut self.0
139    }
140}
141
142impl<T> From<T> for Json<T> {
143    fn from(value: T) -> Self {
144        Json(value)
145    }
146}
147
148/// Default pre-allocation size for JSON response buffers (256 bytes)
149/// This covers most small to medium JSON responses without reallocation.
150const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
151
152// IntoResponse for Json - allows using Json<T> as a return type
153impl<T: Serialize> IntoResponse for Json<T> {
154    fn into_response(self) -> crate::response::Response {
155        // Use pre-allocated buffer to reduce allocations
156        match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
157            Ok(body) => http::Response::builder()
158                .status(StatusCode::OK)
159                .header(header::CONTENT_TYPE, "application/json")
160                .body(Full::new(Bytes::from(body)))
161                .unwrap(),
162            Err(err) => {
163                ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
164            }
165        }
166    }
167}
168
169/// Validated JSON body extractor
170///
171/// Parses the request body as JSON, deserializes into type `T`, and validates
172/// using the `Validate` trait. Returns a 422 Unprocessable Entity error with
173/// detailed field-level validation errors if validation fails.
174///
175/// # Example
176///
177/// ```rust,ignore
178/// use rustapi_rs::prelude::*;
179/// use validator::Validate;
180///
181/// #[derive(Deserialize, Validate)]
182/// struct CreateUser {
183///     #[validate(email)]
184///     email: String,
185///     #[validate(length(min = 8))]
186///     password: String,
187/// }
188///
189/// async fn register(ValidatedJson(body): ValidatedJson<CreateUser>) -> impl IntoResponse {
190///     // body is already validated!
191///     // If email is invalid or password too short, a 422 error is returned automatically
192/// }
193/// ```
194#[derive(Debug, Clone, Copy, Default)]
195pub struct ValidatedJson<T>(pub T);
196
197impl<T> ValidatedJson<T> {
198    /// Create a new ValidatedJson wrapper
199    pub fn new(value: T) -> Self {
200        Self(value)
201    }
202
203    /// Get the inner value
204    pub fn into_inner(self) -> T {
205        self.0
206    }
207}
208
209impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
210    async fn from_request(req: &mut Request) -> Result<Self> {
211        req.load_body().await?;
212        // First, deserialize the JSON body using simd-json when available
213        let body = req
214            .take_body()
215            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
216
217        let value: T = json::from_slice(&body)?;
218
219        // Then, validate it
220        if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
221            // Convert validation error to API error with 422 status
222            return Err(validation_error.into());
223        }
224
225        Ok(ValidatedJson(value))
226    }
227}
228
229impl<T> Deref for ValidatedJson<T> {
230    type Target = T;
231
232    fn deref(&self) -> &Self::Target {
233        &self.0
234    }
235}
236
237impl<T> DerefMut for ValidatedJson<T> {
238    fn deref_mut(&mut self) -> &mut Self::Target {
239        &mut self.0
240    }
241}
242
243impl<T> From<T> for ValidatedJson<T> {
244    fn from(value: T) -> Self {
245        ValidatedJson(value)
246    }
247}
248
249impl<T: Serialize> IntoResponse for ValidatedJson<T> {
250    fn into_response(self) -> crate::response::Response {
251        Json(self.0).into_response()
252    }
253}
254
255/// Query string extractor
256///
257/// Parses the query string into type `T`.
258///
259/// # Example
260///
261/// ```rust,ignore
262/// #[derive(Deserialize)]
263/// struct Pagination {
264///     page: Option<u32>,
265///     limit: Option<u32>,
266/// }
267///
268/// async fn list_users(Query(params): Query<Pagination>) -> impl IntoResponse {
269///     // params.page, params.limit
270/// }
271/// ```
272#[derive(Debug, Clone)]
273pub struct Query<T>(pub T);
274
275impl<T: DeserializeOwned> FromRequestParts for Query<T> {
276    fn from_request_parts(req: &Request) -> Result<Self> {
277        let query = req.query_string().unwrap_or("");
278        let value: T = serde_urlencoded::from_str(query)
279            .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
280        Ok(Query(value))
281    }
282}
283
284impl<T> Deref for Query<T> {
285    type Target = T;
286
287    fn deref(&self) -> &Self::Target {
288        &self.0
289    }
290}
291
292/// Path parameter extractor
293///
294/// Extracts path parameters defined in the route pattern.
295///
296/// # Example
297///
298/// For route `/users/{id}`:
299///
300/// ```rust,ignore
301/// async fn get_user(Path(id): Path<i64>) -> impl IntoResponse {
302///     // id is extracted from path
303/// }
304/// ```
305///
306/// For multiple params `/users/{user_id}/posts/{post_id}`:
307///
308/// ```rust,ignore
309/// async fn get_post(Path((user_id, post_id)): Path<(i64, i64)>) -> impl IntoResponse {
310///     // Both params extracted
311/// }
312/// ```
313#[derive(Debug, Clone)]
314pub struct Path<T>(pub T);
315
316impl<T: FromStr> FromRequestParts for Path<T>
317where
318    T::Err: std::fmt::Display,
319{
320    fn from_request_parts(req: &Request) -> Result<Self> {
321        let params = req.path_params();
322
323        // For single param, get the first one
324        if let Some((_, value)) = params.iter().next() {
325            let parsed = value
326                .parse::<T>()
327                .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
328            return Ok(Path(parsed));
329        }
330
331        Err(ApiError::internal("Missing path parameter"))
332    }
333}
334
335impl<T> Deref for Path<T> {
336    type Target = T;
337
338    fn deref(&self) -> &Self::Target {
339        &self.0
340    }
341}
342
343/// State extractor
344///
345/// Extracts shared application state.
346///
347/// # Example
348///
349/// ```rust,ignore
350/// #[derive(Clone)]
351/// struct AppState {
352///     db: DbPool,
353/// }
354///
355/// async fn handler(State(state): State<AppState>) -> impl IntoResponse {
356///     // Use state.db
357/// }
358/// ```
359#[derive(Debug, Clone)]
360pub struct State<T>(pub T);
361
362impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
363    fn from_request_parts(req: &Request) -> Result<Self> {
364        req.state().get::<T>().cloned().map(State).ok_or_else(|| {
365            ApiError::internal(format!(
366                "State of type `{}` not found. Did you forget to call .state()?",
367                std::any::type_name::<T>()
368            ))
369        })
370    }
371}
372
373impl<T> Deref for State<T> {
374    type Target = T;
375
376    fn deref(&self) -> &Self::Target {
377        &self.0
378    }
379}
380
381/// Raw body bytes extractor
382#[derive(Debug, Clone)]
383pub struct Body(pub Bytes);
384
385impl FromRequest for Body {
386    async fn from_request(req: &mut Request) -> Result<Self> {
387        req.load_body().await?;
388        let body = req
389            .take_body()
390            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
391        Ok(Body(body))
392    }
393}
394
395impl Deref for Body {
396    type Target = Bytes;
397
398    fn deref(&self) -> &Self::Target {
399        &self.0
400    }
401}
402
403/// Streaming body extractor
404pub struct BodyStream(pub StreamingBody);
405
406impl FromRequest for BodyStream {
407    async fn from_request(req: &mut Request) -> Result<Self> {
408        let config = StreamingConfig::default();
409
410        if let Some(stream) = req.take_stream() {
411            Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
412        } else if let Some(bytes) = req.take_body() {
413            // Handle buffered body as stream
414            let stream = futures_util::stream::once(async move { Ok(bytes) });
415            Ok(BodyStream(StreamingBody::from_stream(
416                stream,
417                config.max_body_size,
418            )))
419        } else {
420            Err(ApiError::internal("Body already consumed"))
421        }
422    }
423}
424
425impl Deref for BodyStream {
426    type Target = StreamingBody;
427
428    fn deref(&self) -> &Self::Target {
429        &self.0
430    }
431}
432
433impl DerefMut for BodyStream {
434    fn deref_mut(&mut self) -> &mut Self::Target {
435        &mut self.0
436    }
437}
438
439// Forward stream implementation
440impl futures_util::Stream for BodyStream {
441    type Item = Result<Bytes, ApiError>;
442
443    fn poll_next(
444        mut self: std::pin::Pin<&mut Self>,
445        cx: &mut std::task::Context<'_>,
446    ) -> std::task::Poll<Option<Self::Item>> {
447        std::pin::Pin::new(&mut self.0).poll_next(cx)
448    }
449}
450
451/// Optional extractor wrapper
452///
453/// Makes any extractor optional - returns None instead of error on failure.
454impl<T: FromRequestParts> FromRequestParts for Option<T> {
455    fn from_request_parts(req: &Request) -> Result<Self> {
456        Ok(T::from_request_parts(req).ok())
457    }
458}
459
460/// Headers extractor
461///
462/// Provides access to all request headers as a typed map.
463///
464/// # Example
465///
466/// ```rust,ignore
467/// use rustapi_core::extract::Headers;
468///
469/// async fn handler(headers: Headers) -> impl IntoResponse {
470///     if let Some(content_type) = headers.get("content-type") {
471///         format!("Content-Type: {:?}", content_type)
472///     } else {
473///         "No Content-Type header".to_string()
474///     }
475/// }
476/// ```
477#[derive(Debug, Clone)]
478pub struct Headers(pub http::HeaderMap);
479
480impl Headers {
481    /// Get a header value by name
482    pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
483        self.0.get(name)
484    }
485
486    /// Check if a header exists
487    pub fn contains(&self, name: &str) -> bool {
488        self.0.contains_key(name)
489    }
490
491    /// Get the number of headers
492    pub fn len(&self) -> usize {
493        self.0.len()
494    }
495
496    /// Check if headers are empty
497    pub fn is_empty(&self) -> bool {
498        self.0.is_empty()
499    }
500
501    /// Iterate over all headers
502    pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
503        self.0.iter()
504    }
505}
506
507impl FromRequestParts for Headers {
508    fn from_request_parts(req: &Request) -> Result<Self> {
509        Ok(Headers(req.headers().clone()))
510    }
511}
512
513impl Deref for Headers {
514    type Target = http::HeaderMap;
515
516    fn deref(&self) -> &Self::Target {
517        &self.0
518    }
519}
520
521/// Single header value extractor
522///
523/// Extracts a specific header value by name. Returns an error if the header is missing.
524///
525/// # Example
526///
527/// ```rust,ignore
528/// use rustapi_core::extract::HeaderValue;
529///
530/// async fn handler(
531///     auth: HeaderValue<{ "authorization" }>,
532/// ) -> impl IntoResponse {
533///     format!("Auth header: {}", auth.0)
534/// }
535/// ```
536///
537/// Note: Due to Rust's const generics limitations, you may need to use the
538/// `HeaderValueOf` type alias or extract headers manually using the `Headers` extractor.
539#[derive(Debug, Clone)]
540pub struct HeaderValue(pub String, pub &'static str);
541
542impl HeaderValue {
543    /// Create a new HeaderValue extractor for a specific header name
544    pub fn new(name: &'static str, value: String) -> Self {
545        Self(value, name)
546    }
547
548    /// Get the header value
549    pub fn value(&self) -> &str {
550        &self.0
551    }
552
553    /// Get the header name
554    pub fn name(&self) -> &'static str {
555        self.1
556    }
557
558    /// Extract a specific header from a request
559    pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
560        req.headers()
561            .get(name)
562            .and_then(|v| v.to_str().ok())
563            .map(|s| HeaderValue(s.to_string(), name))
564            .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
565    }
566}
567
568impl Deref for HeaderValue {
569    type Target = String;
570
571    fn deref(&self) -> &Self::Target {
572        &self.0
573    }
574}
575
576/// Extension extractor
577///
578/// Retrieves typed data from request extensions that was inserted by middleware.
579///
580/// # Example
581///
582/// ```rust,ignore
583/// use rustapi_core::extract::Extension;
584///
585/// // Middleware inserts user data
586/// #[derive(Clone)]
587/// struct CurrentUser { id: i64 }
588///
589/// async fn handler(Extension(user): Extension<CurrentUser>) -> impl IntoResponse {
590///     format!("User ID: {}", user.id)
591/// }
592/// ```
593#[derive(Debug, Clone)]
594pub struct Extension<T>(pub T);
595
596impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
597    fn from_request_parts(req: &Request) -> Result<Self> {
598        req.extensions()
599            .get::<T>()
600            .cloned()
601            .map(Extension)
602            .ok_or_else(|| {
603                ApiError::internal(format!(
604                    "Extension of type `{}` not found. Did middleware insert it?",
605                    std::any::type_name::<T>()
606                ))
607            })
608    }
609}
610
611impl<T> Deref for Extension<T> {
612    type Target = T;
613
614    fn deref(&self) -> &Self::Target {
615        &self.0
616    }
617}
618
619impl<T> DerefMut for Extension<T> {
620    fn deref_mut(&mut self) -> &mut Self::Target {
621        &mut self.0
622    }
623}
624
625/// Client IP address extractor
626///
627/// Extracts the client IP address from the request. When `trust_proxy` is enabled,
628/// it will use the `X-Forwarded-For` header if present.
629///
630/// # Example
631///
632/// ```rust,ignore
633/// use rustapi_core::extract::ClientIp;
634///
635/// async fn handler(ClientIp(ip): ClientIp) -> impl IntoResponse {
636///     format!("Your IP: {}", ip)
637/// }
638/// ```
639#[derive(Debug, Clone)]
640pub struct ClientIp(pub std::net::IpAddr);
641
642impl ClientIp {
643    /// Extract client IP, optionally trusting X-Forwarded-For header
644    pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
645        if trust_proxy {
646            // Try X-Forwarded-For header first
647            if let Some(forwarded) = req.headers().get("x-forwarded-for") {
648                if let Ok(forwarded_str) = forwarded.to_str() {
649                    // X-Forwarded-For can contain multiple IPs, take the first one
650                    if let Some(first_ip) = forwarded_str.split(',').next() {
651                        if let Ok(ip) = first_ip.trim().parse() {
652                            return Ok(ClientIp(ip));
653                        }
654                    }
655                }
656            }
657        }
658
659        // Fall back to socket address from extensions (if set by server)
660        if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
661            return Ok(ClientIp(addr.ip()));
662        }
663
664        // Default to localhost if no IP information available
665        Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
666            127, 0, 0, 1,
667        ))))
668    }
669}
670
671impl FromRequestParts for ClientIp {
672    fn from_request_parts(req: &Request) -> Result<Self> {
673        // By default, trust proxy headers
674        Self::extract_with_config(req, true)
675    }
676}
677
678/// Cookies extractor
679///
680/// Parses and provides access to request cookies from the Cookie header.
681///
682/// # Example
683///
684/// ```rust,ignore
685/// use rustapi_core::extract::Cookies;
686///
687/// async fn handler(cookies: Cookies) -> impl IntoResponse {
688///     if let Some(session) = cookies.get("session_id") {
689///         format!("Session: {}", session.value())
690///     } else {
691///         "No session cookie".to_string()
692///     }
693/// }
694/// ```
695#[cfg(feature = "cookies")]
696#[derive(Debug, Clone)]
697pub struct Cookies(pub cookie::CookieJar);
698
699#[cfg(feature = "cookies")]
700impl Cookies {
701    /// Get a cookie by name
702    pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
703        self.0.get(name)
704    }
705
706    /// Iterate over all cookies
707    pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
708        self.0.iter()
709    }
710
711    /// Check if a cookie exists
712    pub fn contains(&self, name: &str) -> bool {
713        self.0.get(name).is_some()
714    }
715}
716
717#[cfg(feature = "cookies")]
718impl FromRequestParts for Cookies {
719    fn from_request_parts(req: &Request) -> Result<Self> {
720        let mut jar = cookie::CookieJar::new();
721
722        if let Some(cookie_header) = req.headers().get(header::COOKIE) {
723            if let Ok(cookie_str) = cookie_header.to_str() {
724                // Parse each cookie from the header
725                for cookie_part in cookie_str.split(';') {
726                    let trimmed = cookie_part.trim();
727                    if !trimmed.is_empty() {
728                        if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
729                            jar.add_original(cookie.into_owned());
730                        }
731                    }
732                }
733            }
734        }
735
736        Ok(Cookies(jar))
737    }
738}
739
740#[cfg(feature = "cookies")]
741impl Deref for Cookies {
742    type Target = cookie::CookieJar;
743
744    fn deref(&self) -> &Self::Target {
745        &self.0
746    }
747}
748
749// Implement FromRequestParts for common primitive types (path params)
750macro_rules! impl_from_request_parts_for_primitives {
751    ($($ty:ty),*) => {
752        $(
753            impl FromRequestParts for $ty {
754                fn from_request_parts(req: &Request) -> Result<Self> {
755                    let Path(value) = Path::<$ty>::from_request_parts(req)?;
756                    Ok(value)
757                }
758            }
759        )*
760    };
761}
762
763impl_from_request_parts_for_primitives!(
764    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
765);
766
767// OperationModifier implementations for extractors
768
769use rustapi_openapi::utoipa_types::openapi;
770use rustapi_openapi::{
771    IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
772    ResponseSpec, Schema, SchemaRef,
773};
774use std::collections::HashMap;
775
776// ValidatedJson - Adds request body
777impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
778    fn update_operation(op: &mut Operation) {
779        let (name, _) = T::schema();
780
781        let schema_ref = SchemaRef::Ref {
782            reference: format!("#/components/schemas/{}", name),
783        };
784
785        let mut content = HashMap::new();
786        content.insert(
787            "application/json".to_string(),
788            MediaType { schema: schema_ref },
789        );
790
791        op.request_body = Some(RequestBody {
792            required: true,
793            content,
794        });
795
796        // Add 422 Validation Error response
797        op.responses.insert(
798            "422".to_string(),
799            ResponseSpec {
800                description: "Validation Error".to_string(),
801                content: {
802                    let mut map = HashMap::new();
803                    map.insert(
804                        "application/json".to_string(),
805                        MediaType {
806                            schema: SchemaRef::Ref {
807                                reference: "#/components/schemas/ValidationErrorSchema".to_string(),
808                            },
809                        },
810                    );
811                    Some(map)
812                },
813            },
814        );
815    }
816}
817
818// Json - Adds request body (Same as ValidatedJson)
819impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
820    fn update_operation(op: &mut Operation) {
821        let (name, _) = T::schema();
822
823        let schema_ref = SchemaRef::Ref {
824            reference: format!("#/components/schemas/{}", name),
825        };
826
827        let mut content = HashMap::new();
828        content.insert(
829            "application/json".to_string(),
830            MediaType { schema: schema_ref },
831        );
832
833        op.request_body = Some(RequestBody {
834            required: true,
835            content,
836        });
837    }
838}
839
840// Path - Path parameters are automatically extracted from route patterns
841// The add_path_params_to_operation function in app.rs handles OpenAPI documentation
842// based on the {param} syntax in route paths (e.g., "/users/{id}")
843impl<T> OperationModifier for Path<T> {
844    fn update_operation(_op: &mut Operation) {
845        // Path parameters are automatically documented by add_path_params_to_operation
846        // in app.rs based on the route pattern. No additional implementation needed here.
847        //
848        // For typed path params, the schema type defaults to "string" but will be
849        // inferred from the actual type T when more sophisticated type introspection
850        // is implemented.
851    }
852}
853
854// Query - Extracts query params using IntoParams
855impl<T: IntoParams> OperationModifier for Query<T> {
856    fn update_operation(op: &mut Operation) {
857        let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
858
859        let new_params: Vec<Parameter> = params
860            .into_iter()
861            .map(|p| {
862                let schema = match p.schema {
863                    Some(schema) => match schema {
864                        openapi::RefOr::Ref(r) => SchemaRef::Ref {
865                            reference: r.ref_location,
866                        },
867                        openapi::RefOr::T(s) => {
868                            let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
869                            SchemaRef::Inline(value)
870                        }
871                    },
872                    None => SchemaRef::Inline(serde_json::Value::Null),
873                };
874
875                let required = match p.required {
876                    openapi::Required::True => true,
877                    openapi::Required::False => false,
878                };
879
880                Parameter {
881                    name: p.name,
882                    location: "query".to_string(), // explicitly query
883                    required,
884                    description: p.description,
885                    schema,
886                }
887            })
888            .collect();
889
890        if let Some(existing) = &mut op.parameters {
891            existing.extend(new_params);
892        } else {
893            op.parameters = Some(new_params);
894        }
895    }
896}
897
898// State - No op
899impl<T> OperationModifier for State<T> {
900    fn update_operation(_op: &mut Operation) {}
901}
902
903// Body - Generic binary body
904impl OperationModifier for Body {
905    fn update_operation(op: &mut Operation) {
906        let mut content = HashMap::new();
907        content.insert(
908            "application/octet-stream".to_string(),
909            MediaType {
910                schema: SchemaRef::Inline(
911                    serde_json::json!({ "type": "string", "format": "binary" }),
912                ),
913            },
914        );
915
916        op.request_body = Some(RequestBody {
917            required: true,
918            content,
919        });
920    }
921}
922
923// BodyStream - Generic binary stream (Same as Body)
924impl OperationModifier for BodyStream {
925    fn update_operation(op: &mut Operation) {
926        let mut content = HashMap::new();
927        content.insert(
928            "application/octet-stream".to_string(),
929            MediaType {
930                schema: SchemaRef::Inline(
931                    serde_json::json!({ "type": "string", "format": "binary" }),
932                ),
933            },
934        );
935
936        op.request_body = Some(RequestBody {
937            required: true,
938            content,
939        });
940    }
941}
942
943// ResponseModifier implementations for extractors
944
945// Json<T> - 200 OK with schema T
946impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
947    fn update_response(op: &mut Operation) {
948        let (name, _) = T::schema();
949
950        let schema_ref = SchemaRef::Ref {
951            reference: format!("#/components/schemas/{}", name),
952        };
953
954        op.responses.insert(
955            "200".to_string(),
956            ResponseSpec {
957                description: "Successful response".to_string(),
958                content: {
959                    let mut map = HashMap::new();
960                    map.insert(
961                        "application/json".to_string(),
962                        MediaType { schema: schema_ref },
963                    );
964                    Some(map)
965                },
966            },
967        );
968    }
969}
970
971#[cfg(test)]
972mod tests {
973    use super::*;
974    use crate::path_params::PathParams;
975    use bytes::Bytes;
976    use http::{Extensions, Method};
977    use proptest::prelude::*;
978    use proptest::test_runner::TestCaseError;
979    use std::sync::Arc;
980
981    /// Create a test request with the given method, path, and headers
982    fn create_test_request_with_headers(
983        method: Method,
984        path: &str,
985        headers: Vec<(&str, &str)>,
986    ) -> Request {
987        let uri: http::Uri = path.parse().unwrap();
988        let mut builder = http::Request::builder().method(method).uri(uri);
989
990        for (name, value) in headers {
991            builder = builder.header(name, value);
992        }
993
994        let req = builder.body(()).unwrap();
995        let (parts, _) = req.into_parts();
996
997        Request::new(
998            parts,
999            crate::request::BodyVariant::Buffered(Bytes::new()),
1000            Arc::new(Extensions::new()),
1001            PathParams::new(),
1002        )
1003    }
1004
1005    /// Create a test request with extensions
1006    fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1007        method: Method,
1008        path: &str,
1009        extension: T,
1010    ) -> Request {
1011        let uri: http::Uri = path.parse().unwrap();
1012        let builder = http::Request::builder().method(method).uri(uri);
1013
1014        let req = builder.body(()).unwrap();
1015        let (mut parts, _) = req.into_parts();
1016        parts.extensions.insert(extension);
1017
1018        Request::new(
1019            parts,
1020            crate::request::BodyVariant::Buffered(Bytes::new()),
1021            Arc::new(Extensions::new()),
1022            PathParams::new(),
1023        )
1024    }
1025
1026    // **Feature: phase3-batteries-included, Property 14: Headers extractor completeness**
1027    //
1028    // For any request with headers H, the `Headers` extractor SHALL return a map
1029    // containing all key-value pairs in H.
1030    //
1031    // **Validates: Requirements 5.1**
1032    proptest! {
1033        #![proptest_config(ProptestConfig::with_cases(100))]
1034
1035        #[test]
1036        fn prop_headers_extractor_completeness(
1037            // Generate random header names and values
1038            // Using alphanumeric strings to ensure valid header names/values
1039            headers in prop::collection::vec(
1040                (
1041                    "[a-z][a-z0-9-]{0,20}",  // Valid header name pattern
1042                    "[a-zA-Z0-9 ]{1,50}"     // Valid header value pattern
1043                ),
1044                0..10
1045            )
1046        ) {
1047            let result: Result<(), TestCaseError> = (|| {
1048                // Convert to header tuples
1049                let header_tuples: Vec<(&str, &str)> = headers
1050                    .iter()
1051                    .map(|(k, v)| (k.as_str(), v.as_str()))
1052                    .collect();
1053
1054                // Create request with headers
1055                let request = create_test_request_with_headers(
1056                    Method::GET,
1057                    "/test",
1058                    header_tuples.clone(),
1059                );
1060
1061                // Extract headers
1062                let extracted = Headers::from_request_parts(&request)
1063                    .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1064
1065                // Verify all original headers are present
1066                // HTTP allows duplicate headers - get_all() returns all values for a header name
1067                for (name, value) in &headers {
1068                    // Check that the header name exists
1069                    let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1070                    prop_assert!(
1071                        !all_values.is_empty(),
1072                        "Header '{}' not found",
1073                        name
1074                    );
1075
1076                    // Check that the value is among the extracted values
1077                    let value_found = all_values.iter().any(|v| {
1078                        v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1079                    });
1080
1081                    prop_assert!(
1082                        value_found,
1083                        "Header '{}' value '{}' not found in extracted values",
1084                        name,
1085                        value
1086                    );
1087                }
1088
1089                Ok(())
1090            })();
1091            result?;
1092        }
1093    }
1094
1095    // **Feature: phase3-batteries-included, Property 15: HeaderValue extractor correctness**
1096    //
1097    // For any request with header "X" having value V, `HeaderValue::extract(req, "X")` SHALL return V;
1098    // for requests without header "X", it SHALL return an error.
1099    //
1100    // **Validates: Requirements 5.2**
1101    proptest! {
1102        #![proptest_config(ProptestConfig::with_cases(100))]
1103
1104        #[test]
1105        fn prop_header_value_extractor_correctness(
1106            header_name in "[a-z][a-z0-9-]{0,20}",
1107            header_value in "[a-zA-Z0-9 ]{1,50}",
1108            has_header in prop::bool::ANY,
1109        ) {
1110            let result: Result<(), TestCaseError> = (|| {
1111                let headers = if has_header {
1112                    vec![(header_name.as_str(), header_value.as_str())]
1113                } else {
1114                    vec![]
1115                };
1116
1117                let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1118
1119                // We need to use a static string for the header name in the extractor
1120                // So we'll test with a known header name
1121                let test_header = "x-test-header";
1122                let request_with_known_header = if has_header {
1123                    create_test_request_with_headers(
1124                        Method::GET,
1125                        "/test",
1126                        vec![(test_header, header_value.as_str())],
1127                    )
1128                } else {
1129                    create_test_request_with_headers(Method::GET, "/test", vec![])
1130                };
1131
1132                let result = HeaderValue::extract(&request_with_known_header, test_header);
1133
1134                if has_header {
1135                    let extracted = result
1136                        .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1137                    prop_assert_eq!(
1138                        extracted.value(),
1139                        header_value.as_str(),
1140                        "Header value mismatch"
1141                    );
1142                } else {
1143                    prop_assert!(
1144                        result.is_err(),
1145                        "Expected error when header is missing"
1146                    );
1147                }
1148
1149                Ok(())
1150            })();
1151            result?;
1152        }
1153    }
1154
1155    // **Feature: phase3-batteries-included, Property 17: ClientIp extractor with forwarding**
1156    //
1157    // For any request with socket IP S and X-Forwarded-For header F, when forwarding is enabled,
1158    // `ClientIp` SHALL return the first IP in F; when disabled, it SHALL return S.
1159    //
1160    // **Validates: Requirements 5.4**
1161    proptest! {
1162        #![proptest_config(ProptestConfig::with_cases(100))]
1163
1164        #[test]
1165        fn prop_client_ip_extractor_with_forwarding(
1166            // Generate valid IPv4 addresses
1167            forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1168                .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1169            socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1170                .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1171            has_forwarded_header in prop::bool::ANY,
1172            trust_proxy in prop::bool::ANY,
1173        ) {
1174            let result: Result<(), TestCaseError> = (|| {
1175                let headers = if has_forwarded_header {
1176                    vec![("x-forwarded-for", forwarded_ip.as_str())]
1177                } else {
1178                    vec![]
1179                };
1180
1181                // Create request with headers
1182                let uri: http::Uri = "/test".parse().unwrap();
1183                let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1184                for (name, value) in &headers {
1185                    builder = builder.header(*name, *value);
1186                }
1187                let req = builder.body(()).unwrap();
1188                let (mut parts, _) = req.into_parts();
1189
1190                // Add socket address to extensions
1191                let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1192                parts.extensions.insert(socket_addr);
1193
1194                let request = Request::new(
1195                    parts,
1196                    crate::request::BodyVariant::Buffered(Bytes::new()),
1197                    Arc::new(Extensions::new()),
1198                    PathParams::new(),
1199                );
1200
1201                let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1202                    .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1203
1204                if trust_proxy && has_forwarded_header {
1205                    // Should use X-Forwarded-For
1206                    let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1207                        .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1208                    prop_assert_eq!(
1209                        extracted.0,
1210                        expected_ip,
1211                        "Should use X-Forwarded-For IP when trust_proxy is enabled"
1212                    );
1213                } else {
1214                    // Should use socket IP
1215                    prop_assert_eq!(
1216                        extracted.0,
1217                        socket_ip,
1218                        "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1219                    );
1220                }
1221
1222                Ok(())
1223            })();
1224            result?;
1225        }
1226    }
1227
1228    // **Feature: phase3-batteries-included, Property 18: Extension extractor retrieval**
1229    //
1230    // For any type T and value V inserted into request extensions by middleware,
1231    // `Extension<T>` SHALL return V.
1232    //
1233    // **Validates: Requirements 5.5**
1234    proptest! {
1235        #![proptest_config(ProptestConfig::with_cases(100))]
1236
1237        #[test]
1238        fn prop_extension_extractor_retrieval(
1239            value in any::<i64>(),
1240            has_extension in prop::bool::ANY,
1241        ) {
1242            let result: Result<(), TestCaseError> = (|| {
1243                // Create a simple wrapper type for testing
1244                #[derive(Clone, Debug, PartialEq)]
1245                struct TestExtension(i64);
1246
1247                let uri: http::Uri = "/test".parse().unwrap();
1248                let builder = http::Request::builder().method(Method::GET).uri(uri);
1249                let req = builder.body(()).unwrap();
1250                let (mut parts, _) = req.into_parts();
1251
1252                if has_extension {
1253                    parts.extensions.insert(TestExtension(value));
1254                }
1255
1256                let request = Request::new(
1257                    parts,
1258                    crate::request::BodyVariant::Buffered(Bytes::new()),
1259                    Arc::new(Extensions::new()),
1260                    PathParams::new(),
1261                );
1262
1263                let result = Extension::<TestExtension>::from_request_parts(&request);
1264
1265                if has_extension {
1266                    let extracted = result
1267                        .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1268                    prop_assert_eq!(
1269                        extracted.0,
1270                        TestExtension(value),
1271                        "Extension value mismatch"
1272                    );
1273                } else {
1274                    prop_assert!(
1275                        result.is_err(),
1276                        "Expected error when extension is missing"
1277                    );
1278                }
1279
1280                Ok(())
1281            })();
1282            result?;
1283        }
1284    }
1285
1286    // Unit tests for basic functionality
1287
1288    #[test]
1289    fn test_headers_extractor_basic() {
1290        let request = create_test_request_with_headers(
1291            Method::GET,
1292            "/test",
1293            vec![
1294                ("content-type", "application/json"),
1295                ("accept", "text/html"),
1296            ],
1297        );
1298
1299        let headers = Headers::from_request_parts(&request).unwrap();
1300
1301        assert!(headers.contains("content-type"));
1302        assert!(headers.contains("accept"));
1303        assert!(!headers.contains("x-custom"));
1304        assert_eq!(headers.len(), 2);
1305    }
1306
1307    #[test]
1308    fn test_header_value_extractor_present() {
1309        let request = create_test_request_with_headers(
1310            Method::GET,
1311            "/test",
1312            vec![("authorization", "Bearer token123")],
1313        );
1314
1315        let result = HeaderValue::extract(&request, "authorization");
1316        assert!(result.is_ok());
1317        assert_eq!(result.unwrap().value(), "Bearer token123");
1318    }
1319
1320    #[test]
1321    fn test_header_value_extractor_missing() {
1322        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1323
1324        let result = HeaderValue::extract(&request, "authorization");
1325        assert!(result.is_err());
1326    }
1327
1328    #[test]
1329    fn test_client_ip_from_forwarded_header() {
1330        let request = create_test_request_with_headers(
1331            Method::GET,
1332            "/test",
1333            vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1334        );
1335
1336        let ip = ClientIp::extract_with_config(&request, true).unwrap();
1337        assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1338    }
1339
1340    #[test]
1341    fn test_client_ip_ignores_forwarded_when_not_trusted() {
1342        let uri: http::Uri = "/test".parse().unwrap();
1343        let builder = http::Request::builder()
1344            .method(Method::GET)
1345            .uri(uri)
1346            .header("x-forwarded-for", "192.168.1.100");
1347        let req = builder.body(()).unwrap();
1348        let (mut parts, _) = req.into_parts();
1349
1350        let socket_addr = std::net::SocketAddr::new(
1351            std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1352            8080,
1353        );
1354        parts.extensions.insert(socket_addr);
1355
1356        let request = Request::new(
1357            parts,
1358            crate::request::BodyVariant::Buffered(Bytes::new()),
1359            Arc::new(Extensions::new()),
1360            PathParams::new(),
1361        );
1362
1363        let ip = ClientIp::extract_with_config(&request, false).unwrap();
1364        assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1365    }
1366
1367    #[test]
1368    fn test_extension_extractor_present() {
1369        #[derive(Clone, Debug, PartialEq)]
1370        struct MyData(String);
1371
1372        let request =
1373            create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1374
1375        let result = Extension::<MyData>::from_request_parts(&request);
1376        assert!(result.is_ok());
1377        assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1378    }
1379
1380    #[test]
1381    fn test_extension_extractor_missing() {
1382        #[derive(Clone, Debug)]
1383        #[allow(dead_code)]
1384        struct MyData(String);
1385
1386        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1387
1388        let result = Extension::<MyData>::from_request_parts(&request);
1389        assert!(result.is_err());
1390    }
1391
1392    // Cookies tests (feature-gated)
1393    #[cfg(feature = "cookies")]
1394    mod cookies_tests {
1395        use super::*;
1396
1397        // **Feature: phase3-batteries-included, Property 16: Cookies extractor parsing**
1398        //
1399        // For any request with Cookie header containing cookies C, the `Cookies` extractor
1400        // SHALL return a CookieJar containing exactly the cookies in C.
1401        // Note: Duplicate cookie names result in only the last value being kept.
1402        //
1403        // **Validates: Requirements 5.3**
1404        proptest! {
1405            #![proptest_config(ProptestConfig::with_cases(100))]
1406
1407            #[test]
1408            fn prop_cookies_extractor_parsing(
1409                // Generate random cookie names and values
1410                // Using alphanumeric strings to ensure valid cookie names/values
1411                cookies in prop::collection::vec(
1412                    (
1413                        "[a-zA-Z][a-zA-Z0-9_]{0,15}",  // Valid cookie name pattern
1414                        "[a-zA-Z0-9]{1,30}"            // Valid cookie value pattern (no special chars)
1415                    ),
1416                    0..5
1417                )
1418            ) {
1419                let result: Result<(), TestCaseError> = (|| {
1420                    // Build cookie header string
1421                    let cookie_header = cookies
1422                        .iter()
1423                        .map(|(name, value)| format!("{}={}", name, value))
1424                        .collect::<Vec<_>>()
1425                        .join("; ");
1426
1427                    let headers = if !cookies.is_empty() {
1428                        vec![("cookie", cookie_header.as_str())]
1429                    } else {
1430                        vec![]
1431                    };
1432
1433                    let request = create_test_request_with_headers(Method::GET, "/test", headers);
1434
1435                    // Extract cookies
1436                    let extracted = Cookies::from_request_parts(&request)
1437                        .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1438
1439                    // Build expected cookies map - last value wins for duplicate names
1440                    let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1441                    for (name, value) in &cookies {
1442                        expected_cookies.insert(name.as_str(), value.as_str());
1443                    }
1444
1445                    // Verify all expected cookies are present with correct values
1446                    for (name, expected_value) in &expected_cookies {
1447                        let cookie = extracted.get(name)
1448                            .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1449
1450                        prop_assert_eq!(
1451                            cookie.value(),
1452                            *expected_value,
1453                            "Cookie '{}' value mismatch",
1454                            name
1455                        );
1456                    }
1457
1458                    // Count cookies in jar should match unique cookie names
1459                    let extracted_count = extracted.iter().count();
1460                    prop_assert_eq!(
1461                        extracted_count,
1462                        expected_cookies.len(),
1463                        "Expected {} unique cookies, got {}",
1464                        expected_cookies.len(),
1465                        extracted_count
1466                    );
1467
1468                    Ok(())
1469                })();
1470                result?;
1471            }
1472        }
1473
1474        #[test]
1475        fn test_cookies_extractor_basic() {
1476            let request = create_test_request_with_headers(
1477                Method::GET,
1478                "/test",
1479                vec![("cookie", "session=abc123; user=john")],
1480            );
1481
1482            let cookies = Cookies::from_request_parts(&request).unwrap();
1483
1484            assert!(cookies.contains("session"));
1485            assert!(cookies.contains("user"));
1486            assert!(!cookies.contains("other"));
1487
1488            assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1489            assert_eq!(cookies.get("user").unwrap().value(), "john");
1490        }
1491
1492        #[test]
1493        fn test_cookies_extractor_empty() {
1494            let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1495
1496            let cookies = Cookies::from_request_parts(&request).unwrap();
1497            assert_eq!(cookies.iter().count(), 0);
1498        }
1499
1500        #[test]
1501        fn test_cookies_extractor_single() {
1502            let request = create_test_request_with_headers(
1503                Method::GET,
1504                "/test",
1505                vec![("cookie", "token=xyz789")],
1506            );
1507
1508            let cookies = Cookies::from_request_parts(&request).unwrap();
1509            assert_eq!(cookies.iter().count(), 1);
1510            assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1511        }
1512    }
1513}