rustapi_core/
extract.rs

1//! Extractors for RustAPI
2//!
3//! Extractors automatically parse and validate data from incoming HTTP requests.
4//! They implement the [`FromRequest`] or [`FromRequestParts`] traits and can be
5//! used as handler function parameters.
6//!
7//! # Available Extractors
8//!
9//! | Extractor | Description | Consumes Body |
10//! |-----------|-------------|---------------|
11//! | [`Json<T>`] | Parse JSON request body | Yes |
12//! | [`ValidatedJson<T>`] | Parse and validate JSON body | Yes |
13//! | [`Query<T>`] | Parse query string parameters | No |
14//! | [`Path<T>`] | Extract path parameters | No |
15//! | [`State<T>`] | Access shared application state | No |
16//! | [`Body`] | Raw request body bytes | Yes |
17//! | [`Headers`] | Access all request headers | No |
18//! | [`HeaderValue`] | Extract a specific header | No |
19//! | [`Extension<T>`] | Access middleware-injected data | No |
20//! | [`ClientIp`] | Extract client IP address | No |
21//! | [`Cookies`] | Parse request cookies (requires `cookies` feature) | No |
22//!
23//! # Example
24//!
25//! ```rust,ignore
26//! use rustapi_core::{Json, Query, Path, State};
27//! use serde::{Deserialize, Serialize};
28//!
29//! #[derive(Deserialize)]
30//! struct CreateUser {
31//!     name: String,
32//!     email: String,
33//! }
34//!
35//! #[derive(Deserialize)]
36//! struct Pagination {
37//!     page: Option<u32>,
38//!     limit: Option<u32>,
39//! }
40//!
41//! // Multiple extractors can be combined
42//! async fn create_user(
43//!     State(db): State<DbPool>,
44//!     Query(pagination): Query<Pagination>,
45//!     Json(body): Json<CreateUser>,
46//! ) -> impl IntoResponse {
47//!     // Use db, pagination, and body...
48//! }
49//! ```
50//!
51//! # Extractor Order
52//!
53//! When using multiple extractors, body-consuming extractors (like `Json` or `Body`)
54//! must come last since they consume the request body. Non-body extractors can be
55//! in any order.
56
57use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use serde::de::DeserializeOwned;
68use serde::Serialize;
69use std::future::Future;
70use std::ops::{Deref, DerefMut};
71use std::str::FromStr;
72
73/// Trait for extracting data from request parts (headers, path, query)
74///
75/// This is used for extractors that don't need the request body.
76///
77/// # Example: Implementing a custom extractor that requires a specific header
78///
79/// ```rust
80/// use rustapi_core::FromRequestParts;
81/// use rustapi_core::{Request, ApiError, Result};
82/// use http::StatusCode;
83///
84/// struct ApiKey(String);
85///
86/// impl FromRequestParts for ApiKey {
87///     fn from_request_parts(req: &Request) -> Result<Self> {
88///         if let Some(key) = req.headers().get("x-api-key") {
89///             if let Ok(key_str) = key.to_str() {
90///                 return Ok(ApiKey(key_str.to_string()));
91///             }
92///         }
93///         Err(ApiError::unauthorized("Missing or invalid API key"))
94///     }
95/// }
96/// ```
97pub trait FromRequestParts: Sized {
98    /// Extract from request parts
99    fn from_request_parts(req: &Request) -> Result<Self>;
100}
101
102/// Trait for extracting data from the full request (including body)
103///
104/// This is used for extractors that consume the request body.
105///
106/// # Example: Implementing a custom extractor that consumes the body
107///
108/// ```rust
109/// use rustapi_core::FromRequest;
110/// use rustapi_core::{Request, ApiError, Result};
111/// use std::future::Future;
112///
113/// struct PlainText(String);
114///
115/// impl FromRequest for PlainText {
116///     async fn from_request(req: &mut Request) -> Result<Self> {
117///         // Ensure body is loaded
118///         req.load_body().await?;
119///         
120///         // Consume the body
121///         if let Some(bytes) = req.take_body() {
122///             if let Ok(text) = String::from_utf8(bytes.to_vec()) {
123///                 return Ok(PlainText(text));
124///             }
125///         }
126///         
127///         Err(ApiError::bad_request("Invalid plain text body"))
128///     }
129/// }
130/// ```
131pub trait FromRequest: Sized {
132    /// Extract from the full request
133    fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
134}
135
136// Blanket impl: FromRequestParts -> FromRequest
137impl<T: FromRequestParts> FromRequest for T {
138    async fn from_request(req: &mut Request) -> Result<Self> {
139        T::from_request_parts(req)
140    }
141}
142
143/// JSON body extractor
144///
145/// Parses the request body as JSON and deserializes into type `T`.
146/// Also works as a response type when T: Serialize.
147///
148/// # Example
149///
150/// ```rust,ignore
151/// #[derive(Deserialize)]
152/// struct CreateUser {
153///     name: String,
154///     email: String,
155/// }
156///
157/// async fn create_user(Json(body): Json<CreateUser>) -> impl IntoResponse {
158///     // body is already deserialized
159/// }
160/// ```
161#[derive(Debug, Clone, Copy, Default)]
162pub struct Json<T>(pub T);
163
164impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
165    async fn from_request(req: &mut Request) -> Result<Self> {
166        req.load_body().await?;
167        let body = req
168            .take_body()
169            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
170
171        // Use simd-json accelerated parsing when available (2-4x faster)
172        let value: T = json::from_slice(&body)?;
173        Ok(Json(value))
174    }
175}
176
177impl<T> Deref for Json<T> {
178    type Target = T;
179
180    fn deref(&self) -> &Self::Target {
181        &self.0
182    }
183}
184
185impl<T> DerefMut for Json<T> {
186    fn deref_mut(&mut self) -> &mut Self::Target {
187        &mut self.0
188    }
189}
190
191impl<T> From<T> for Json<T> {
192    fn from(value: T) -> Self {
193        Json(value)
194    }
195}
196
197/// Default pre-allocation size for JSON response buffers (256 bytes)
198/// This covers most small to medium JSON responses without reallocation.
199const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
200
201// IntoResponse for Json - allows using Json<T> as a return type
202impl<T: Serialize> IntoResponse for Json<T> {
203    fn into_response(self) -> crate::response::Response {
204        // Use pre-allocated buffer to reduce allocations
205        match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
206            Ok(body) => http::Response::builder()
207                .status(StatusCode::OK)
208                .header(header::CONTENT_TYPE, "application/json")
209                .body(crate::response::Body::from(body))
210                .unwrap(),
211            Err(err) => {
212                ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
213            }
214        }
215    }
216}
217
218/// Validated JSON body extractor
219///
220/// Parses the request body as JSON, deserializes into type `T`, and validates
221/// using the `Validate` trait. Returns a 422 Unprocessable Entity error with
222/// detailed field-level validation errors if validation fails.
223///
224/// # Example
225///
226/// ```rust,ignore
227/// use rustapi_rs::prelude::*;
228/// use validator::Validate;
229///
230/// #[derive(Deserialize, Validate)]
231/// struct CreateUser {
232///     #[validate(email)]
233///     email: String,
234///     #[validate(length(min = 8))]
235///     password: String,
236/// }
237///
238/// async fn register(ValidatedJson(body): ValidatedJson<CreateUser>) -> impl IntoResponse {
239///     // body is already validated!
240///     // If email is invalid or password too short, a 422 error is returned automatically
241/// }
242/// ```
243#[derive(Debug, Clone, Copy, Default)]
244pub struct ValidatedJson<T>(pub T);
245
246impl<T> ValidatedJson<T> {
247    /// Create a new ValidatedJson wrapper
248    pub fn new(value: T) -> Self {
249        Self(value)
250    }
251
252    /// Get the inner value
253    pub fn into_inner(self) -> T {
254        self.0
255    }
256}
257
258impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
259    async fn from_request(req: &mut Request) -> Result<Self> {
260        req.load_body().await?;
261        // First, deserialize the JSON body using simd-json when available
262        let body = req
263            .take_body()
264            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
265
266        let value: T = json::from_slice(&body)?;
267
268        // Then, validate it using the unified Validatable trait
269        if let Err(e) = value.do_validate() {
270            return Err(e);
271        }
272
273        Ok(ValidatedJson(value))
274    }
275}
276
277impl<T> Deref for ValidatedJson<T> {
278    type Target = T;
279
280    fn deref(&self) -> &Self::Target {
281        &self.0
282    }
283}
284
285impl<T> DerefMut for ValidatedJson<T> {
286    fn deref_mut(&mut self) -> &mut Self::Target {
287        &mut self.0
288    }
289}
290
291impl<T> From<T> for ValidatedJson<T> {
292    fn from(value: T) -> Self {
293        ValidatedJson(value)
294    }
295}
296
297impl<T: Serialize> IntoResponse for ValidatedJson<T> {
298    fn into_response(self) -> crate::response::Response {
299        Json(self.0).into_response()
300    }
301}
302
303/// Async validated JSON body extractor
304///
305/// Parses the request body as JSON, deserializes into type `T`, and validates
306/// using the `AsyncValidate` trait from `rustapi-validate`.
307///
308/// This extractor supports async validation rules, such as database uniqueness checks.
309///
310/// # Example
311///
312/// ```rust,ignore
313/// use rustapi_rs::prelude::*;
314/// use rustapi_validate::v2::prelude::*;
315///
316/// #[derive(Deserialize, Validate, AsyncValidate)]
317/// struct CreateUser {
318///     #[validate(email)]
319///     email: String,
320///     
321///     #[validate(async_unique(table = "users", column = "email"))]
322///     username: String,
323/// }
324///
325/// async fn register(AsyncValidatedJson(body): AsyncValidatedJson<CreateUser>) -> impl IntoResponse {
326///     // body is validated asynchronously (e.g. checked existing email in DB)
327/// }
328/// ```
329#[derive(Debug, Clone, Copy, Default)]
330pub struct AsyncValidatedJson<T>(pub T);
331
332impl<T> AsyncValidatedJson<T> {
333    /// Create a new AsyncValidatedJson wrapper
334    pub fn new(value: T) -> Self {
335        Self(value)
336    }
337
338    /// Get the inner value
339    pub fn into_inner(self) -> T {
340        self.0
341    }
342}
343
344impl<T> Deref for AsyncValidatedJson<T> {
345    type Target = T;
346
347    fn deref(&self) -> &Self::Target {
348        &self.0
349    }
350}
351
352impl<T> DerefMut for AsyncValidatedJson<T> {
353    fn deref_mut(&mut self) -> &mut Self::Target {
354        &mut self.0
355    }
356}
357
358impl<T> From<T> for AsyncValidatedJson<T> {
359    fn from(value: T) -> Self {
360        AsyncValidatedJson(value)
361    }
362}
363
364impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
365    fn into_response(self) -> crate::response::Response {
366        Json(self.0).into_response()
367    }
368}
369
370impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
371    async fn from_request(req: &mut Request) -> Result<Self> {
372        req.load_body().await?;
373
374        let body = req
375            .take_body()
376            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
377
378        let value: T = json::from_slice(&body)?;
379
380        // Create validation context from request
381        // TODO: Extract validators from App State
382        let ctx = ValidationContext::default();
383
384        // Perform full validation (sync + async)
385        if let Err(errors) = value.validate_full(&ctx).await {
386            // Convert v2 ValidationErrors to ApiError
387            let field_errors: Vec<crate::error::FieldError> = errors
388                .fields
389                .iter()
390                .flat_map(|(field, errs)| {
391                    let field_name = field.to_string();
392                    errs.iter().map(move |e| crate::error::FieldError {
393                        field: field_name.clone(),
394                        code: e.code.to_string(),
395                        message: e.message.clone(),
396                    })
397                })
398                .collect();
399
400            return Err(ApiError::validation(field_errors));
401        }
402
403        Ok(AsyncValidatedJson(value))
404    }
405}
406
407/// Query string extractor
408///
409/// Parses the query string into type `T`.
410///
411/// # Example
412///
413/// ```rust,ignore
414/// #[derive(Deserialize)]
415/// struct Pagination {
416///     page: Option<u32>,
417///     limit: Option<u32>,
418/// }
419///
420/// async fn list_users(Query(params): Query<Pagination>) -> impl IntoResponse {
421///     // params.page, params.limit
422/// }
423/// ```
424#[derive(Debug, Clone)]
425pub struct Query<T>(pub T);
426
427impl<T: DeserializeOwned> FromRequestParts for Query<T> {
428    fn from_request_parts(req: &Request) -> Result<Self> {
429        let query = req.query_string().unwrap_or("");
430        let value: T = serde_urlencoded::from_str(query)
431            .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
432        Ok(Query(value))
433    }
434}
435
436impl<T> Deref for Query<T> {
437    type Target = T;
438
439    fn deref(&self) -> &Self::Target {
440        &self.0
441    }
442}
443
444/// Path parameter extractor
445///
446/// Extracts path parameters defined in the route pattern.
447///
448/// # Example
449///
450/// For route `/users/{id}`:
451///
452/// ```rust,ignore
453/// async fn get_user(Path(id): Path<i64>) -> impl IntoResponse {
454///     // id is extracted from path
455/// }
456/// ```
457///
458/// For multiple params `/users/{user_id}/posts/{post_id}`:
459///
460/// ```rust,ignore
461/// async fn get_post(Path((user_id, post_id)): Path<(i64, i64)>) -> impl IntoResponse {
462///     // Both params extracted
463/// }
464/// ```
465#[derive(Debug, Clone)]
466pub struct Path<T>(pub T);
467
468impl<T: FromStr> FromRequestParts for Path<T>
469where
470    T::Err: std::fmt::Display,
471{
472    fn from_request_parts(req: &Request) -> Result<Self> {
473        let params = req.path_params();
474
475        // For single param, get the first one
476        if let Some((_, value)) = params.iter().next() {
477            let parsed = value
478                .parse::<T>()
479                .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
480            return Ok(Path(parsed));
481        }
482
483        Err(ApiError::internal("Missing path parameter"))
484    }
485}
486
487impl<T> Deref for Path<T> {
488    type Target = T;
489
490    fn deref(&self) -> &Self::Target {
491        &self.0
492    }
493}
494
495/// Typed path extractor
496///
497/// Extracts path parameters and deserializes them into a struct implementing `Deserialize`.
498/// This is similar to `Path<T>`, but supports complex structs that can be deserialized
499/// from a map of parameter names to values (e.g. via `serde_json`).
500///
501/// # Example
502///
503/// ```rust,ignore
504/// #[derive(Deserialize)]
505/// struct UserParams {
506///     id: u64,
507///     category: String,
508/// }
509///
510/// async fn get_user(Typed(params): Typed<UserParams>) -> impl IntoResponse {
511///     // params.id, params.category
512/// }
513/// ```
514#[derive(Debug, Clone)]
515pub struct Typed<T>(pub T);
516
517impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
518    fn from_request_parts(req: &Request) -> Result<Self> {
519        let params = req.path_params();
520        let mut map = serde_json::Map::new();
521        for (k, v) in params.iter() {
522            map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
523        }
524        let value = serde_json::Value::Object(map);
525        let parsed: T = serde_json::from_value(value)
526            .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
527        Ok(Typed(parsed))
528    }
529}
530
531impl<T> Deref for Typed<T> {
532    type Target = T;
533
534    fn deref(&self) -> &Self::Target {
535        &self.0
536    }
537}
538
539/// State extractor
540///
541/// Extracts shared application state.
542///
543/// # Example
544///
545/// ```rust,ignore
546/// #[derive(Clone)]
547/// struct AppState {
548///     db: DbPool,
549/// }
550///
551/// async fn handler(State(state): State<AppState>) -> impl IntoResponse {
552///     // Use state.db
553/// }
554/// ```
555#[derive(Debug, Clone)]
556pub struct State<T>(pub T);
557
558impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
559    fn from_request_parts(req: &Request) -> Result<Self> {
560        req.state().get::<T>().cloned().map(State).ok_or_else(|| {
561            ApiError::internal(format!(
562                "State of type `{}` not found. Did you forget to call .state()?",
563                std::any::type_name::<T>()
564            ))
565        })
566    }
567}
568
569impl<T> Deref for State<T> {
570    type Target = T;
571
572    fn deref(&self) -> &Self::Target {
573        &self.0
574    }
575}
576
577/// Raw body bytes extractor
578#[derive(Debug, Clone)]
579pub struct Body(pub Bytes);
580
581impl FromRequest for Body {
582    async fn from_request(req: &mut Request) -> Result<Self> {
583        req.load_body().await?;
584        let body = req
585            .take_body()
586            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
587        Ok(Body(body))
588    }
589}
590
591impl Deref for Body {
592    type Target = Bytes;
593
594    fn deref(&self) -> &Self::Target {
595        &self.0
596    }
597}
598
599/// Streaming body extractor
600pub struct BodyStream(pub StreamingBody);
601
602impl FromRequest for BodyStream {
603    async fn from_request(req: &mut Request) -> Result<Self> {
604        let config = StreamingConfig::default();
605
606        if let Some(stream) = req.take_stream() {
607            Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
608        } else if let Some(bytes) = req.take_body() {
609            // Handle buffered body as stream
610            let stream = futures_util::stream::once(async move { Ok(bytes) });
611            Ok(BodyStream(StreamingBody::from_stream(
612                stream,
613                config.max_body_size,
614            )))
615        } else {
616            Err(ApiError::internal("Body already consumed"))
617        }
618    }
619}
620
621impl Deref for BodyStream {
622    type Target = StreamingBody;
623
624    fn deref(&self) -> &Self::Target {
625        &self.0
626    }
627}
628
629impl DerefMut for BodyStream {
630    fn deref_mut(&mut self) -> &mut Self::Target {
631        &mut self.0
632    }
633}
634
635// Forward stream implementation
636impl futures_util::Stream for BodyStream {
637    type Item = Result<Bytes, ApiError>;
638
639    fn poll_next(
640        mut self: std::pin::Pin<&mut Self>,
641        cx: &mut std::task::Context<'_>,
642    ) -> std::task::Poll<Option<Self::Item>> {
643        std::pin::Pin::new(&mut self.0).poll_next(cx)
644    }
645}
646
647/// Optional extractor wrapper
648///
649/// Makes any extractor optional - returns None instead of error on failure.
650impl<T: FromRequestParts> FromRequestParts for Option<T> {
651    fn from_request_parts(req: &Request) -> Result<Self> {
652        Ok(T::from_request_parts(req).ok())
653    }
654}
655
656/// Headers extractor
657///
658/// Provides access to all request headers as a typed map.
659///
660/// # Example
661///
662/// ```rust,ignore
663/// use rustapi_core::extract::Headers;
664///
665/// async fn handler(headers: Headers) -> impl IntoResponse {
666///     if let Some(content_type) = headers.get("content-type") {
667///         format!("Content-Type: {:?}", content_type)
668///     } else {
669///         "No Content-Type header".to_string()
670///     }
671/// }
672/// ```
673#[derive(Debug, Clone)]
674pub struct Headers(pub http::HeaderMap);
675
676impl Headers {
677    /// Get a header value by name
678    pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
679        self.0.get(name)
680    }
681
682    /// Check if a header exists
683    pub fn contains(&self, name: &str) -> bool {
684        self.0.contains_key(name)
685    }
686
687    /// Get the number of headers
688    pub fn len(&self) -> usize {
689        self.0.len()
690    }
691
692    /// Check if headers are empty
693    pub fn is_empty(&self) -> bool {
694        self.0.is_empty()
695    }
696
697    /// Iterate over all headers
698    pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
699        self.0.iter()
700    }
701}
702
703impl FromRequestParts for Headers {
704    fn from_request_parts(req: &Request) -> Result<Self> {
705        Ok(Headers(req.headers().clone()))
706    }
707}
708
709impl Deref for Headers {
710    type Target = http::HeaderMap;
711
712    fn deref(&self) -> &Self::Target {
713        &self.0
714    }
715}
716
717/// Single header value extractor
718///
719/// Extracts a specific header value by name. Returns an error if the header is missing.
720///
721/// # Example
722///
723/// ```rust,ignore
724/// use rustapi_core::extract::HeaderValue;
725///
726/// async fn handler(
727///     auth: HeaderValue<{ "authorization" }>,
728/// ) -> impl IntoResponse {
729///     format!("Auth header: {}", auth.0)
730/// }
731/// ```
732///
733/// Note: Due to Rust's const generics limitations, you may need to use the
734/// `HeaderValueOf` type alias or extract headers manually using the `Headers` extractor.
735#[derive(Debug, Clone)]
736pub struct HeaderValue(pub String, pub &'static str);
737
738impl HeaderValue {
739    /// Create a new HeaderValue extractor for a specific header name
740    pub fn new(name: &'static str, value: String) -> Self {
741        Self(value, name)
742    }
743
744    /// Get the header value
745    pub fn value(&self) -> &str {
746        &self.0
747    }
748
749    /// Get the header name
750    pub fn name(&self) -> &'static str {
751        self.1
752    }
753
754    /// Extract a specific header from a request
755    pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
756        req.headers()
757            .get(name)
758            .and_then(|v| v.to_str().ok())
759            .map(|s| HeaderValue(s.to_string(), name))
760            .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
761    }
762}
763
764impl Deref for HeaderValue {
765    type Target = String;
766
767    fn deref(&self) -> &Self::Target {
768        &self.0
769    }
770}
771
772/// Extension extractor
773///
774/// Retrieves typed data from request extensions that was inserted by middleware.
775///
776/// # Example
777///
778/// ```rust,ignore
779/// use rustapi_core::extract::Extension;
780///
781/// // Middleware inserts user data
782/// #[derive(Clone)]
783/// struct CurrentUser { id: i64 }
784///
785/// async fn handler(Extension(user): Extension<CurrentUser>) -> impl IntoResponse {
786///     format!("User ID: {}", user.id)
787/// }
788/// ```
789#[derive(Debug, Clone)]
790pub struct Extension<T>(pub T);
791
792impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
793    fn from_request_parts(req: &Request) -> Result<Self> {
794        req.extensions()
795            .get::<T>()
796            .cloned()
797            .map(Extension)
798            .ok_or_else(|| {
799                ApiError::internal(format!(
800                    "Extension of type `{}` not found. Did middleware insert it?",
801                    std::any::type_name::<T>()
802                ))
803            })
804    }
805}
806
807impl<T> Deref for Extension<T> {
808    type Target = T;
809
810    fn deref(&self) -> &Self::Target {
811        &self.0
812    }
813}
814
815impl<T> DerefMut for Extension<T> {
816    fn deref_mut(&mut self) -> &mut Self::Target {
817        &mut self.0
818    }
819}
820
821/// Client IP address extractor
822///
823/// Extracts the client IP address from the request. When `trust_proxy` is enabled,
824/// it will use the `X-Forwarded-For` header if present.
825///
826/// # Example
827///
828/// ```rust,ignore
829/// use rustapi_core::extract::ClientIp;
830///
831/// async fn handler(ClientIp(ip): ClientIp) -> impl IntoResponse {
832///     format!("Your IP: {}", ip)
833/// }
834/// ```
835#[derive(Debug, Clone)]
836pub struct ClientIp(pub std::net::IpAddr);
837
838impl ClientIp {
839    /// Extract client IP, optionally trusting X-Forwarded-For header
840    pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
841        if trust_proxy {
842            // Try X-Forwarded-For header first
843            if let Some(forwarded) = req.headers().get("x-forwarded-for") {
844                if let Ok(forwarded_str) = forwarded.to_str() {
845                    // X-Forwarded-For can contain multiple IPs, take the first one
846                    if let Some(first_ip) = forwarded_str.split(',').next() {
847                        if let Ok(ip) = first_ip.trim().parse() {
848                            return Ok(ClientIp(ip));
849                        }
850                    }
851                }
852            }
853        }
854
855        // Fall back to socket address from extensions (if set by server)
856        if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
857            return Ok(ClientIp(addr.ip()));
858        }
859
860        // Default to localhost if no IP information available
861        Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
862            127, 0, 0, 1,
863        ))))
864    }
865}
866
867impl FromRequestParts for ClientIp {
868    fn from_request_parts(req: &Request) -> Result<Self> {
869        // By default, trust proxy headers
870        Self::extract_with_config(req, true)
871    }
872}
873
874/// Cookies extractor
875///
876/// Parses and provides access to request cookies from the Cookie header.
877///
878/// # Example
879///
880/// ```rust,ignore
881/// use rustapi_core::extract::Cookies;
882///
883/// async fn handler(cookies: Cookies) -> impl IntoResponse {
884///     if let Some(session) = cookies.get("session_id") {
885///         format!("Session: {}", session.value())
886///     } else {
887///         "No session cookie".to_string()
888///     }
889/// }
890/// ```
891#[cfg(feature = "cookies")]
892#[derive(Debug, Clone)]
893pub struct Cookies(pub cookie::CookieJar);
894
895#[cfg(feature = "cookies")]
896impl Cookies {
897    /// Get a cookie by name
898    pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
899        self.0.get(name)
900    }
901
902    /// Iterate over all cookies
903    pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
904        self.0.iter()
905    }
906
907    /// Check if a cookie exists
908    pub fn contains(&self, name: &str) -> bool {
909        self.0.get(name).is_some()
910    }
911}
912
913#[cfg(feature = "cookies")]
914impl FromRequestParts for Cookies {
915    fn from_request_parts(req: &Request) -> Result<Self> {
916        let mut jar = cookie::CookieJar::new();
917
918        if let Some(cookie_header) = req.headers().get(header::COOKIE) {
919            if let Ok(cookie_str) = cookie_header.to_str() {
920                // Parse each cookie from the header
921                for cookie_part in cookie_str.split(';') {
922                    let trimmed = cookie_part.trim();
923                    if !trimmed.is_empty() {
924                        if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
925                            jar.add_original(cookie.into_owned());
926                        }
927                    }
928                }
929            }
930        }
931
932        Ok(Cookies(jar))
933    }
934}
935
936#[cfg(feature = "cookies")]
937impl Deref for Cookies {
938    type Target = cookie::CookieJar;
939
940    fn deref(&self) -> &Self::Target {
941        &self.0
942    }
943}
944
945// Implement FromRequestParts for common primitive types (path params)
946macro_rules! impl_from_request_parts_for_primitives {
947    ($($ty:ty),*) => {
948        $(
949            impl FromRequestParts for $ty {
950                fn from_request_parts(req: &Request) -> Result<Self> {
951                    let Path(value) = Path::<$ty>::from_request_parts(req)?;
952                    Ok(value)
953                }
954            }
955        )*
956    };
957}
958
959impl_from_request_parts_for_primitives!(
960    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
961);
962
963// OperationModifier implementations for extractors
964
965use rustapi_openapi::utoipa_types::openapi;
966use rustapi_openapi::{
967    IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
968    ResponseSpec, Schema, SchemaRef,
969};
970use std::collections::HashMap;
971
972// ValidatedJson - Adds request body
973impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
974    fn update_operation(op: &mut Operation) {
975        let (name, _) = T::schema();
976
977        let schema_ref = SchemaRef::Ref {
978            reference: format!("#/components/schemas/{}", name),
979        };
980
981        let mut content = HashMap::new();
982        content.insert(
983            "application/json".to_string(),
984            MediaType { schema: schema_ref },
985        );
986
987        op.request_body = Some(RequestBody {
988            required: true,
989            content,
990        });
991
992        // Add 422 Validation Error response
993        op.responses.insert(
994            "422".to_string(),
995            ResponseSpec {
996                description: "Validation Error".to_string(),
997                content: {
998                    let mut map = HashMap::new();
999                    map.insert(
1000                        "application/json".to_string(),
1001                        MediaType {
1002                            schema: SchemaRef::Ref {
1003                                reference: "#/components/schemas/ValidationErrorSchema".to_string(),
1004                            },
1005                        },
1006                    );
1007                    Some(map)
1008                },
1009            },
1010        );
1011    }
1012}
1013
1014// Json - Adds request body (Same as ValidatedJson)
1015impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
1016    fn update_operation(op: &mut Operation) {
1017        let (name, _) = T::schema();
1018
1019        let schema_ref = SchemaRef::Ref {
1020            reference: format!("#/components/schemas/{}", name),
1021        };
1022
1023        let mut content = HashMap::new();
1024        content.insert(
1025            "application/json".to_string(),
1026            MediaType { schema: schema_ref },
1027        );
1028
1029        op.request_body = Some(RequestBody {
1030            required: true,
1031            content,
1032        });
1033    }
1034}
1035
1036// Path - Path parameters are automatically extracted from route patterns
1037// The add_path_params_to_operation function in app.rs handles OpenAPI documentation
1038// based on the {param} syntax in route paths (e.g., "/users/{id}")
1039impl<T> OperationModifier for Path<T> {
1040    fn update_operation(_op: &mut Operation) {
1041        // Path parameters are automatically documented by add_path_params_to_operation
1042        // in app.rs based on the route pattern. No additional implementation needed here.
1043        //
1044        // For typed path params, the schema type defaults to "string" but will be
1045        // inferred from the actual type T when more sophisticated type introspection
1046        // is implemented.
1047    }
1048}
1049
1050// Typed - Same as Path, parameters are documented by route pattern
1051impl<T> OperationModifier for Typed<T> {
1052    fn update_operation(_op: &mut Operation) {
1053        // No-op, managed by route registration
1054    }
1055}
1056
1057// Query - Extracts query params using IntoParams
1058impl<T: IntoParams> OperationModifier for Query<T> {
1059    fn update_operation(op: &mut Operation) {
1060        let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
1061
1062        let new_params: Vec<Parameter> = params
1063            .into_iter()
1064            .map(|p| {
1065                let schema = match p.schema {
1066                    Some(schema) => match schema {
1067                        openapi::RefOr::Ref(r) => SchemaRef::Ref {
1068                            reference: r.ref_location,
1069                        },
1070                        openapi::RefOr::T(s) => {
1071                            let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
1072                            SchemaRef::Inline(value)
1073                        }
1074                    },
1075                    None => SchemaRef::Inline(serde_json::Value::Null),
1076                };
1077
1078                let required = match p.required {
1079                    openapi::Required::True => true,
1080                    openapi::Required::False => false,
1081                };
1082
1083                Parameter {
1084                    name: p.name,
1085                    location: "query".to_string(), // explicitly query
1086                    required,
1087                    description: p.description,
1088                    schema,
1089                }
1090            })
1091            .collect();
1092
1093        if let Some(existing) = &mut op.parameters {
1094            existing.extend(new_params);
1095        } else {
1096            op.parameters = Some(new_params);
1097        }
1098    }
1099}
1100
1101// State - No op
1102impl<T> OperationModifier for State<T> {
1103    fn update_operation(_op: &mut Operation) {}
1104}
1105
1106// Body - Generic binary body
1107impl OperationModifier for Body {
1108    fn update_operation(op: &mut Operation) {
1109        let mut content = HashMap::new();
1110        content.insert(
1111            "application/octet-stream".to_string(),
1112            MediaType {
1113                schema: SchemaRef::Inline(
1114                    serde_json::json!({ "type": "string", "format": "binary" }),
1115                ),
1116            },
1117        );
1118
1119        op.request_body = Some(RequestBody {
1120            required: true,
1121            content,
1122        });
1123    }
1124}
1125
1126// BodyStream - Generic binary stream (Same as Body)
1127impl OperationModifier for BodyStream {
1128    fn update_operation(op: &mut Operation) {
1129        let mut content = HashMap::new();
1130        content.insert(
1131            "application/octet-stream".to_string(),
1132            MediaType {
1133                schema: SchemaRef::Inline(
1134                    serde_json::json!({ "type": "string", "format": "binary" }),
1135                ),
1136            },
1137        );
1138
1139        op.request_body = Some(RequestBody {
1140            required: true,
1141            content,
1142        });
1143    }
1144}
1145
1146// ResponseModifier implementations for extractors
1147
1148// Json<T> - 200 OK with schema T
1149impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
1150    fn update_response(op: &mut Operation) {
1151        let (name, _) = T::schema();
1152
1153        let schema_ref = SchemaRef::Ref {
1154            reference: format!("#/components/schemas/{}", name),
1155        };
1156
1157        op.responses.insert(
1158            "200".to_string(),
1159            ResponseSpec {
1160                description: "Successful response".to_string(),
1161                content: {
1162                    let mut map = HashMap::new();
1163                    map.insert(
1164                        "application/json".to_string(),
1165                        MediaType { schema: schema_ref },
1166                    );
1167                    Some(map)
1168                },
1169            },
1170        );
1171    }
1172}
1173
1174#[cfg(test)]
1175mod tests {
1176    use super::*;
1177    use crate::path_params::PathParams;
1178    use bytes::Bytes;
1179    use http::{Extensions, Method};
1180    use proptest::prelude::*;
1181    use proptest::test_runner::TestCaseError;
1182    use std::sync::Arc;
1183
1184    /// Create a test request with the given method, path, and headers
1185    fn create_test_request_with_headers(
1186        method: Method,
1187        path: &str,
1188        headers: Vec<(&str, &str)>,
1189    ) -> Request {
1190        let uri: http::Uri = path.parse().unwrap();
1191        let mut builder = http::Request::builder().method(method).uri(uri);
1192
1193        for (name, value) in headers {
1194            builder = builder.header(name, value);
1195        }
1196
1197        let req = builder.body(()).unwrap();
1198        let (parts, _) = req.into_parts();
1199
1200        Request::new(
1201            parts,
1202            crate::request::BodyVariant::Buffered(Bytes::new()),
1203            Arc::new(Extensions::new()),
1204            PathParams::new(),
1205        )
1206    }
1207
1208    /// Create a test request with extensions
1209    fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1210        method: Method,
1211        path: &str,
1212        extension: T,
1213    ) -> Request {
1214        let uri: http::Uri = path.parse().unwrap();
1215        let builder = http::Request::builder().method(method).uri(uri);
1216
1217        let req = builder.body(()).unwrap();
1218        let (mut parts, _) = req.into_parts();
1219        parts.extensions.insert(extension);
1220
1221        Request::new(
1222            parts,
1223            crate::request::BodyVariant::Buffered(Bytes::new()),
1224            Arc::new(Extensions::new()),
1225            PathParams::new(),
1226        )
1227    }
1228
1229    // **Feature: phase3-batteries-included, Property 14: Headers extractor completeness**
1230    //
1231    // For any request with headers H, the `Headers` extractor SHALL return a map
1232    // containing all key-value pairs in H.
1233    //
1234    // **Validates: Requirements 5.1**
1235    proptest! {
1236        #![proptest_config(ProptestConfig::with_cases(100))]
1237
1238        #[test]
1239        fn prop_headers_extractor_completeness(
1240            // Generate random header names and values
1241            // Using alphanumeric strings to ensure valid header names/values
1242            headers in prop::collection::vec(
1243                (
1244                    "[a-z][a-z0-9-]{0,20}",  // Valid header name pattern
1245                    "[a-zA-Z0-9 ]{1,50}"     // Valid header value pattern
1246                ),
1247                0..10
1248            )
1249        ) {
1250            let result: Result<(), TestCaseError> = (|| {
1251                // Convert to header tuples
1252                let header_tuples: Vec<(&str, &str)> = headers
1253                    .iter()
1254                    .map(|(k, v)| (k.as_str(), v.as_str()))
1255                    .collect();
1256
1257                // Create request with headers
1258                let request = create_test_request_with_headers(
1259                    Method::GET,
1260                    "/test",
1261                    header_tuples.clone(),
1262                );
1263
1264                // Extract headers
1265                let extracted = Headers::from_request_parts(&request)
1266                    .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1267
1268                // Verify all original headers are present
1269                // HTTP allows duplicate headers - get_all() returns all values for a header name
1270                for (name, value) in &headers {
1271                    // Check that the header name exists
1272                    let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1273                    prop_assert!(
1274                        !all_values.is_empty(),
1275                        "Header '{}' not found",
1276                        name
1277                    );
1278
1279                    // Check that the value is among the extracted values
1280                    let value_found = all_values.iter().any(|v| {
1281                        v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1282                    });
1283
1284                    prop_assert!(
1285                        value_found,
1286                        "Header '{}' value '{}' not found in extracted values",
1287                        name,
1288                        value
1289                    );
1290                }
1291
1292                Ok(())
1293            })();
1294            result?;
1295        }
1296    }
1297
1298    // **Feature: phase3-batteries-included, Property 15: HeaderValue extractor correctness**
1299    //
1300    // For any request with header "X" having value V, `HeaderValue::extract(req, "X")` SHALL return V;
1301    // for requests without header "X", it SHALL return an error.
1302    //
1303    // **Validates: Requirements 5.2**
1304    proptest! {
1305        #![proptest_config(ProptestConfig::with_cases(100))]
1306
1307        #[test]
1308        fn prop_header_value_extractor_correctness(
1309            header_name in "[a-z][a-z0-9-]{0,20}",
1310            header_value in "[a-zA-Z0-9 ]{1,50}",
1311            has_header in prop::bool::ANY,
1312        ) {
1313            let result: Result<(), TestCaseError> = (|| {
1314                let headers = if has_header {
1315                    vec![(header_name.as_str(), header_value.as_str())]
1316                } else {
1317                    vec![]
1318                };
1319
1320                let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1321
1322                // We need to use a static string for the header name in the extractor
1323                // So we'll test with a known header name
1324                let test_header = "x-test-header";
1325                let request_with_known_header = if has_header {
1326                    create_test_request_with_headers(
1327                        Method::GET,
1328                        "/test",
1329                        vec![(test_header, header_value.as_str())],
1330                    )
1331                } else {
1332                    create_test_request_with_headers(Method::GET, "/test", vec![])
1333                };
1334
1335                let result = HeaderValue::extract(&request_with_known_header, test_header);
1336
1337                if has_header {
1338                    let extracted = result
1339                        .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1340                    prop_assert_eq!(
1341                        extracted.value(),
1342                        header_value.as_str(),
1343                        "Header value mismatch"
1344                    );
1345                } else {
1346                    prop_assert!(
1347                        result.is_err(),
1348                        "Expected error when header is missing"
1349                    );
1350                }
1351
1352                Ok(())
1353            })();
1354            result?;
1355        }
1356    }
1357
1358    // **Feature: phase3-batteries-included, Property 17: ClientIp extractor with forwarding**
1359    //
1360    // For any request with socket IP S and X-Forwarded-For header F, when forwarding is enabled,
1361    // `ClientIp` SHALL return the first IP in F; when disabled, it SHALL return S.
1362    //
1363    // **Validates: Requirements 5.4**
1364    proptest! {
1365        #![proptest_config(ProptestConfig::with_cases(100))]
1366
1367        #[test]
1368        fn prop_client_ip_extractor_with_forwarding(
1369            // Generate valid IPv4 addresses
1370            forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1371                .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1372            socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1373                .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1374            has_forwarded_header in prop::bool::ANY,
1375            trust_proxy in prop::bool::ANY,
1376        ) {
1377            let result: Result<(), TestCaseError> = (|| {
1378                let headers = if has_forwarded_header {
1379                    vec![("x-forwarded-for", forwarded_ip.as_str())]
1380                } else {
1381                    vec![]
1382                };
1383
1384                // Create request with headers
1385                let uri: http::Uri = "/test".parse().unwrap();
1386                let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1387                for (name, value) in &headers {
1388                    builder = builder.header(*name, *value);
1389                }
1390                let req = builder.body(()).unwrap();
1391                let (mut parts, _) = req.into_parts();
1392
1393                // Add socket address to extensions
1394                let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1395                parts.extensions.insert(socket_addr);
1396
1397                let request = Request::new(
1398                    parts,
1399                    crate::request::BodyVariant::Buffered(Bytes::new()),
1400                    Arc::new(Extensions::new()),
1401                    PathParams::new(),
1402                );
1403
1404                let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1405                    .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1406
1407                if trust_proxy && has_forwarded_header {
1408                    // Should use X-Forwarded-For
1409                    let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1410                        .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1411                    prop_assert_eq!(
1412                        extracted.0,
1413                        expected_ip,
1414                        "Should use X-Forwarded-For IP when trust_proxy is enabled"
1415                    );
1416                } else {
1417                    // Should use socket IP
1418                    prop_assert_eq!(
1419                        extracted.0,
1420                        socket_ip,
1421                        "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1422                    );
1423                }
1424
1425                Ok(())
1426            })();
1427            result?;
1428        }
1429    }
1430
1431    // **Feature: phase3-batteries-included, Property 18: Extension extractor retrieval**
1432    //
1433    // For any type T and value V inserted into request extensions by middleware,
1434    // `Extension<T>` SHALL return V.
1435    //
1436    // **Validates: Requirements 5.5**
1437    proptest! {
1438        #![proptest_config(ProptestConfig::with_cases(100))]
1439
1440        #[test]
1441        fn prop_extension_extractor_retrieval(
1442            value in any::<i64>(),
1443            has_extension in prop::bool::ANY,
1444        ) {
1445            let result: Result<(), TestCaseError> = (|| {
1446                // Create a simple wrapper type for testing
1447                #[derive(Clone, Debug, PartialEq)]
1448                struct TestExtension(i64);
1449
1450                let uri: http::Uri = "/test".parse().unwrap();
1451                let builder = http::Request::builder().method(Method::GET).uri(uri);
1452                let req = builder.body(()).unwrap();
1453                let (mut parts, _) = req.into_parts();
1454
1455                if has_extension {
1456                    parts.extensions.insert(TestExtension(value));
1457                }
1458
1459                let request = Request::new(
1460                    parts,
1461                    crate::request::BodyVariant::Buffered(Bytes::new()),
1462                    Arc::new(Extensions::new()),
1463                    PathParams::new(),
1464                );
1465
1466                let result = Extension::<TestExtension>::from_request_parts(&request);
1467
1468                if has_extension {
1469                    let extracted = result
1470                        .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1471                    prop_assert_eq!(
1472                        extracted.0,
1473                        TestExtension(value),
1474                        "Extension value mismatch"
1475                    );
1476                } else {
1477                    prop_assert!(
1478                        result.is_err(),
1479                        "Expected error when extension is missing"
1480                    );
1481                }
1482
1483                Ok(())
1484            })();
1485            result?;
1486        }
1487    }
1488
1489    // Unit tests for basic functionality
1490
1491    #[test]
1492    fn test_headers_extractor_basic() {
1493        let request = create_test_request_with_headers(
1494            Method::GET,
1495            "/test",
1496            vec![
1497                ("content-type", "application/json"),
1498                ("accept", "text/html"),
1499            ],
1500        );
1501
1502        let headers = Headers::from_request_parts(&request).unwrap();
1503
1504        assert!(headers.contains("content-type"));
1505        assert!(headers.contains("accept"));
1506        assert!(!headers.contains("x-custom"));
1507        assert_eq!(headers.len(), 2);
1508    }
1509
1510    #[test]
1511    fn test_header_value_extractor_present() {
1512        let request = create_test_request_with_headers(
1513            Method::GET,
1514            "/test",
1515            vec![("authorization", "Bearer token123")],
1516        );
1517
1518        let result = HeaderValue::extract(&request, "authorization");
1519        assert!(result.is_ok());
1520        assert_eq!(result.unwrap().value(), "Bearer token123");
1521    }
1522
1523    #[test]
1524    fn test_header_value_extractor_missing() {
1525        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1526
1527        let result = HeaderValue::extract(&request, "authorization");
1528        assert!(result.is_err());
1529    }
1530
1531    #[test]
1532    fn test_client_ip_from_forwarded_header() {
1533        let request = create_test_request_with_headers(
1534            Method::GET,
1535            "/test",
1536            vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1537        );
1538
1539        let ip = ClientIp::extract_with_config(&request, true).unwrap();
1540        assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1541    }
1542
1543    #[test]
1544    fn test_client_ip_ignores_forwarded_when_not_trusted() {
1545        let uri: http::Uri = "/test".parse().unwrap();
1546        let builder = http::Request::builder()
1547            .method(Method::GET)
1548            .uri(uri)
1549            .header("x-forwarded-for", "192.168.1.100");
1550        let req = builder.body(()).unwrap();
1551        let (mut parts, _) = req.into_parts();
1552
1553        let socket_addr = std::net::SocketAddr::new(
1554            std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1555            8080,
1556        );
1557        parts.extensions.insert(socket_addr);
1558
1559        let request = Request::new(
1560            parts,
1561            crate::request::BodyVariant::Buffered(Bytes::new()),
1562            Arc::new(Extensions::new()),
1563            PathParams::new(),
1564        );
1565
1566        let ip = ClientIp::extract_with_config(&request, false).unwrap();
1567        assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1568    }
1569
1570    #[test]
1571    fn test_extension_extractor_present() {
1572        #[derive(Clone, Debug, PartialEq)]
1573        struct MyData(String);
1574
1575        let request =
1576            create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1577
1578        let result = Extension::<MyData>::from_request_parts(&request);
1579        assert!(result.is_ok());
1580        assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1581    }
1582
1583    #[test]
1584    fn test_extension_extractor_missing() {
1585        #[derive(Clone, Debug)]
1586        #[allow(dead_code)]
1587        struct MyData(String);
1588
1589        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1590
1591        let result = Extension::<MyData>::from_request_parts(&request);
1592        assert!(result.is_err());
1593    }
1594
1595    // Cookies tests (feature-gated)
1596    #[cfg(feature = "cookies")]
1597    mod cookies_tests {
1598        use super::*;
1599
1600        // **Feature: phase3-batteries-included, Property 16: Cookies extractor parsing**
1601        //
1602        // For any request with Cookie header containing cookies C, the `Cookies` extractor
1603        // SHALL return a CookieJar containing exactly the cookies in C.
1604        // Note: Duplicate cookie names result in only the last value being kept.
1605        //
1606        // **Validates: Requirements 5.3**
1607        proptest! {
1608            #![proptest_config(ProptestConfig::with_cases(100))]
1609
1610            #[test]
1611            fn prop_cookies_extractor_parsing(
1612                // Generate random cookie names and values
1613                // Using alphanumeric strings to ensure valid cookie names/values
1614                cookies in prop::collection::vec(
1615                    (
1616                        "[a-zA-Z][a-zA-Z0-9_]{0,15}",  // Valid cookie name pattern
1617                        "[a-zA-Z0-9]{1,30}"            // Valid cookie value pattern (no special chars)
1618                    ),
1619                    0..5
1620                )
1621            ) {
1622                let result: Result<(), TestCaseError> = (|| {
1623                    // Build cookie header string
1624                    let cookie_header = cookies
1625                        .iter()
1626                        .map(|(name, value)| format!("{}={}", name, value))
1627                        .collect::<Vec<_>>()
1628                        .join("; ");
1629
1630                    let headers = if !cookies.is_empty() {
1631                        vec![("cookie", cookie_header.as_str())]
1632                    } else {
1633                        vec![]
1634                    };
1635
1636                    let request = create_test_request_with_headers(Method::GET, "/test", headers);
1637
1638                    // Extract cookies
1639                    let extracted = Cookies::from_request_parts(&request)
1640                        .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1641
1642                    // Build expected cookies map - last value wins for duplicate names
1643                    let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1644                    for (name, value) in &cookies {
1645                        expected_cookies.insert(name.as_str(), value.as_str());
1646                    }
1647
1648                    // Verify all expected cookies are present with correct values
1649                    for (name, expected_value) in &expected_cookies {
1650                        let cookie = extracted.get(name)
1651                            .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1652
1653                        prop_assert_eq!(
1654                            cookie.value(),
1655                            *expected_value,
1656                            "Cookie '{}' value mismatch",
1657                            name
1658                        );
1659                    }
1660
1661                    // Count cookies in jar should match unique cookie names
1662                    let extracted_count = extracted.iter().count();
1663                    prop_assert_eq!(
1664                        extracted_count,
1665                        expected_cookies.len(),
1666                        "Expected {} unique cookies, got {}",
1667                        expected_cookies.len(),
1668                        extracted_count
1669                    );
1670
1671                    Ok(())
1672                })();
1673                result?;
1674            }
1675        }
1676
1677        #[test]
1678        fn test_cookies_extractor_basic() {
1679            let request = create_test_request_with_headers(
1680                Method::GET,
1681                "/test",
1682                vec![("cookie", "session=abc123; user=john")],
1683            );
1684
1685            let cookies = Cookies::from_request_parts(&request).unwrap();
1686
1687            assert!(cookies.contains("session"));
1688            assert!(cookies.contains("user"));
1689            assert!(!cookies.contains("other"));
1690
1691            assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1692            assert_eq!(cookies.get("user").unwrap().value(), "john");
1693        }
1694
1695        #[test]
1696        fn test_cookies_extractor_empty() {
1697            let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1698
1699            let cookies = Cookies::from_request_parts(&request).unwrap();
1700            assert_eq!(cookies.iter().count(), 0);
1701        }
1702
1703        #[test]
1704        fn test_cookies_extractor_single() {
1705            let request = create_test_request_with_headers(
1706                Method::GET,
1707                "/test",
1708                vec![("cookie", "token=xyz789")],
1709            );
1710
1711            let cookies = Cookies::from_request_parts(&request).unwrap();
1712            assert_eq!(cookies.iter().count(), 1);
1713            assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1714        }
1715    }
1716}