rustapi_core/
extract.rs

1//! Extractors for RustAPI
2//!
3//! Extractors automatically parse and validate data from incoming HTTP requests.
4//! They implement the [`FromRequest`] or [`FromRequestParts`] traits and can be
5//! used as handler function parameters.
6//!
7//! # Available Extractors
8//!
9//! | Extractor | Description | Consumes Body |
10//! |-----------|-------------|---------------|
11//! | [`Json<T>`] | Parse JSON request body | Yes |
12//! | [`ValidatedJson<T>`] | Parse and validate JSON body | Yes |
13//! | [`Query<T>`] | Parse query string parameters | No |
14//! | [`Path<T>`] | Extract path parameters | No |
15//! | [`State<T>`] | Access shared application state | No |
16//! | [`Body`] | Raw request body bytes | Yes |
17//! | [`Headers`] | Access all request headers | No |
18//! | [`HeaderValue`] | Extract a specific header | No |
19//! | [`Extension<T>`] | Access middleware-injected data | No |
20//! | [`ClientIp`] | Extract client IP address | No |
21//! | [`Cookies`] | Parse request cookies (requires `cookies` feature) | No |
22//!
23//! # Example
24//!
25//! ```rust,ignore
26//! use rustapi_core::{Json, Query, Path, State};
27//! use serde::{Deserialize, Serialize};
28//!
29//! #[derive(Deserialize)]
30//! struct CreateUser {
31//!     name: String,
32//!     email: String,
33//! }
34//!
35//! #[derive(Deserialize)]
36//! struct Pagination {
37//!     page: Option<u32>,
38//!     limit: Option<u32>,
39//! }
40//!
41//! // Multiple extractors can be combined
42//! async fn create_user(
43//!     State(db): State<DbPool>,
44//!     Query(pagination): Query<Pagination>,
45//!     Json(body): Json<CreateUser>,
46//! ) -> impl IntoResponse {
47//!     // Use db, pagination, and body...
48//! }
49//! ```
50//!
51//! # Extractor Order
52//!
53//! When using multiple extractors, body-consuming extractors (like `Json` or `Body`)
54//! must come last since they consume the request body. Non-body extractors can be
55//! in any order.
56
57use crate::error::{ApiError, Result};
58use crate::request::Request;
59use crate::response::IntoResponse;
60use bytes::Bytes;
61use http::{header, StatusCode};
62use http_body_util::Full;
63use serde::de::DeserializeOwned;
64use serde::Serialize;
65use std::future::Future;
66use std::ops::{Deref, DerefMut};
67use std::str::FromStr;
68
69/// Trait for extracting data from request parts (headers, path, query)
70///
71/// This is used for extractors that don't need the request body.
72pub trait FromRequestParts: Sized {
73    /// Extract from request parts
74    fn from_request_parts(req: &Request) -> Result<Self>;
75}
76
77/// Trait for extracting data from the full request (including body)
78///
79/// This is used for extractors that consume the request body.
80pub trait FromRequest: Sized {
81    /// Extract from the full request
82    fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
83}
84
85// Blanket impl: FromRequestParts -> FromRequest
86impl<T: FromRequestParts> FromRequest for T {
87    async fn from_request(req: &mut Request) -> Result<Self> {
88        T::from_request_parts(req)
89    }
90}
91
92/// JSON body extractor
93///
94/// Parses the request body as JSON and deserializes into type `T`.
95/// Also works as a response type when T: Serialize.
96///
97/// # Example
98///
99/// ```rust,ignore
100/// #[derive(Deserialize)]
101/// struct CreateUser {
102///     name: String,
103///     email: String,
104/// }
105///
106/// async fn create_user(Json(body): Json<CreateUser>) -> impl IntoResponse {
107///     // body is already deserialized
108/// }
109/// ```
110#[derive(Debug, Clone, Copy, Default)]
111pub struct Json<T>(pub T);
112
113impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
114    async fn from_request(req: &mut Request) -> Result<Self> {
115        let body = req
116            .take_body()
117            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
118
119        let value: T = serde_json::from_slice(&body)?;
120        Ok(Json(value))
121    }
122}
123
124impl<T> Deref for Json<T> {
125    type Target = T;
126
127    fn deref(&self) -> &Self::Target {
128        &self.0
129    }
130}
131
132impl<T> DerefMut for Json<T> {
133    fn deref_mut(&mut self) -> &mut Self::Target {
134        &mut self.0
135    }
136}
137
138impl<T> From<T> for Json<T> {
139    fn from(value: T) -> Self {
140        Json(value)
141    }
142}
143
144// IntoResponse for Json - allows using Json<T> as a return type
145impl<T: Serialize> IntoResponse for Json<T> {
146    fn into_response(self) -> crate::response::Response {
147        match serde_json::to_vec(&self.0) {
148            Ok(body) => http::Response::builder()
149                .status(StatusCode::OK)
150                .header(header::CONTENT_TYPE, "application/json")
151                .body(Full::new(Bytes::from(body)))
152                .unwrap(),
153            Err(err) => {
154                ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
155            }
156        }
157    }
158}
159
160/// Validated JSON body extractor
161///
162/// Parses the request body as JSON, deserializes into type `T`, and validates
163/// using the `Validate` trait. Returns a 422 Unprocessable Entity error with
164/// detailed field-level validation errors if validation fails.
165///
166/// # Example
167///
168/// ```rust,ignore
169/// use rustapi_rs::prelude::*;
170/// use validator::Validate;
171///
172/// #[derive(Deserialize, Validate)]
173/// struct CreateUser {
174///     #[validate(email)]
175///     email: String,
176///     #[validate(length(min = 8))]
177///     password: String,
178/// }
179///
180/// async fn register(ValidatedJson(body): ValidatedJson<CreateUser>) -> impl IntoResponse {
181///     // body is already validated!
182///     // If email is invalid or password too short, a 422 error is returned automatically
183/// }
184/// ```
185#[derive(Debug, Clone, Copy, Default)]
186pub struct ValidatedJson<T>(pub T);
187
188impl<T> ValidatedJson<T> {
189    /// Create a new ValidatedJson wrapper
190    pub fn new(value: T) -> Self {
191        Self(value)
192    }
193
194    /// Get the inner value
195    pub fn into_inner(self) -> T {
196        self.0
197    }
198}
199
200impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
201    async fn from_request(req: &mut Request) -> Result<Self> {
202        let body = req
203            .take_body()
204            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
205
206        let value: T = serde_json::from_slice(&body)?;
207
208        // Then, validate it
209        if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
210            // Convert validation error to API error with 422 status
211            return Err(validation_error.into());
212        }
213
214        Ok(ValidatedJson(value))
215    }
216}
217
218impl<T> Deref for ValidatedJson<T> {
219    type Target = T;
220
221    fn deref(&self) -> &Self::Target {
222        &self.0
223    }
224}
225
226impl<T> DerefMut for ValidatedJson<T> {
227    fn deref_mut(&mut self) -> &mut Self::Target {
228        &mut self.0
229    }
230}
231
232impl<T> From<T> for ValidatedJson<T> {
233    fn from(value: T) -> Self {
234        ValidatedJson(value)
235    }
236}
237
238impl<T: Serialize> IntoResponse for ValidatedJson<T> {
239    fn into_response(self) -> crate::response::Response {
240        Json(self.0).into_response()
241    }
242}
243
244/// Query string extractor
245///
246/// Parses the query string into type `T`.
247///
248/// # Example
249///
250/// ```rust,ignore
251/// #[derive(Deserialize)]
252/// struct Pagination {
253///     page: Option<u32>,
254///     limit: Option<u32>,
255/// }
256///
257/// async fn list_users(Query(params): Query<Pagination>) -> impl IntoResponse {
258///     // params.page, params.limit
259/// }
260/// ```
261#[derive(Debug, Clone)]
262pub struct Query<T>(pub T);
263
264impl<T: DeserializeOwned> FromRequestParts for Query<T> {
265    fn from_request_parts(req: &Request) -> Result<Self> {
266        let query = req.query_string().unwrap_or("");
267        let value: T = serde_urlencoded::from_str(query)
268            .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
269        Ok(Query(value))
270    }
271}
272
273impl<T> Deref for Query<T> {
274    type Target = T;
275
276    fn deref(&self) -> &Self::Target {
277        &self.0
278    }
279}
280
281/// Path parameter extractor
282///
283/// Extracts path parameters defined in the route pattern.
284///
285/// # Example
286///
287/// For route `/users/{id}`:
288///
289/// ```rust,ignore
290/// async fn get_user(Path(id): Path<i64>) -> impl IntoResponse {
291///     // id is extracted from path
292/// }
293/// ```
294///
295/// For multiple params `/users/{user_id}/posts/{post_id}`:
296///
297/// ```rust,ignore
298/// async fn get_post(Path((user_id, post_id)): Path<(i64, i64)>) -> impl IntoResponse {
299///     // Both params extracted
300/// }
301/// ```
302#[derive(Debug, Clone)]
303pub struct Path<T>(pub T);
304
305impl<T: FromStr> FromRequestParts for Path<T>
306where
307    T::Err: std::fmt::Display,
308{
309    fn from_request_parts(req: &Request) -> Result<Self> {
310        let params = req.path_params();
311
312        // For single param, get the first one
313        if let Some((_, value)) = params.iter().next() {
314            let parsed = value
315                .parse::<T>()
316                .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
317            return Ok(Path(parsed));
318        }
319
320        Err(ApiError::internal("Missing path parameter"))
321    }
322}
323
324impl<T> Deref for Path<T> {
325    type Target = T;
326
327    fn deref(&self) -> &Self::Target {
328        &self.0
329    }
330}
331
332/// State extractor
333///
334/// Extracts shared application state.
335///
336/// # Example
337///
338/// ```rust,ignore
339/// #[derive(Clone)]
340/// struct AppState {
341///     db: DbPool,
342/// }
343///
344/// async fn handler(State(state): State<AppState>) -> impl IntoResponse {
345///     // Use state.db
346/// }
347/// ```
348#[derive(Debug, Clone)]
349pub struct State<T>(pub T);
350
351impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
352    fn from_request_parts(req: &Request) -> Result<Self> {
353        req.state().get::<T>().cloned().map(State).ok_or_else(|| {
354            ApiError::internal(format!(
355                "State of type `{}` not found. Did you forget to call .state()?",
356                std::any::type_name::<T>()
357            ))
358        })
359    }
360}
361
362impl<T> Deref for State<T> {
363    type Target = T;
364
365    fn deref(&self) -> &Self::Target {
366        &self.0
367    }
368}
369
370/// Raw body bytes extractor
371#[derive(Debug, Clone)]
372pub struct Body(pub Bytes);
373
374impl FromRequest for Body {
375    async fn from_request(req: &mut Request) -> Result<Self> {
376        let body = req
377            .take_body()
378            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
379        Ok(Body(body))
380    }
381}
382
383impl Deref for Body {
384    type Target = Bytes;
385
386    fn deref(&self) -> &Self::Target {
387        &self.0
388    }
389}
390
391/// Optional extractor wrapper
392///
393/// Makes any extractor optional - returns None instead of error on failure.
394impl<T: FromRequestParts> FromRequestParts for Option<T> {
395    fn from_request_parts(req: &Request) -> Result<Self> {
396        Ok(T::from_request_parts(req).ok())
397    }
398}
399
400/// Headers extractor
401///
402/// Provides access to all request headers as a typed map.
403///
404/// # Example
405///
406/// ```rust,ignore
407/// use rustapi_core::extract::Headers;
408///
409/// async fn handler(headers: Headers) -> impl IntoResponse {
410///     if let Some(content_type) = headers.get("content-type") {
411///         format!("Content-Type: {:?}", content_type)
412///     } else {
413///         "No Content-Type header".to_string()
414///     }
415/// }
416/// ```
417#[derive(Debug, Clone)]
418pub struct Headers(pub http::HeaderMap);
419
420impl Headers {
421    /// Get a header value by name
422    pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
423        self.0.get(name)
424    }
425
426    /// Check if a header exists
427    pub fn contains(&self, name: &str) -> bool {
428        self.0.contains_key(name)
429    }
430
431    /// Get the number of headers
432    pub fn len(&self) -> usize {
433        self.0.len()
434    }
435
436    /// Check if headers are empty
437    pub fn is_empty(&self) -> bool {
438        self.0.is_empty()
439    }
440
441    /// Iterate over all headers
442    pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
443        self.0.iter()
444    }
445}
446
447impl FromRequestParts for Headers {
448    fn from_request_parts(req: &Request) -> Result<Self> {
449        Ok(Headers(req.headers().clone()))
450    }
451}
452
453impl Deref for Headers {
454    type Target = http::HeaderMap;
455
456    fn deref(&self) -> &Self::Target {
457        &self.0
458    }
459}
460
461/// Single header value extractor
462///
463/// Extracts a specific header value by name. Returns an error if the header is missing.
464///
465/// # Example
466///
467/// ```rust,ignore
468/// use rustapi_core::extract::HeaderValue;
469///
470/// async fn handler(
471///     auth: HeaderValue<{ "authorization" }>,
472/// ) -> impl IntoResponse {
473///     format!("Auth header: {}", auth.0)
474/// }
475/// ```
476///
477/// Note: Due to Rust's const generics limitations, you may need to use the
478/// `HeaderValueOf` type alias or extract headers manually using the `Headers` extractor.
479#[derive(Debug, Clone)]
480pub struct HeaderValue(pub String, pub &'static str);
481
482impl HeaderValue {
483    /// Create a new HeaderValue extractor for a specific header name
484    pub fn new(name: &'static str, value: String) -> Self {
485        Self(value, name)
486    }
487
488    /// Get the header value
489    pub fn value(&self) -> &str {
490        &self.0
491    }
492
493    /// Get the header name
494    pub fn name(&self) -> &'static str {
495        self.1
496    }
497
498    /// Extract a specific header from a request
499    pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
500        req.headers()
501            .get(name)
502            .and_then(|v| v.to_str().ok())
503            .map(|s| HeaderValue(s.to_string(), name))
504            .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
505    }
506}
507
508impl Deref for HeaderValue {
509    type Target = String;
510
511    fn deref(&self) -> &Self::Target {
512        &self.0
513    }
514}
515
516/// Extension extractor
517///
518/// Retrieves typed data from request extensions that was inserted by middleware.
519///
520/// # Example
521///
522/// ```rust,ignore
523/// use rustapi_core::extract::Extension;
524///
525/// // Middleware inserts user data
526/// #[derive(Clone)]
527/// struct CurrentUser { id: i64 }
528///
529/// async fn handler(Extension(user): Extension<CurrentUser>) -> impl IntoResponse {
530///     format!("User ID: {}", user.id)
531/// }
532/// ```
533#[derive(Debug, Clone)]
534pub struct Extension<T>(pub T);
535
536impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
537    fn from_request_parts(req: &Request) -> Result<Self> {
538        req.extensions()
539            .get::<T>()
540            .cloned()
541            .map(Extension)
542            .ok_or_else(|| {
543                ApiError::internal(format!(
544                    "Extension of type `{}` not found. Did middleware insert it?",
545                    std::any::type_name::<T>()
546                ))
547            })
548    }
549}
550
551impl<T> Deref for Extension<T> {
552    type Target = T;
553
554    fn deref(&self) -> &Self::Target {
555        &self.0
556    }
557}
558
559impl<T> DerefMut for Extension<T> {
560    fn deref_mut(&mut self) -> &mut Self::Target {
561        &mut self.0
562    }
563}
564
565/// Client IP address extractor
566///
567/// Extracts the client IP address from the request. When `trust_proxy` is enabled,
568/// it will use the `X-Forwarded-For` header if present.
569///
570/// # Example
571///
572/// ```rust,ignore
573/// use rustapi_core::extract::ClientIp;
574///
575/// async fn handler(ClientIp(ip): ClientIp) -> impl IntoResponse {
576///     format!("Your IP: {}", ip)
577/// }
578/// ```
579#[derive(Debug, Clone)]
580pub struct ClientIp(pub std::net::IpAddr);
581
582impl ClientIp {
583    /// Extract client IP, optionally trusting X-Forwarded-For header
584    pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
585        if trust_proxy {
586            // Try X-Forwarded-For header first
587            if let Some(forwarded) = req.headers().get("x-forwarded-for") {
588                if let Ok(forwarded_str) = forwarded.to_str() {
589                    // X-Forwarded-For can contain multiple IPs, take the first one
590                    if let Some(first_ip) = forwarded_str.split(',').next() {
591                        if let Ok(ip) = first_ip.trim().parse() {
592                            return Ok(ClientIp(ip));
593                        }
594                    }
595                }
596            }
597        }
598
599        // Fall back to socket address from extensions (if set by server)
600        if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
601            return Ok(ClientIp(addr.ip()));
602        }
603
604        // Default to localhost if no IP information available
605        Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
606            127, 0, 0, 1,
607        ))))
608    }
609}
610
611impl FromRequestParts for ClientIp {
612    fn from_request_parts(req: &Request) -> Result<Self> {
613        // By default, trust proxy headers
614        Self::extract_with_config(req, true)
615    }
616}
617
618/// Cookies extractor
619///
620/// Parses and provides access to request cookies from the Cookie header.
621///
622/// # Example
623///
624/// ```rust,ignore
625/// use rustapi_core::extract::Cookies;
626///
627/// async fn handler(cookies: Cookies) -> impl IntoResponse {
628///     if let Some(session) = cookies.get("session_id") {
629///         format!("Session: {}", session.value())
630///     } else {
631///         "No session cookie".to_string()
632///     }
633/// }
634/// ```
635#[cfg(feature = "cookies")]
636#[derive(Debug, Clone)]
637pub struct Cookies(pub cookie::CookieJar);
638
639#[cfg(feature = "cookies")]
640impl Cookies {
641    /// Get a cookie by name
642    pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
643        self.0.get(name)
644    }
645
646    /// Iterate over all cookies
647    pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
648        self.0.iter()
649    }
650
651    /// Check if a cookie exists
652    pub fn contains(&self, name: &str) -> bool {
653        self.0.get(name).is_some()
654    }
655}
656
657#[cfg(feature = "cookies")]
658impl FromRequestParts for Cookies {
659    fn from_request_parts(req: &Request) -> Result<Self> {
660        let mut jar = cookie::CookieJar::new();
661
662        if let Some(cookie_header) = req.headers().get(header::COOKIE) {
663            if let Ok(cookie_str) = cookie_header.to_str() {
664                // Parse each cookie from the header
665                for cookie_part in cookie_str.split(';') {
666                    let trimmed = cookie_part.trim();
667                    if !trimmed.is_empty() {
668                        if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
669                            jar.add_original(cookie.into_owned());
670                        }
671                    }
672                }
673            }
674        }
675
676        Ok(Cookies(jar))
677    }
678}
679
680#[cfg(feature = "cookies")]
681impl Deref for Cookies {
682    type Target = cookie::CookieJar;
683
684    fn deref(&self) -> &Self::Target {
685        &self.0
686    }
687}
688
689// Implement FromRequestParts for common primitive types (path params)
690macro_rules! impl_from_request_parts_for_primitives {
691    ($($ty:ty),*) => {
692        $(
693            impl FromRequestParts for $ty {
694                fn from_request_parts(req: &Request) -> Result<Self> {
695                    let Path(value) = Path::<$ty>::from_request_parts(req)?;
696                    Ok(value)
697                }
698            }
699        )*
700    };
701}
702
703impl_from_request_parts_for_primitives!(
704    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
705);
706
707// OperationModifier implementations for extractors
708
709use rustapi_openapi::utoipa_types::openapi;
710use rustapi_openapi::{
711    IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
712    ResponseSpec, Schema, SchemaRef,
713};
714use std::collections::HashMap;
715
716// ValidatedJson - Adds request body
717impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
718    fn update_operation(op: &mut Operation) {
719        let (name, _) = T::schema();
720
721        let schema_ref = SchemaRef::Ref {
722            reference: format!("#/components/schemas/{}", name),
723        };
724
725        let mut content = HashMap::new();
726        content.insert(
727            "application/json".to_string(),
728            MediaType { schema: schema_ref },
729        );
730
731        op.request_body = Some(RequestBody {
732            required: true,
733            content,
734        });
735
736        // Add 422 Validation Error response
737        op.responses.insert(
738            "422".to_string(),
739            ResponseSpec {
740                description: "Validation Error".to_string(),
741                content: {
742                    let mut map = HashMap::new();
743                    map.insert(
744                        "application/json".to_string(),
745                        MediaType {
746                            schema: SchemaRef::Ref {
747                                reference: "#/components/schemas/ValidationErrorSchema".to_string(),
748                            },
749                        },
750                    );
751                    Some(map)
752                },
753            },
754        );
755    }
756}
757
758// Json - Adds request body (Same as ValidatedJson)
759impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
760    fn update_operation(op: &mut Operation) {
761        let (name, _) = T::schema();
762
763        let schema_ref = SchemaRef::Ref {
764            reference: format!("#/components/schemas/{}", name),
765        };
766
767        let mut content = HashMap::new();
768        content.insert(
769            "application/json".to_string(),
770            MediaType { schema: schema_ref },
771        );
772
773        op.request_body = Some(RequestBody {
774            required: true,
775            content,
776        });
777    }
778}
779
780// Path - Path parameters are automatically extracted from route patterns
781// The add_path_params_to_operation function in app.rs handles OpenAPI documentation
782// based on the {param} syntax in route paths (e.g., "/users/{id}")
783impl<T> OperationModifier for Path<T> {
784    fn update_operation(_op: &mut Operation) {
785        // Path parameters are automatically documented by add_path_params_to_operation
786        // in app.rs based on the route pattern. No additional implementation needed here.
787        //
788        // For typed path params, the schema type defaults to "string" but will be
789        // inferred from the actual type T when more sophisticated type introspection
790        // is implemented.
791    }
792}
793
794// Query - Extracts query params using IntoParams
795impl<T: IntoParams> OperationModifier for Query<T> {
796    fn update_operation(op: &mut Operation) {
797        let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
798
799        let new_params: Vec<Parameter> = params
800            .into_iter()
801            .map(|p| {
802                let schema = match p.schema {
803                    Some(schema) => match schema {
804                        openapi::RefOr::Ref(r) => SchemaRef::Ref {
805                            reference: r.ref_location,
806                        },
807                        openapi::RefOr::T(s) => {
808                            let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
809                            SchemaRef::Inline(value)
810                        }
811                    },
812                    None => SchemaRef::Inline(serde_json::Value::Null),
813                };
814
815                let required = match p.required {
816                    openapi::Required::True => true,
817                    openapi::Required::False => false,
818                };
819
820                Parameter {
821                    name: p.name,
822                    location: "query".to_string(), // explicitly query
823                    required,
824                    description: p.description,
825                    schema,
826                }
827            })
828            .collect();
829
830        if let Some(existing) = &mut op.parameters {
831            existing.extend(new_params);
832        } else {
833            op.parameters = Some(new_params);
834        }
835    }
836}
837
838// State - No op
839impl<T> OperationModifier for State<T> {
840    fn update_operation(_op: &mut Operation) {}
841}
842
843// Body - Generic binary body
844impl OperationModifier for Body {
845    fn update_operation(op: &mut Operation) {
846        let mut content = HashMap::new();
847        content.insert(
848            "application/octet-stream".to_string(),
849            MediaType {
850                schema: SchemaRef::Inline(
851                    serde_json::json!({ "type": "string", "format": "binary" }),
852                ),
853            },
854        );
855
856        op.request_body = Some(RequestBody {
857            required: true,
858            content,
859        });
860    }
861}
862
863// ResponseModifier implementations for extractors
864
865// Json<T> - 200 OK with schema T
866impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
867    fn update_response(op: &mut Operation) {
868        let (name, _) = T::schema();
869
870        let schema_ref = SchemaRef::Ref {
871            reference: format!("#/components/schemas/{}", name),
872        };
873
874        op.responses.insert(
875            "200".to_string(),
876            ResponseSpec {
877                description: "Successful response".to_string(),
878                content: {
879                    let mut map = HashMap::new();
880                    map.insert(
881                        "application/json".to_string(),
882                        MediaType { schema: schema_ref },
883                    );
884                    Some(map)
885                },
886            },
887        );
888    }
889}
890
891#[cfg(test)]
892mod tests {
893    use super::*;
894    use bytes::Bytes;
895    use http::{Extensions, Method};
896    use proptest::prelude::*;
897    use proptest::test_runner::TestCaseError;
898    use std::collections::HashMap;
899    use std::sync::Arc;
900
901    /// Create a test request with the given method, path, and headers
902    fn create_test_request_with_headers(
903        method: Method,
904        path: &str,
905        headers: Vec<(&str, &str)>,
906    ) -> Request {
907        let uri: http::Uri = path.parse().unwrap();
908        let mut builder = http::Request::builder().method(method).uri(uri);
909
910        for (name, value) in headers {
911            builder = builder.header(name, value);
912        }
913
914        let req = builder.body(()).unwrap();
915        let (parts, _) = req.into_parts();
916
917        Request::new(
918            parts,
919            Bytes::new(),
920            Arc::new(Extensions::new()),
921            HashMap::new(),
922        )
923    }
924
925    /// Create a test request with extensions
926    fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
927        method: Method,
928        path: &str,
929        extension: T,
930    ) -> Request {
931        let uri: http::Uri = path.parse().unwrap();
932        let builder = http::Request::builder().method(method).uri(uri);
933
934        let req = builder.body(()).unwrap();
935        let (mut parts, _) = req.into_parts();
936        parts.extensions.insert(extension);
937
938        Request::new(
939            parts,
940            Bytes::new(),
941            Arc::new(Extensions::new()),
942            HashMap::new(),
943        )
944    }
945
946    // **Feature: phase3-batteries-included, Property 14: Headers extractor completeness**
947    //
948    // For any request with headers H, the `Headers` extractor SHALL return a map
949    // containing all key-value pairs in H.
950    //
951    // **Validates: Requirements 5.1**
952    proptest! {
953        #![proptest_config(ProptestConfig::with_cases(100))]
954
955        #[test]
956        fn prop_headers_extractor_completeness(
957            // Generate random header names and values
958            // Using alphanumeric strings to ensure valid header names/values
959            headers in prop::collection::vec(
960                (
961                    "[a-z][a-z0-9-]{0,20}",  // Valid header name pattern
962                    "[a-zA-Z0-9 ]{1,50}"     // Valid header value pattern
963                ),
964                0..10
965            )
966        ) {
967            let result: Result<(), TestCaseError> = (|| {
968                // Convert to header tuples
969                let header_tuples: Vec<(&str, &str)> = headers
970                    .iter()
971                    .map(|(k, v)| (k.as_str(), v.as_str()))
972                    .collect();
973
974                // Create request with headers
975                let request = create_test_request_with_headers(
976                    Method::GET,
977                    "/test",
978                    header_tuples.clone(),
979                );
980
981                // Extract headers
982                let extracted = Headers::from_request_parts(&request)
983                    .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
984
985                // Verify all original headers are present
986                // HTTP allows duplicate headers - get_all() returns all values for a header name
987                for (name, value) in &headers {
988                    // Check that the header name exists
989                    let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
990                    prop_assert!(
991                        !all_values.is_empty(),
992                        "Header '{}' not found",
993                        name
994                    );
995
996                    // Check that the value is among the extracted values
997                    let value_found = all_values.iter().any(|v| {
998                        v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
999                    });
1000
1001                    prop_assert!(
1002                        value_found,
1003                        "Header '{}' value '{}' not found in extracted values",
1004                        name,
1005                        value
1006                    );
1007                }
1008
1009                Ok(())
1010            })();
1011            result?;
1012        }
1013    }
1014
1015    // **Feature: phase3-batteries-included, Property 15: HeaderValue extractor correctness**
1016    //
1017    // For any request with header "X" having value V, `HeaderValue::extract(req, "X")` SHALL return V;
1018    // for requests without header "X", it SHALL return an error.
1019    //
1020    // **Validates: Requirements 5.2**
1021    proptest! {
1022        #![proptest_config(ProptestConfig::with_cases(100))]
1023
1024        #[test]
1025        fn prop_header_value_extractor_correctness(
1026            header_name in "[a-z][a-z0-9-]{0,20}",
1027            header_value in "[a-zA-Z0-9 ]{1,50}",
1028            has_header in prop::bool::ANY,
1029        ) {
1030            let result: Result<(), TestCaseError> = (|| {
1031                let headers = if has_header {
1032                    vec![(header_name.as_str(), header_value.as_str())]
1033                } else {
1034                    vec![]
1035                };
1036
1037                let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1038
1039                // We need to use a static string for the header name in the extractor
1040                // So we'll test with a known header name
1041                let test_header = "x-test-header";
1042                let request_with_known_header = if has_header {
1043                    create_test_request_with_headers(
1044                        Method::GET,
1045                        "/test",
1046                        vec![(test_header, header_value.as_str())],
1047                    )
1048                } else {
1049                    create_test_request_with_headers(Method::GET, "/test", vec![])
1050                };
1051
1052                let result = HeaderValue::extract(&request_with_known_header, test_header);
1053
1054                if has_header {
1055                    let extracted = result
1056                        .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1057                    prop_assert_eq!(
1058                        extracted.value(),
1059                        header_value.as_str(),
1060                        "Header value mismatch"
1061                    );
1062                } else {
1063                    prop_assert!(
1064                        result.is_err(),
1065                        "Expected error when header is missing"
1066                    );
1067                }
1068
1069                Ok(())
1070            })();
1071            result?;
1072        }
1073    }
1074
1075    // **Feature: phase3-batteries-included, Property 17: ClientIp extractor with forwarding**
1076    //
1077    // For any request with socket IP S and X-Forwarded-For header F, when forwarding is enabled,
1078    // `ClientIp` SHALL return the first IP in F; when disabled, it SHALL return S.
1079    //
1080    // **Validates: Requirements 5.4**
1081    proptest! {
1082        #![proptest_config(ProptestConfig::with_cases(100))]
1083
1084        #[test]
1085        fn prop_client_ip_extractor_with_forwarding(
1086            // Generate valid IPv4 addresses
1087            forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1088                .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1089            socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1090                .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1091            has_forwarded_header in prop::bool::ANY,
1092            trust_proxy in prop::bool::ANY,
1093        ) {
1094            let result: Result<(), TestCaseError> = (|| {
1095                let headers = if has_forwarded_header {
1096                    vec![("x-forwarded-for", forwarded_ip.as_str())]
1097                } else {
1098                    vec![]
1099                };
1100
1101                // Create request with headers
1102                let uri: http::Uri = "/test".parse().unwrap();
1103                let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1104                for (name, value) in &headers {
1105                    builder = builder.header(*name, *value);
1106                }
1107                let req = builder.body(()).unwrap();
1108                let (mut parts, _) = req.into_parts();
1109
1110                // Add socket address to extensions
1111                let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1112                parts.extensions.insert(socket_addr);
1113
1114                let request = Request::new(
1115                    parts,
1116                    Bytes::new(),
1117                    Arc::new(Extensions::new()),
1118                    HashMap::new(),
1119                );
1120
1121                let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1122                    .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1123
1124                if trust_proxy && has_forwarded_header {
1125                    // Should use X-Forwarded-For
1126                    let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1127                        .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1128                    prop_assert_eq!(
1129                        extracted.0,
1130                        expected_ip,
1131                        "Should use X-Forwarded-For IP when trust_proxy is enabled"
1132                    );
1133                } else {
1134                    // Should use socket IP
1135                    prop_assert_eq!(
1136                        extracted.0,
1137                        socket_ip,
1138                        "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1139                    );
1140                }
1141
1142                Ok(())
1143            })();
1144            result?;
1145        }
1146    }
1147
1148    // **Feature: phase3-batteries-included, Property 18: Extension extractor retrieval**
1149    //
1150    // For any type T and value V inserted into request extensions by middleware,
1151    // `Extension<T>` SHALL return V.
1152    //
1153    // **Validates: Requirements 5.5**
1154    proptest! {
1155        #![proptest_config(ProptestConfig::with_cases(100))]
1156
1157        #[test]
1158        fn prop_extension_extractor_retrieval(
1159            value in any::<i64>(),
1160            has_extension in prop::bool::ANY,
1161        ) {
1162            let result: Result<(), TestCaseError> = (|| {
1163                // Create a simple wrapper type for testing
1164                #[derive(Clone, Debug, PartialEq)]
1165                struct TestExtension(i64);
1166
1167                let uri: http::Uri = "/test".parse().unwrap();
1168                let builder = http::Request::builder().method(Method::GET).uri(uri);
1169                let req = builder.body(()).unwrap();
1170                let (mut parts, _) = req.into_parts();
1171
1172                if has_extension {
1173                    parts.extensions.insert(TestExtension(value));
1174                }
1175
1176                let request = Request::new(
1177                    parts,
1178                    Bytes::new(),
1179                    Arc::new(Extensions::new()),
1180                    HashMap::new(),
1181                );
1182
1183                let result = Extension::<TestExtension>::from_request_parts(&request);
1184
1185                if has_extension {
1186                    let extracted = result
1187                        .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1188                    prop_assert_eq!(
1189                        extracted.0,
1190                        TestExtension(value),
1191                        "Extension value mismatch"
1192                    );
1193                } else {
1194                    prop_assert!(
1195                        result.is_err(),
1196                        "Expected error when extension is missing"
1197                    );
1198                }
1199
1200                Ok(())
1201            })();
1202            result?;
1203        }
1204    }
1205
1206    // Unit tests for basic functionality
1207
1208    #[test]
1209    fn test_headers_extractor_basic() {
1210        let request = create_test_request_with_headers(
1211            Method::GET,
1212            "/test",
1213            vec![
1214                ("content-type", "application/json"),
1215                ("accept", "text/html"),
1216            ],
1217        );
1218
1219        let headers = Headers::from_request_parts(&request).unwrap();
1220
1221        assert!(headers.contains("content-type"));
1222        assert!(headers.contains("accept"));
1223        assert!(!headers.contains("x-custom"));
1224        assert_eq!(headers.len(), 2);
1225    }
1226
1227    #[test]
1228    fn test_header_value_extractor_present() {
1229        let request = create_test_request_with_headers(
1230            Method::GET,
1231            "/test",
1232            vec![("authorization", "Bearer token123")],
1233        );
1234
1235        let result = HeaderValue::extract(&request, "authorization");
1236        assert!(result.is_ok());
1237        assert_eq!(result.unwrap().value(), "Bearer token123");
1238    }
1239
1240    #[test]
1241    fn test_header_value_extractor_missing() {
1242        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1243
1244        let result = HeaderValue::extract(&request, "authorization");
1245        assert!(result.is_err());
1246    }
1247
1248    #[test]
1249    fn test_client_ip_from_forwarded_header() {
1250        let request = create_test_request_with_headers(
1251            Method::GET,
1252            "/test",
1253            vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1254        );
1255
1256        let ip = ClientIp::extract_with_config(&request, true).unwrap();
1257        assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1258    }
1259
1260    #[test]
1261    fn test_client_ip_ignores_forwarded_when_not_trusted() {
1262        let uri: http::Uri = "/test".parse().unwrap();
1263        let builder = http::Request::builder()
1264            .method(Method::GET)
1265            .uri(uri)
1266            .header("x-forwarded-for", "192.168.1.100");
1267        let req = builder.body(()).unwrap();
1268        let (mut parts, _) = req.into_parts();
1269
1270        let socket_addr = std::net::SocketAddr::new(
1271            std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1272            8080,
1273        );
1274        parts.extensions.insert(socket_addr);
1275
1276        let request = Request::new(
1277            parts,
1278            Bytes::new(),
1279            Arc::new(Extensions::new()),
1280            HashMap::new(),
1281        );
1282
1283        let ip = ClientIp::extract_with_config(&request, false).unwrap();
1284        assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1285    }
1286
1287    #[test]
1288    fn test_extension_extractor_present() {
1289        #[derive(Clone, Debug, PartialEq)]
1290        struct MyData(String);
1291
1292        let request =
1293            create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1294
1295        let result = Extension::<MyData>::from_request_parts(&request);
1296        assert!(result.is_ok());
1297        assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1298    }
1299
1300    #[test]
1301    fn test_extension_extractor_missing() {
1302        #[derive(Clone, Debug)]
1303        #[allow(dead_code)]
1304        struct MyData(String);
1305
1306        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1307
1308        let result = Extension::<MyData>::from_request_parts(&request);
1309        assert!(result.is_err());
1310    }
1311
1312    // Cookies tests (feature-gated)
1313    #[cfg(feature = "cookies")]
1314    mod cookies_tests {
1315        use super::*;
1316
1317        // **Feature: phase3-batteries-included, Property 16: Cookies extractor parsing**
1318        //
1319        // For any request with Cookie header containing cookies C, the `Cookies` extractor
1320        // SHALL return a CookieJar containing exactly the cookies in C.
1321        // Note: Duplicate cookie names result in only the last value being kept.
1322        //
1323        // **Validates: Requirements 5.3**
1324        proptest! {
1325            #![proptest_config(ProptestConfig::with_cases(100))]
1326
1327            #[test]
1328            fn prop_cookies_extractor_parsing(
1329                // Generate random cookie names and values
1330                // Using alphanumeric strings to ensure valid cookie names/values
1331                cookies in prop::collection::vec(
1332                    (
1333                        "[a-zA-Z][a-zA-Z0-9_]{0,15}",  // Valid cookie name pattern
1334                        "[a-zA-Z0-9]{1,30}"            // Valid cookie value pattern (no special chars)
1335                    ),
1336                    0..5
1337                )
1338            ) {
1339                let result: Result<(), TestCaseError> = (|| {
1340                    // Build cookie header string
1341                    let cookie_header = cookies
1342                        .iter()
1343                        .map(|(name, value)| format!("{}={}", name, value))
1344                        .collect::<Vec<_>>()
1345                        .join("; ");
1346
1347                    let headers = if !cookies.is_empty() {
1348                        vec![("cookie", cookie_header.as_str())]
1349                    } else {
1350                        vec![]
1351                    };
1352
1353                    let request = create_test_request_with_headers(Method::GET, "/test", headers);
1354
1355                    // Extract cookies
1356                    let extracted = Cookies::from_request_parts(&request)
1357                        .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1358
1359                    // Build expected cookies map - last value wins for duplicate names
1360                    let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1361                    for (name, value) in &cookies {
1362                        expected_cookies.insert(name.as_str(), value.as_str());
1363                    }
1364
1365                    // Verify all expected cookies are present with correct values
1366                    for (name, expected_value) in &expected_cookies {
1367                        let cookie = extracted.get(name)
1368                            .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1369
1370                        prop_assert_eq!(
1371                            cookie.value(),
1372                            *expected_value,
1373                            "Cookie '{}' value mismatch",
1374                            name
1375                        );
1376                    }
1377
1378                    // Count cookies in jar should match unique cookie names
1379                    let extracted_count = extracted.iter().count();
1380                    prop_assert_eq!(
1381                        extracted_count,
1382                        expected_cookies.len(),
1383                        "Expected {} unique cookies, got {}",
1384                        expected_cookies.len(),
1385                        extracted_count
1386                    );
1387
1388                    Ok(())
1389                })();
1390                result?;
1391            }
1392        }
1393
1394        #[test]
1395        fn test_cookies_extractor_basic() {
1396            let request = create_test_request_with_headers(
1397                Method::GET,
1398                "/test",
1399                vec![("cookie", "session=abc123; user=john")],
1400            );
1401
1402            let cookies = Cookies::from_request_parts(&request).unwrap();
1403
1404            assert!(cookies.contains("session"));
1405            assert!(cookies.contains("user"));
1406            assert!(!cookies.contains("other"));
1407
1408            assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1409            assert_eq!(cookies.get("user").unwrap().value(), "john");
1410        }
1411
1412        #[test]
1413        fn test_cookies_extractor_empty() {
1414            let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1415
1416            let cookies = Cookies::from_request_parts(&request).unwrap();
1417            assert_eq!(cookies.iter().count(), 0);
1418        }
1419
1420        #[test]
1421        fn test_cookies_extractor_single() {
1422            let request = create_test_request_with_headers(
1423                Method::GET,
1424                "/test",
1425                vec![("cookie", "token=xyz789")],
1426            );
1427
1428            let cookies = Cookies::from_request_parts(&request).unwrap();
1429            assert_eq!(cookies.iter().count(), 1);
1430            assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1431        }
1432    }
1433}