Skip to main content

rustapi_core/
extract.rs

1//! Extractors for RustAPI
2//!
3//! Extractors automatically parse and validate data from incoming HTTP requests.
4//! They implement the [`FromRequest`] or [`FromRequestParts`] traits and can be
5//! used as handler function parameters.
6//!
7//! # Available Extractors
8//!
9//! | Extractor | Description | Consumes Body |
10//! |-----------|-------------|---------------|
11//! | [`Json<T>`] | Parse JSON request body | Yes |
12//! | [`ValidatedJson<T>`] | Parse and validate JSON body | Yes |
13//! | [`Query<T>`] | Parse query string parameters | No |
14//! | [`Path<T>`] | Extract path parameters | No |
15//! | [`State<T>`] | Access shared application state | No |
16//! | [`Body`] | Raw request body bytes | Yes |
17//! | [`Headers`] | Access all request headers | No |
18//! | [`HeaderValue`] | Extract a specific header | No |
19//! | [`Extension<T>`] | Access middleware-injected data | No |
20//! | [`ClientIp`] | Extract client IP address | No |
21//! | [`Cookies`] | Parse request cookies (requires `cookies` feature) | No |
22//!
23//! # Example
24//!
25//! ```rust,ignore
26//! use rustapi_core::{Json, Query, Path, State};
27//! use serde::{Deserialize, Serialize};
28//!
29//! #[derive(Deserialize)]
30//! struct CreateUser {
31//!     name: String,
32//!     email: String,
33//! }
34//!
35//! #[derive(Deserialize)]
36//! struct Pagination {
37//!     page: Option<u32>,
38//!     limit: Option<u32>,
39//! }
40//!
41//! // Multiple extractors can be combined
42//! async fn create_user(
43//!     State(db): State<DbPool>,
44//!     Query(pagination): Query<Pagination>,
45//!     Json(body): Json<CreateUser>,
46//! ) -> impl IntoResponse {
47//!     // Use db, pagination, and body...
48//! }
49//! ```
50//!
51//! # Extractor Order
52//!
53//! When using multiple extractors, body-consuming extractors (like `Json` or `Body`)
54//! must come last since they consume the request body. Non-body extractors can be
55//! in any order.
56
57use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use rustapi_openapi::schema::{RustApiSchema, SchemaCtx, SchemaRef};
68use serde::de::DeserializeOwned;
69use serde::Serialize;
70use std::collections::BTreeMap;
71use std::future::Future;
72use std::ops::{Deref, DerefMut};
73use std::str::FromStr;
74
75/// Trait for extracting data from request parts (headers, path, query)
76///
77/// This is used for extractors that don't need the request body.
78///
79/// # Example: Implementing a custom extractor that requires a specific header
80///
81/// ```rust
82/// use rustapi_core::FromRequestParts;
83/// use rustapi_core::{Request, ApiError, Result};
84/// use http::StatusCode;
85///
86/// struct ApiKey(String);
87///
88/// impl FromRequestParts for ApiKey {
89///     fn from_request_parts(req: &Request) -> Result<Self> {
90///         if let Some(key) = req.headers().get("x-api-key") {
91///             if let Ok(key_str) = key.to_str() {
92///                 return Ok(ApiKey(key_str.to_string()));
93///             }
94///         }
95///         Err(ApiError::unauthorized("Missing or invalid API key"))
96///     }
97/// }
98/// ```
99pub trait FromRequestParts: Sized {
100    /// Extract from request parts
101    fn from_request_parts(req: &Request) -> Result<Self>;
102}
103
104/// Trait for extracting data from the full request (including body)
105///
106/// This is used for extractors that consume the request body.
107///
108/// # Example: Implementing a custom extractor that consumes the body
109///
110/// ```rust
111/// use rustapi_core::FromRequest;
112/// use rustapi_core::{Request, ApiError, Result};
113/// use std::future::Future;
114///
115/// struct PlainText(String);
116///
117/// impl FromRequest for PlainText {
118///     async fn from_request(req: &mut Request) -> Result<Self> {
119///         // Ensure body is loaded
120///         req.load_body().await?;
121///         
122///         // Consume the body
123///         if let Some(bytes) = req.take_body() {
124///             if let Ok(text) = String::from_utf8(bytes.to_vec()) {
125///                 return Ok(PlainText(text));
126///             }
127///         }
128///         
129///         Err(ApiError::bad_request("Invalid plain text body"))
130///     }
131/// }
132/// ```
133pub trait FromRequest: Sized {
134    /// Extract from the full request
135    fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
136}
137
138// Blanket impl: FromRequestParts -> FromRequest
139impl<T: FromRequestParts> FromRequest for T {
140    async fn from_request(req: &mut Request) -> Result<Self> {
141        T::from_request_parts(req)
142    }
143}
144
145/// JSON body extractor
146///
147/// Parses the request body as JSON and deserializes into type `T`.
148/// Also works as a response type when T: Serialize.
149///
150/// # Example
151///
152/// ```rust,ignore
153/// #[derive(Deserialize)]
154/// struct CreateUser {
155///     name: String,
156///     email: String,
157/// }
158///
159/// async fn create_user(Json(body): Json<CreateUser>) -> impl IntoResponse {
160///     // body is already deserialized
161/// }
162/// ```
163#[derive(Debug, Clone, Copy, Default)]
164pub struct Json<T>(pub T);
165
166impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
167    async fn from_request(req: &mut Request) -> Result<Self> {
168        req.load_body().await?;
169        let body = req
170            .take_body()
171            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
172
173        // Use simd-json accelerated parsing when available (2-4x faster)
174        let value: T = json::from_slice(&body)?;
175        Ok(Json(value))
176    }
177}
178
179impl<T> Deref for Json<T> {
180    type Target = T;
181
182    fn deref(&self) -> &Self::Target {
183        &self.0
184    }
185}
186
187impl<T> DerefMut for Json<T> {
188    fn deref_mut(&mut self) -> &mut Self::Target {
189        &mut self.0
190    }
191}
192
193impl<T> From<T> for Json<T> {
194    fn from(value: T) -> Self {
195        Json(value)
196    }
197}
198
199/// Default pre-allocation size for JSON response buffers (256 bytes)
200/// This covers most small to medium JSON responses without reallocation.
201const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
202
203// IntoResponse for Json - allows using Json<T> as a return type
204impl<T: Serialize> IntoResponse for Json<T> {
205    fn into_response(self) -> crate::response::Response {
206        // Use pre-allocated buffer to reduce allocations
207        match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
208            Ok(body) => http::Response::builder()
209                .status(StatusCode::OK)
210                .header(header::CONTENT_TYPE, "application/json")
211                .body(crate::response::Body::from(body))
212                .unwrap(),
213            Err(err) => {
214                ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
215            }
216        }
217    }
218}
219
220/// Validated JSON body extractor
221///
222/// Parses the request body as JSON, deserializes into type `T`, and validates
223/// using the `Validate` trait. Returns a 422 Unprocessable Entity error with
224/// detailed field-level validation errors if validation fails.
225///
226/// # Example
227///
228/// ```rust,ignore
229/// use rustapi_rs::prelude::*;
230/// use validator::Validate;
231///
232/// #[derive(Deserialize, Validate)]
233/// struct CreateUser {
234///     #[validate(email)]
235///     email: String,
236///     #[validate(length(min = 8))]
237///     password: String,
238/// }
239///
240/// async fn register(ValidatedJson(body): ValidatedJson<CreateUser>) -> impl IntoResponse {
241///     // body is already validated!
242///     // If email is invalid or password too short, a 422 error is returned automatically
243/// }
244/// ```
245#[derive(Debug, Clone, Copy, Default)]
246pub struct ValidatedJson<T>(pub T);
247
248impl<T> ValidatedJson<T> {
249    /// Create a new ValidatedJson wrapper
250    pub fn new(value: T) -> Self {
251        Self(value)
252    }
253
254    /// Get the inner value
255    pub fn into_inner(self) -> T {
256        self.0
257    }
258}
259
260impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
261    async fn from_request(req: &mut Request) -> Result<Self> {
262        req.load_body().await?;
263        // First, deserialize the JSON body using simd-json when available
264        let body = req
265            .take_body()
266            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
267
268        let value: T = json::from_slice(&body)?;
269
270        // Then, validate it using the unified Validatable trait
271        value.do_validate()?;
272
273        Ok(ValidatedJson(value))
274    }
275}
276
277impl<T> Deref for ValidatedJson<T> {
278    type Target = T;
279
280    fn deref(&self) -> &Self::Target {
281        &self.0
282    }
283}
284
285impl<T> DerefMut for ValidatedJson<T> {
286    fn deref_mut(&mut self) -> &mut Self::Target {
287        &mut self.0
288    }
289}
290
291impl<T> From<T> for ValidatedJson<T> {
292    fn from(value: T) -> Self {
293        ValidatedJson(value)
294    }
295}
296
297impl<T: Serialize> IntoResponse for ValidatedJson<T> {
298    fn into_response(self) -> crate::response::Response {
299        Json(self.0).into_response()
300    }
301}
302
303/// Async validated JSON body extractor
304///
305/// Parses the request body as JSON, deserializes into type `T`, and validates
306/// using the `AsyncValidate` trait from `rustapi-validate`.
307///
308/// This extractor supports async validation rules, such as database uniqueness checks.
309///
310/// # Example
311///
312/// ```rust,ignore
313/// use rustapi_rs::prelude::*;
314/// use rustapi_validate::v2::prelude::*;
315///
316/// #[derive(Deserialize, Validate, AsyncValidate)]
317/// struct CreateUser {
318///     #[validate(email)]
319///     email: String,
320///     
321///     #[validate(async_unique(table = "users", column = "email"))]
322///     username: String,
323/// }
324///
325/// async fn register(AsyncValidatedJson(body): AsyncValidatedJson<CreateUser>) -> impl IntoResponse {
326///     // body is validated asynchronously (e.g. checked existing email in DB)
327/// }
328/// ```
329#[derive(Debug, Clone, Copy, Default)]
330pub struct AsyncValidatedJson<T>(pub T);
331
332impl<T> AsyncValidatedJson<T> {
333    /// Create a new AsyncValidatedJson wrapper
334    pub fn new(value: T) -> Self {
335        Self(value)
336    }
337
338    /// Get the inner value
339    pub fn into_inner(self) -> T {
340        self.0
341    }
342}
343
344impl<T> Deref for AsyncValidatedJson<T> {
345    type Target = T;
346
347    fn deref(&self) -> &Self::Target {
348        &self.0
349    }
350}
351
352impl<T> DerefMut for AsyncValidatedJson<T> {
353    fn deref_mut(&mut self) -> &mut Self::Target {
354        &mut self.0
355    }
356}
357
358impl<T> From<T> for AsyncValidatedJson<T> {
359    fn from(value: T) -> Self {
360        AsyncValidatedJson(value)
361    }
362}
363
364impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
365    fn into_response(self) -> crate::response::Response {
366        Json(self.0).into_response()
367    }
368}
369
370impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
371    async fn from_request(req: &mut Request) -> Result<Self> {
372        req.load_body().await?;
373
374        let body = req
375            .take_body()
376            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
377
378        let value: T = json::from_slice(&body)?;
379
380        // Create validation context from request
381        // TODO: Extract validators from App State
382        let ctx = ValidationContext::default();
383
384        // Perform full validation (sync + async)
385        if let Err(errors) = value.validate_full(&ctx).await {
386            // Convert v2 ValidationErrors to ApiError
387            let field_errors: Vec<crate::error::FieldError> = errors
388                .fields
389                .iter()
390                .flat_map(|(field, errs)| {
391                    let field_name = field.to_string();
392                    errs.iter().map(move |e| crate::error::FieldError {
393                        field: field_name.clone(),
394                        code: e.code.to_string(),
395                        message: e.message.clone(),
396                    })
397                })
398                .collect();
399
400            return Err(ApiError::validation(field_errors));
401        }
402
403        Ok(AsyncValidatedJson(value))
404    }
405}
406
407/// Query string extractor
408///
409/// Parses the query string into type `T`.
410///
411/// # Example
412///
413/// ```rust,ignore
414/// #[derive(Deserialize)]
415/// struct Pagination {
416///     page: Option<u32>,
417///     limit: Option<u32>,
418/// }
419///
420/// async fn list_users(Query(params): Query<Pagination>) -> impl IntoResponse {
421///     // params.page, params.limit
422/// }
423/// ```
424#[derive(Debug, Clone)]
425pub struct Query<T>(pub T);
426
427impl<T: DeserializeOwned> FromRequestParts for Query<T> {
428    fn from_request_parts(req: &Request) -> Result<Self> {
429        let query = req.query_string().unwrap_or("");
430        let value: T = serde_urlencoded::from_str(query)
431            .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
432        Ok(Query(value))
433    }
434}
435
436impl<T> Deref for Query<T> {
437    type Target = T;
438
439    fn deref(&self) -> &Self::Target {
440        &self.0
441    }
442}
443
444/// Path parameter extractor
445///
446/// Extracts path parameters defined in the route pattern.
447///
448/// # Example
449///
450/// For route `/users/{id}`:
451///
452/// ```rust,ignore
453/// async fn get_user(Path(id): Path<i64>) -> impl IntoResponse {
454///     // id is extracted from path
455/// }
456/// ```
457///
458/// For multiple params `/users/{user_id}/posts/{post_id}`:
459///
460/// ```rust,ignore
461/// async fn get_post(Path((user_id, post_id)): Path<(i64, i64)>) -> impl IntoResponse {
462///     // Both params extracted
463/// }
464/// ```
465#[derive(Debug, Clone)]
466pub struct Path<T>(pub T);
467
468impl<T: FromStr> FromRequestParts for Path<T>
469where
470    T::Err: std::fmt::Display,
471{
472    fn from_request_parts(req: &Request) -> Result<Self> {
473        let params = req.path_params();
474
475        // For single param, get the first one
476        if let Some((_, value)) = params.iter().next() {
477            let parsed = value
478                .parse::<T>()
479                .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
480            return Ok(Path(parsed));
481        }
482
483        Err(ApiError::internal("Missing path parameter"))
484    }
485}
486
487impl<T> Deref for Path<T> {
488    type Target = T;
489
490    fn deref(&self) -> &Self::Target {
491        &self.0
492    }
493}
494
495/// Typed path extractor
496///
497/// Extracts path parameters and deserializes them into a struct implementing `Deserialize`.
498/// This is similar to `Path<T>`, but supports complex structs that can be deserialized
499/// from a map of parameter names to values (e.g. via `serde_json`).
500///
501/// # Example
502///
503/// ```rust,ignore
504/// #[derive(Deserialize)]
505/// struct UserParams {
506///     id: u64,
507///     category: String,
508/// }
509///
510/// async fn get_user(Typed(params): Typed<UserParams>) -> impl IntoResponse {
511///     // params.id, params.category
512/// }
513/// ```
514#[derive(Debug, Clone)]
515pub struct Typed<T>(pub T);
516
517impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
518    fn from_request_parts(req: &Request) -> Result<Self> {
519        let params = req.path_params();
520        let mut map = serde_json::Map::new();
521        for (k, v) in params.iter() {
522            map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
523        }
524        let value = serde_json::Value::Object(map);
525        let parsed: T = serde_json::from_value(value)
526            .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
527        Ok(Typed(parsed))
528    }
529}
530
531impl<T> Deref for Typed<T> {
532    type Target = T;
533
534    fn deref(&self) -> &Self::Target {
535        &self.0
536    }
537}
538
539/// State extractor
540///
541/// Extracts shared application state.
542///
543/// # Example
544///
545/// ```rust,ignore
546/// #[derive(Clone)]
547/// struct AppState {
548///     db: DbPool,
549/// }
550///
551/// async fn handler(State(state): State<AppState>) -> impl IntoResponse {
552///     // Use state.db
553/// }
554/// ```
555#[derive(Debug, Clone)]
556pub struct State<T>(pub T);
557
558impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
559    fn from_request_parts(req: &Request) -> Result<Self> {
560        req.state().get::<T>().cloned().map(State).ok_or_else(|| {
561            ApiError::internal(format!(
562                "State of type `{}` not found. Did you forget to call .state()?",
563                std::any::type_name::<T>()
564            ))
565        })
566    }
567}
568
569impl<T> Deref for State<T> {
570    type Target = T;
571
572    fn deref(&self) -> &Self::Target {
573        &self.0
574    }
575}
576
577/// Raw body bytes extractor
578#[derive(Debug, Clone)]
579pub struct Body(pub Bytes);
580
581impl FromRequest for Body {
582    async fn from_request(req: &mut Request) -> Result<Self> {
583        req.load_body().await?;
584        let body = req
585            .take_body()
586            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
587        Ok(Body(body))
588    }
589}
590
591impl Deref for Body {
592    type Target = Bytes;
593
594    fn deref(&self) -> &Self::Target {
595        &self.0
596    }
597}
598
599/// Streaming body extractor
600pub struct BodyStream(pub StreamingBody);
601
602impl FromRequest for BodyStream {
603    async fn from_request(req: &mut Request) -> Result<Self> {
604        let config = StreamingConfig::default();
605
606        if let Some(stream) = req.take_stream() {
607            Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
608        } else if let Some(bytes) = req.take_body() {
609            // Handle buffered body as stream
610            let stream = futures_util::stream::once(async move { Ok(bytes) });
611            Ok(BodyStream(StreamingBody::from_stream(
612                stream,
613                config.max_body_size,
614            )))
615        } else {
616            Err(ApiError::internal("Body already consumed"))
617        }
618    }
619}
620
621impl Deref for BodyStream {
622    type Target = StreamingBody;
623
624    fn deref(&self) -> &Self::Target {
625        &self.0
626    }
627}
628
629impl DerefMut for BodyStream {
630    fn deref_mut(&mut self) -> &mut Self::Target {
631        &mut self.0
632    }
633}
634
635// Forward stream implementation
636impl futures_util::Stream for BodyStream {
637    type Item = Result<Bytes, ApiError>;
638
639    fn poll_next(
640        mut self: std::pin::Pin<&mut Self>,
641        cx: &mut std::task::Context<'_>,
642    ) -> std::task::Poll<Option<Self::Item>> {
643        std::pin::Pin::new(&mut self.0).poll_next(cx)
644    }
645}
646
647/// Optional extractor wrapper
648///
649/// Makes any extractor optional - returns None instead of error on failure.
650impl<T: FromRequestParts> FromRequestParts for Option<T> {
651    fn from_request_parts(req: &Request) -> Result<Self> {
652        Ok(T::from_request_parts(req).ok())
653    }
654}
655
656/// Headers extractor
657///
658/// Provides access to all request headers as a typed map.
659///
660/// # Example
661///
662/// ```rust,ignore
663/// use rustapi_core::extract::Headers;
664///
665/// async fn handler(headers: Headers) -> impl IntoResponse {
666///     if let Some(content_type) = headers.get("content-type") {
667///         format!("Content-Type: {:?}", content_type)
668///     } else {
669///         "No Content-Type header".to_string()
670///     }
671/// }
672/// ```
673#[derive(Debug, Clone)]
674pub struct Headers(pub http::HeaderMap);
675
676impl Headers {
677    /// Get a header value by name
678    pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
679        self.0.get(name)
680    }
681
682    /// Check if a header exists
683    pub fn contains(&self, name: &str) -> bool {
684        self.0.contains_key(name)
685    }
686
687    /// Get the number of headers
688    pub fn len(&self) -> usize {
689        self.0.len()
690    }
691
692    /// Check if headers are empty
693    pub fn is_empty(&self) -> bool {
694        self.0.is_empty()
695    }
696
697    /// Iterate over all headers
698    pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
699        self.0.iter()
700    }
701}
702
703impl FromRequestParts for Headers {
704    fn from_request_parts(req: &Request) -> Result<Self> {
705        Ok(Headers(req.headers().clone()))
706    }
707}
708
709impl Deref for Headers {
710    type Target = http::HeaderMap;
711
712    fn deref(&self) -> &Self::Target {
713        &self.0
714    }
715}
716
717/// Single header value extractor
718///
719/// Extracts a specific header value by name. Returns an error if the header is missing.
720///
721/// # Example
722///
723/// ```rust,ignore
724/// use rustapi_core::extract::HeaderValue;
725///
726/// async fn handler(
727///     auth: HeaderValue<{ "authorization" }>,
728/// ) -> impl IntoResponse {
729///     format!("Auth header: {}", auth.0)
730/// }
731/// ```
732///
733/// Note: Due to Rust's const generics limitations, you may need to use the
734/// `HeaderValueOf` type alias or extract headers manually using the `Headers` extractor.
735#[derive(Debug, Clone)]
736pub struct HeaderValue(pub String, pub &'static str);
737
738impl HeaderValue {
739    /// Create a new HeaderValue extractor for a specific header name
740    pub fn new(name: &'static str, value: String) -> Self {
741        Self(value, name)
742    }
743
744    /// Get the header value
745    pub fn value(&self) -> &str {
746        &self.0
747    }
748
749    /// Get the header name
750    pub fn name(&self) -> &'static str {
751        self.1
752    }
753
754    /// Extract a specific header from a request
755    pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
756        req.headers()
757            .get(name)
758            .and_then(|v| v.to_str().ok())
759            .map(|s| HeaderValue(s.to_string(), name))
760            .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
761    }
762}
763
764impl Deref for HeaderValue {
765    type Target = String;
766
767    fn deref(&self) -> &Self::Target {
768        &self.0
769    }
770}
771
772/// Extension extractor
773///
774/// Retrieves typed data from request extensions that was inserted by middleware.
775///
776/// # Example
777///
778/// ```rust,ignore
779/// use rustapi_core::extract::Extension;
780///
781/// // Middleware inserts user data
782/// #[derive(Clone)]
783/// struct CurrentUser { id: i64 }
784///
785/// async fn handler(Extension(user): Extension<CurrentUser>) -> impl IntoResponse {
786///     format!("User ID: {}", user.id)
787/// }
788/// ```
789#[derive(Debug, Clone)]
790pub struct Extension<T>(pub T);
791
792impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
793    fn from_request_parts(req: &Request) -> Result<Self> {
794        req.extensions()
795            .get::<T>()
796            .cloned()
797            .map(Extension)
798            .ok_or_else(|| {
799                ApiError::internal(format!(
800                    "Extension of type `{}` not found. Did middleware insert it?",
801                    std::any::type_name::<T>()
802                ))
803            })
804    }
805}
806
807impl<T> Deref for Extension<T> {
808    type Target = T;
809
810    fn deref(&self) -> &Self::Target {
811        &self.0
812    }
813}
814
815impl<T> DerefMut for Extension<T> {
816    fn deref_mut(&mut self) -> &mut Self::Target {
817        &mut self.0
818    }
819}
820
821/// Client IP address extractor
822///
823/// Extracts the client IP address from the request. When `trust_proxy` is enabled,
824/// it will use the `X-Forwarded-For` header if present.
825///
826/// # Example
827///
828/// ```rust,ignore
829/// use rustapi_core::extract::ClientIp;
830///
831/// async fn handler(ClientIp(ip): ClientIp) -> impl IntoResponse {
832///     format!("Your IP: {}", ip)
833/// }
834/// ```
835#[derive(Debug, Clone)]
836pub struct ClientIp(pub std::net::IpAddr);
837
838impl ClientIp {
839    /// Extract client IP, optionally trusting X-Forwarded-For header
840    pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
841        if trust_proxy {
842            // Try X-Forwarded-For header first
843            if let Some(forwarded) = req.headers().get("x-forwarded-for") {
844                if let Ok(forwarded_str) = forwarded.to_str() {
845                    // X-Forwarded-For can contain multiple IPs, take the first one
846                    if let Some(first_ip) = forwarded_str.split(',').next() {
847                        if let Ok(ip) = first_ip.trim().parse() {
848                            return Ok(ClientIp(ip));
849                        }
850                    }
851                }
852            }
853        }
854
855        // Fall back to socket address from extensions (if set by server)
856        if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
857            return Ok(ClientIp(addr.ip()));
858        }
859
860        // Default to localhost if no IP information available
861        Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
862            127, 0, 0, 1,
863        ))))
864    }
865}
866
867impl FromRequestParts for ClientIp {
868    fn from_request_parts(req: &Request) -> Result<Self> {
869        // By default, trust proxy headers
870        Self::extract_with_config(req, true)
871    }
872}
873
874/// Cookies extractor
875///
876/// Parses and provides access to request cookies from the Cookie header.
877///
878/// # Example
879///
880/// ```rust,ignore
881/// use rustapi_core::extract::Cookies;
882///
883/// async fn handler(cookies: Cookies) -> impl IntoResponse {
884///     if let Some(session) = cookies.get("session_id") {
885///         format!("Session: {}", session.value())
886///     } else {
887///         "No session cookie".to_string()
888///     }
889/// }
890/// ```
891#[cfg(feature = "cookies")]
892#[derive(Debug, Clone)]
893pub struct Cookies(pub cookie::CookieJar);
894
895#[cfg(feature = "cookies")]
896impl Cookies {
897    /// Get a cookie by name
898    pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
899        self.0.get(name)
900    }
901
902    /// Iterate over all cookies
903    pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
904        self.0.iter()
905    }
906
907    /// Check if a cookie exists
908    pub fn contains(&self, name: &str) -> bool {
909        self.0.get(name).is_some()
910    }
911}
912
913#[cfg(feature = "cookies")]
914impl FromRequestParts for Cookies {
915    fn from_request_parts(req: &Request) -> Result<Self> {
916        let mut jar = cookie::CookieJar::new();
917
918        if let Some(cookie_header) = req.headers().get(header::COOKIE) {
919            if let Ok(cookie_str) = cookie_header.to_str() {
920                // Parse each cookie from the header
921                for cookie_part in cookie_str.split(';') {
922                    let trimmed = cookie_part.trim();
923                    if !trimmed.is_empty() {
924                        if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
925                            jar.add_original(cookie.into_owned());
926                        }
927                    }
928                }
929            }
930        }
931
932        Ok(Cookies(jar))
933    }
934}
935
936#[cfg(feature = "cookies")]
937impl Deref for Cookies {
938    type Target = cookie::CookieJar;
939
940    fn deref(&self) -> &Self::Target {
941        &self.0
942    }
943}
944
945// Implement FromRequestParts for common primitive types (path params)
946macro_rules! impl_from_request_parts_for_primitives {
947    ($($ty:ty),*) => {
948        $(
949            impl FromRequestParts for $ty {
950                fn from_request_parts(req: &Request) -> Result<Self> {
951                    let Path(value) = Path::<$ty>::from_request_parts(req)?;
952                    Ok(value)
953                }
954            }
955        )*
956    };
957}
958
959impl_from_request_parts_for_primitives!(
960    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
961);
962
963// OperationModifier implementations for extractors
964
965use rustapi_openapi::{
966    MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier, ResponseSpec,
967};
968
969// ValidatedJson - Adds request body
970impl<T: RustApiSchema> OperationModifier for ValidatedJson<T> {
971    fn update_operation(op: &mut Operation) {
972        let mut ctx = SchemaCtx::new();
973        let schema_ref = T::schema(&mut ctx);
974
975        let mut content = BTreeMap::new();
976        content.insert(
977            "application/json".to_string(),
978            MediaType {
979                schema: Some(schema_ref),
980                example: None,
981            },
982        );
983
984        op.request_body = Some(RequestBody {
985            description: None,
986            required: Some(true),
987            content,
988        });
989
990        // Add 422 Validation Error response
991        let mut responses_content = BTreeMap::new();
992        responses_content.insert(
993            "application/json".to_string(),
994            MediaType {
995                schema: Some(SchemaRef::Ref {
996                    reference: "#/components/schemas/ValidationErrorSchema".to_string(),
997                }),
998                example: None,
999            },
1000        );
1001
1002        op.responses.insert(
1003            "422".to_string(),
1004            ResponseSpec {
1005                description: "Validation Error".to_string(),
1006                content: responses_content,
1007                headers: BTreeMap::new(),
1008            },
1009        );
1010    }
1011}
1012
1013// Json - Adds request body (Same as ValidatedJson)
1014impl<T: RustApiSchema> OperationModifier for Json<T> {
1015    fn update_operation(op: &mut Operation) {
1016        let mut ctx = SchemaCtx::new();
1017        let schema_ref = T::schema(&mut ctx);
1018
1019        let mut content = BTreeMap::new();
1020        content.insert(
1021            "application/json".to_string(),
1022            MediaType {
1023                schema: Some(schema_ref),
1024                example: None,
1025            },
1026        );
1027
1028        op.request_body = Some(RequestBody {
1029            description: None,
1030            required: Some(true),
1031            content,
1032        });
1033    }
1034}
1035
1036// Path - No op (handled by app routing)
1037impl<T> OperationModifier for Path<T> {
1038    fn update_operation(_op: &mut Operation) {}
1039}
1040
1041// Typed - No op
1042impl<T> OperationModifier for Typed<T> {
1043    fn update_operation(_op: &mut Operation) {}
1044}
1045
1046// Query - Extracts query params using field_schemas
1047impl<T: RustApiSchema> OperationModifier for Query<T> {
1048    fn update_operation(op: &mut Operation) {
1049        let mut ctx = SchemaCtx::new();
1050        if let Some(fields) = T::field_schemas(&mut ctx) {
1051            let new_params: Vec<Parameter> = fields
1052                .into_iter()
1053                .map(|(name, schema)| {
1054                    Parameter {
1055                        name,
1056                        location: "query".to_string(),
1057                        required: false, // Assume optional
1058                        deprecated: None,
1059                        description: None,
1060                        schema: Some(schema),
1061                    }
1062                })
1063                .collect();
1064
1065            op.parameters.extend(new_params);
1066        }
1067    }
1068}
1069
1070// State - No op
1071impl<T> OperationModifier for State<T> {
1072    fn update_operation(_op: &mut Operation) {}
1073}
1074
1075// Body - Generic binary body
1076impl OperationModifier for Body {
1077    fn update_operation(op: &mut Operation) {
1078        let mut content = BTreeMap::new();
1079        content.insert(
1080            "application/octet-stream".to_string(),
1081            MediaType {
1082                schema: Some(SchemaRef::Inline(
1083                    serde_json::json!({ "type": "string", "format": "binary" }),
1084                )),
1085                example: None,
1086            },
1087        );
1088
1089        op.request_body = Some(RequestBody {
1090            description: None,
1091            required: Some(true),
1092            content,
1093        });
1094    }
1095}
1096
1097// BodyStream - Generic binary stream
1098impl OperationModifier for BodyStream {
1099    fn update_operation(op: &mut Operation) {
1100        let mut content = BTreeMap::new();
1101        content.insert(
1102            "application/octet-stream".to_string(),
1103            MediaType {
1104                schema: Some(SchemaRef::Inline(
1105                    serde_json::json!({ "type": "string", "format": "binary" }),
1106                )),
1107                example: None,
1108            },
1109        );
1110
1111        op.request_body = Some(RequestBody {
1112            description: None,
1113            required: Some(true),
1114            content,
1115        });
1116    }
1117}
1118
1119// ResponseModifier implementations for extractors
1120
1121// Json<T> - 200 OK with schema T
1122impl<T: RustApiSchema> ResponseModifier for Json<T> {
1123    fn update_response(op: &mut Operation) {
1124        let mut ctx = SchemaCtx::new();
1125        let schema_ref = T::schema(&mut ctx);
1126
1127        let mut content = BTreeMap::new();
1128        content.insert(
1129            "application/json".to_string(),
1130            MediaType {
1131                schema: Some(schema_ref),
1132                example: None,
1133            },
1134        );
1135
1136        op.responses.insert(
1137            "200".to_string(),
1138            ResponseSpec {
1139                description: "Successful response".to_string(),
1140                content,
1141                headers: BTreeMap::new(),
1142            },
1143        );
1144    }
1145}
1146
1147// RustApiSchema implementations
1148
1149impl<T: RustApiSchema> RustApiSchema for Json<T> {
1150    fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1151        T::schema(ctx)
1152    }
1153}
1154
1155impl<T: RustApiSchema> RustApiSchema for ValidatedJson<T> {
1156    fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1157        T::schema(ctx)
1158    }
1159}
1160
1161impl<T: RustApiSchema> RustApiSchema for AsyncValidatedJson<T> {
1162    fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1163        T::schema(ctx)
1164    }
1165}
1166
1167impl<T: RustApiSchema> RustApiSchema for Query<T> {
1168    fn schema(ctx: &mut SchemaCtx) -> SchemaRef {
1169        T::schema(ctx)
1170    }
1171    fn field_schemas(ctx: &mut SchemaCtx) -> Option<BTreeMap<String, SchemaRef>> {
1172        T::field_schemas(ctx)
1173    }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use super::*;
1179    use crate::path_params::PathParams;
1180    use bytes::Bytes;
1181    use http::{Extensions, Method};
1182    use proptest::prelude::*;
1183    use proptest::test_runner::TestCaseError;
1184    use std::sync::Arc;
1185
1186    /// Create a test request with the given method, path, and headers
1187    fn create_test_request_with_headers(
1188        method: Method,
1189        path: &str,
1190        headers: Vec<(&str, &str)>,
1191    ) -> Request {
1192        let uri: http::Uri = path.parse().unwrap();
1193        let mut builder = http::Request::builder().method(method).uri(uri);
1194
1195        for (name, value) in headers {
1196            builder = builder.header(name, value);
1197        }
1198
1199        let req = builder.body(()).unwrap();
1200        let (parts, _) = req.into_parts();
1201
1202        Request::new(
1203            parts,
1204            crate::request::BodyVariant::Buffered(Bytes::new()),
1205            Arc::new(Extensions::new()),
1206            PathParams::new(),
1207        )
1208    }
1209
1210    /// Create a test request with extensions
1211    fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1212        method: Method,
1213        path: &str,
1214        extension: T,
1215    ) -> Request {
1216        let uri: http::Uri = path.parse().unwrap();
1217        let builder = http::Request::builder().method(method).uri(uri);
1218
1219        let req = builder.body(()).unwrap();
1220        let (mut parts, _) = req.into_parts();
1221        parts.extensions.insert(extension);
1222
1223        Request::new(
1224            parts,
1225            crate::request::BodyVariant::Buffered(Bytes::new()),
1226            Arc::new(Extensions::new()),
1227            PathParams::new(),
1228        )
1229    }
1230
1231    // **Feature: phase3-batteries-included, Property 14: Headers extractor completeness**
1232    //
1233    // For any request with headers H, the `Headers` extractor SHALL return a map
1234    // containing all key-value pairs in H.
1235    //
1236    // **Validates: Requirements 5.1**
1237    proptest! {
1238        #![proptest_config(ProptestConfig::with_cases(100))]
1239
1240        #[test]
1241        fn prop_headers_extractor_completeness(
1242            // Generate random header names and values
1243            // Using alphanumeric strings to ensure valid header names/values
1244            headers in prop::collection::vec(
1245                (
1246                    "[a-z][a-z0-9-]{0,20}",  // Valid header name pattern
1247                    "[a-zA-Z0-9 ]{1,50}"     // Valid header value pattern
1248                ),
1249                0..10
1250            )
1251        ) {
1252            let result: Result<(), TestCaseError> = (|| {
1253                // Convert to header tuples
1254                let header_tuples: Vec<(&str, &str)> = headers
1255                    .iter()
1256                    .map(|(k, v)| (k.as_str(), v.as_str()))
1257                    .collect();
1258
1259                // Create request with headers
1260                let request = create_test_request_with_headers(
1261                    Method::GET,
1262                    "/test",
1263                    header_tuples.clone(),
1264                );
1265
1266                // Extract headers
1267                let extracted = Headers::from_request_parts(&request)
1268                    .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1269
1270                // Verify all original headers are present
1271                // HTTP allows duplicate headers - get_all() returns all values for a header name
1272                for (name, value) in &headers {
1273                    // Check that the header name exists
1274                    let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1275                    prop_assert!(
1276                        !all_values.is_empty(),
1277                        "Header '{}' not found",
1278                        name
1279                    );
1280
1281                    // Check that the value is among the extracted values
1282                    let value_found = all_values.iter().any(|v| {
1283                        v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1284                    });
1285
1286                    prop_assert!(
1287                        value_found,
1288                        "Header '{}' value '{}' not found in extracted values",
1289                        name,
1290                        value
1291                    );
1292                }
1293
1294                Ok(())
1295            })();
1296            result?;
1297        }
1298    }
1299
1300    // **Feature: phase3-batteries-included, Property 15: HeaderValue extractor correctness**
1301    //
1302    // For any request with header "X" having value V, `HeaderValue::extract(req, "X")` SHALL return V;
1303    // for requests without header "X", it SHALL return an error.
1304    //
1305    // **Validates: Requirements 5.2**
1306    proptest! {
1307        #![proptest_config(ProptestConfig::with_cases(100))]
1308
1309        #[test]
1310        fn prop_header_value_extractor_correctness(
1311            header_name in "[a-z][a-z0-9-]{0,20}",
1312            header_value in "[a-zA-Z0-9 ]{1,50}",
1313            has_header in prop::bool::ANY,
1314        ) {
1315            let result: Result<(), TestCaseError> = (|| {
1316                let headers = if has_header {
1317                    vec![(header_name.as_str(), header_value.as_str())]
1318                } else {
1319                    vec![]
1320                };
1321
1322                let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1323
1324                // We need to use a static string for the header name in the extractor
1325                // So we'll test with a known header name
1326                let test_header = "x-test-header";
1327                let request_with_known_header = if has_header {
1328                    create_test_request_with_headers(
1329                        Method::GET,
1330                        "/test",
1331                        vec![(test_header, header_value.as_str())],
1332                    )
1333                } else {
1334                    create_test_request_with_headers(Method::GET, "/test", vec![])
1335                };
1336
1337                let result = HeaderValue::extract(&request_with_known_header, test_header);
1338
1339                if has_header {
1340                    let extracted = result
1341                        .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1342                    prop_assert_eq!(
1343                        extracted.value(),
1344                        header_value.as_str(),
1345                        "Header value mismatch"
1346                    );
1347                } else {
1348                    prop_assert!(
1349                        result.is_err(),
1350                        "Expected error when header is missing"
1351                    );
1352                }
1353
1354                Ok(())
1355            })();
1356            result?;
1357        }
1358    }
1359
1360    // **Feature: phase3-batteries-included, Property 17: ClientIp extractor with forwarding**
1361    //
1362    // For any request with socket IP S and X-Forwarded-For header F, when forwarding is enabled,
1363    // `ClientIp` SHALL return the first IP in F; when disabled, it SHALL return S.
1364    //
1365    // **Validates: Requirements 5.4**
1366    proptest! {
1367        #![proptest_config(ProptestConfig::with_cases(100))]
1368
1369        #[test]
1370        fn prop_client_ip_extractor_with_forwarding(
1371            // Generate valid IPv4 addresses
1372            forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1373                .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1374            socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1375                .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1376            has_forwarded_header in prop::bool::ANY,
1377            trust_proxy in prop::bool::ANY,
1378        ) {
1379            let result: Result<(), TestCaseError> = (|| {
1380                let headers = if has_forwarded_header {
1381                    vec![("x-forwarded-for", forwarded_ip.as_str())]
1382                } else {
1383                    vec![]
1384                };
1385
1386                // Create request with headers
1387                let uri: http::Uri = "/test".parse().unwrap();
1388                let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1389                for (name, value) in &headers {
1390                    builder = builder.header(*name, *value);
1391                }
1392                let req = builder.body(()).unwrap();
1393                let (mut parts, _) = req.into_parts();
1394
1395                // Add socket address to extensions
1396                let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1397                parts.extensions.insert(socket_addr);
1398
1399                let request = Request::new(
1400                    parts,
1401                    crate::request::BodyVariant::Buffered(Bytes::new()),
1402                    Arc::new(Extensions::new()),
1403                    PathParams::new(),
1404                );
1405
1406                let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1407                    .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1408
1409                if trust_proxy && has_forwarded_header {
1410                    // Should use X-Forwarded-For
1411                    let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1412                        .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1413                    prop_assert_eq!(
1414                        extracted.0,
1415                        expected_ip,
1416                        "Should use X-Forwarded-For IP when trust_proxy is enabled"
1417                    );
1418                } else {
1419                    // Should use socket IP
1420                    prop_assert_eq!(
1421                        extracted.0,
1422                        socket_ip,
1423                        "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1424                    );
1425                }
1426
1427                Ok(())
1428            })();
1429            result?;
1430        }
1431    }
1432
1433    // **Feature: phase3-batteries-included, Property 18: Extension extractor retrieval**
1434    //
1435    // For any type T and value V inserted into request extensions by middleware,
1436    // `Extension<T>` SHALL return V.
1437    //
1438    // **Validates: Requirements 5.5**
1439    proptest! {
1440        #![proptest_config(ProptestConfig::with_cases(100))]
1441
1442        #[test]
1443        fn prop_extension_extractor_retrieval(
1444            value in any::<i64>(),
1445            has_extension in prop::bool::ANY,
1446        ) {
1447            let result: Result<(), TestCaseError> = (|| {
1448                // Create a simple wrapper type for testing
1449                #[derive(Clone, Debug, PartialEq)]
1450                struct TestExtension(i64);
1451
1452                let uri: http::Uri = "/test".parse().unwrap();
1453                let builder = http::Request::builder().method(Method::GET).uri(uri);
1454                let req = builder.body(()).unwrap();
1455                let (mut parts, _) = req.into_parts();
1456
1457                if has_extension {
1458                    parts.extensions.insert(TestExtension(value));
1459                }
1460
1461                let request = Request::new(
1462                    parts,
1463                    crate::request::BodyVariant::Buffered(Bytes::new()),
1464                    Arc::new(Extensions::new()),
1465                    PathParams::new(),
1466                );
1467
1468                let result = Extension::<TestExtension>::from_request_parts(&request);
1469
1470                if has_extension {
1471                    let extracted = result
1472                        .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1473                    prop_assert_eq!(
1474                        extracted.0,
1475                        TestExtension(value),
1476                        "Extension value mismatch"
1477                    );
1478                } else {
1479                    prop_assert!(
1480                        result.is_err(),
1481                        "Expected error when extension is missing"
1482                    );
1483                }
1484
1485                Ok(())
1486            })();
1487            result?;
1488        }
1489    }
1490
1491    // Unit tests for basic functionality
1492
1493    #[test]
1494    fn test_headers_extractor_basic() {
1495        let request = create_test_request_with_headers(
1496            Method::GET,
1497            "/test",
1498            vec![
1499                ("content-type", "application/json"),
1500                ("accept", "text/html"),
1501            ],
1502        );
1503
1504        let headers = Headers::from_request_parts(&request).unwrap();
1505
1506        assert!(headers.contains("content-type"));
1507        assert!(headers.contains("accept"));
1508        assert!(!headers.contains("x-custom"));
1509        assert_eq!(headers.len(), 2);
1510    }
1511
1512    #[test]
1513    fn test_header_value_extractor_present() {
1514        let request = create_test_request_with_headers(
1515            Method::GET,
1516            "/test",
1517            vec![("authorization", "Bearer token123")],
1518        );
1519
1520        let result = HeaderValue::extract(&request, "authorization");
1521        assert!(result.is_ok());
1522        assert_eq!(result.unwrap().value(), "Bearer token123");
1523    }
1524
1525    #[test]
1526    fn test_header_value_extractor_missing() {
1527        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1528
1529        let result = HeaderValue::extract(&request, "authorization");
1530        assert!(result.is_err());
1531    }
1532
1533    #[test]
1534    fn test_client_ip_from_forwarded_header() {
1535        let request = create_test_request_with_headers(
1536            Method::GET,
1537            "/test",
1538            vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1539        );
1540
1541        let ip = ClientIp::extract_with_config(&request, true).unwrap();
1542        assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1543    }
1544
1545    #[test]
1546    fn test_client_ip_ignores_forwarded_when_not_trusted() {
1547        let uri: http::Uri = "/test".parse().unwrap();
1548        let builder = http::Request::builder()
1549            .method(Method::GET)
1550            .uri(uri)
1551            .header("x-forwarded-for", "192.168.1.100");
1552        let req = builder.body(()).unwrap();
1553        let (mut parts, _) = req.into_parts();
1554
1555        let socket_addr = std::net::SocketAddr::new(
1556            std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1557            8080,
1558        );
1559        parts.extensions.insert(socket_addr);
1560
1561        let request = Request::new(
1562            parts,
1563            crate::request::BodyVariant::Buffered(Bytes::new()),
1564            Arc::new(Extensions::new()),
1565            PathParams::new(),
1566        );
1567
1568        let ip = ClientIp::extract_with_config(&request, false).unwrap();
1569        assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1570    }
1571
1572    #[test]
1573    fn test_extension_extractor_present() {
1574        #[derive(Clone, Debug, PartialEq)]
1575        struct MyData(String);
1576
1577        let request =
1578            create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1579
1580        let result = Extension::<MyData>::from_request_parts(&request);
1581        assert!(result.is_ok());
1582        assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1583    }
1584
1585    #[test]
1586    fn test_extension_extractor_missing() {
1587        #[derive(Clone, Debug)]
1588        #[allow(dead_code)]
1589        struct MyData(String);
1590
1591        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1592
1593        let result = Extension::<MyData>::from_request_parts(&request);
1594        assert!(result.is_err());
1595    }
1596
1597    // Cookies tests (feature-gated)
1598    #[cfg(feature = "cookies")]
1599    mod cookies_tests {
1600        use super::*;
1601
1602        // **Feature: phase3-batteries-included, Property 16: Cookies extractor parsing**
1603        //
1604        // For any request with Cookie header containing cookies C, the `Cookies` extractor
1605        // SHALL return a CookieJar containing exactly the cookies in C.
1606        // Note: Duplicate cookie names result in only the last value being kept.
1607        //
1608        // **Validates: Requirements 5.3**
1609        proptest! {
1610            #![proptest_config(ProptestConfig::with_cases(100))]
1611
1612            #[test]
1613            fn prop_cookies_extractor_parsing(
1614                // Generate random cookie names and values
1615                // Using alphanumeric strings to ensure valid cookie names/values
1616                cookies in prop::collection::vec(
1617                    (
1618                        "[a-zA-Z][a-zA-Z0-9_]{0,15}",  // Valid cookie name pattern
1619                        "[a-zA-Z0-9]{1,30}"            // Valid cookie value pattern (no special chars)
1620                    ),
1621                    0..5
1622                )
1623            ) {
1624                let result: Result<(), TestCaseError> = (|| {
1625                    // Build cookie header string
1626                    let cookie_header = cookies
1627                        .iter()
1628                        .map(|(name, value)| format!("{}={}", name, value))
1629                        .collect::<Vec<_>>()
1630                        .join("; ");
1631
1632                    let headers = if !cookies.is_empty() {
1633                        vec![("cookie", cookie_header.as_str())]
1634                    } else {
1635                        vec![]
1636                    };
1637
1638                    let request = create_test_request_with_headers(Method::GET, "/test", headers);
1639
1640                    // Extract cookies
1641                    let extracted = Cookies::from_request_parts(&request)
1642                        .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1643
1644                    // Build expected cookies map - last value wins for duplicate names
1645                    let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1646                    for (name, value) in &cookies {
1647                        expected_cookies.insert(name.as_str(), value.as_str());
1648                    }
1649
1650                    // Verify all expected cookies are present with correct values
1651                    for (name, expected_value) in &expected_cookies {
1652                        let cookie = extracted.get(name)
1653                            .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1654
1655                        prop_assert_eq!(
1656                            cookie.value(),
1657                            *expected_value,
1658                            "Cookie '{}' value mismatch",
1659                            name
1660                        );
1661                    }
1662
1663                    // Count cookies in jar should match unique cookie names
1664                    let extracted_count = extracted.iter().count();
1665                    prop_assert_eq!(
1666                        extracted_count,
1667                        expected_cookies.len(),
1668                        "Expected {} unique cookies, got {}",
1669                        expected_cookies.len(),
1670                        extracted_count
1671                    );
1672
1673                    Ok(())
1674                })();
1675                result?;
1676            }
1677        }
1678
1679        #[test]
1680        fn test_cookies_extractor_basic() {
1681            let request = create_test_request_with_headers(
1682                Method::GET,
1683                "/test",
1684                vec![("cookie", "session=abc123; user=john")],
1685            );
1686
1687            let cookies = Cookies::from_request_parts(&request).unwrap();
1688
1689            assert!(cookies.contains("session"));
1690            assert!(cookies.contains("user"));
1691            assert!(!cookies.contains("other"));
1692
1693            assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1694            assert_eq!(cookies.get("user").unwrap().value(), "john");
1695        }
1696
1697        #[test]
1698        fn test_cookies_extractor_empty() {
1699            let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1700
1701            let cookies = Cookies::from_request_parts(&request).unwrap();
1702            assert_eq!(cookies.iter().count(), 0);
1703        }
1704
1705        #[test]
1706        fn test_cookies_extractor_single() {
1707            let request = create_test_request_with_headers(
1708                Method::GET,
1709                "/test",
1710                vec![("cookie", "token=xyz789")],
1711            );
1712
1713            let cookies = Cookies::from_request_parts(&request).unwrap();
1714            assert_eq!(cookies.iter().count(), 1);
1715            assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1716        }
1717    }
1718}