Skip to main content

rustapi_core/
extract.rs

1//! Extractors for RustAPI
2//!
3//! Extractors automatically parse and validate data from incoming HTTP requests.
4//! They implement the [`FromRequest`] or [`FromRequestParts`] traits and can be
5//! used as handler function parameters.
6//!
7//! # Available Extractors
8//!
9//! | Extractor | Description | Consumes Body |
10//! |-----------|-------------|---------------|
11//! | [`Json<T>`] | Parse JSON request body | Yes |
12//! | [`ValidatedJson<T>`] | Parse and validate JSON body | Yes |
13//! | [`Query<T>`] | Parse query string parameters | No |
14//! | [`Path<T>`] | Extract path parameters | No |
15//! | [`State<T>`] | Access shared application state | No |
16//! | [`Body`] | Raw request body bytes | Yes |
17//! | [`Headers`] | Access all request headers | No |
18//! | [`HeaderValue`] | Extract a specific header | No |
19//! | [`Extension<T>`] | Access middleware-injected data | No |
20//! | [`ClientIp`] | Extract client IP address | No |
21//! | [`Cookies`] | Parse request cookies (requires `cookies` feature) | No |
22//!
23//! # Example
24//!
25//! ```rust,ignore
26//! use rustapi_core::{Json, Query, Path, State};
27//! use serde::{Deserialize, Serialize};
28//!
29//! #[derive(Deserialize)]
30//! struct CreateUser {
31//!     name: String,
32//!     email: String,
33//! }
34//!
35//! #[derive(Deserialize)]
36//! struct Pagination {
37//!     page: Option<u32>,
38//!     limit: Option<u32>,
39//! }
40//!
41//! // Multiple extractors can be combined
42//! async fn create_user(
43//!     State(db): State<DbPool>,
44//!     Query(pagination): Query<Pagination>,
45//!     Json(body): Json<CreateUser>,
46//! ) -> impl IntoResponse {
47//!     // Use db, pagination, and body...
48//! }
49//! ```
50//!
51//! # Extractor Order
52//!
53//! When using multiple extractors, body-consuming extractors (like `Json` or `Body`)
54//! must come last since they consume the request body. Non-body extractors can be
55//! in any order.
56
57use crate::error::{ApiError, Result};
58use crate::json;
59use crate::request::Request;
60use crate::response::IntoResponse;
61use crate::stream::{StreamingBody, StreamingConfig};
62use crate::validation::Validatable;
63use bytes::Bytes;
64use http::{header, StatusCode};
65use rustapi_validate::v2::{AsyncValidate, ValidationContext};
66
67use serde::de::DeserializeOwned;
68use serde::Serialize;
69use std::future::Future;
70use std::ops::{Deref, DerefMut};
71use std::str::FromStr;
72
73/// Trait for extracting data from request parts (headers, path, query)
74///
75/// This is used for extractors that don't need the request body.
76///
77/// # Example: Implementing a custom extractor that requires a specific header
78///
79/// ```rust
80/// use rustapi_core::FromRequestParts;
81/// use rustapi_core::{Request, ApiError, Result};
82/// use http::StatusCode;
83///
84/// struct ApiKey(String);
85///
86/// impl FromRequestParts for ApiKey {
87///     fn from_request_parts(req: &Request) -> Result<Self> {
88///         if let Some(key) = req.headers().get("x-api-key") {
89///             if let Ok(key_str) = key.to_str() {
90///                 return Ok(ApiKey(key_str.to_string()));
91///             }
92///         }
93///         Err(ApiError::unauthorized("Missing or invalid API key"))
94///     }
95/// }
96/// ```
97pub trait FromRequestParts: Sized {
98    /// Extract from request parts
99    fn from_request_parts(req: &Request) -> Result<Self>;
100}
101
102/// Trait for extracting data from the full request (including body)
103///
104/// This is used for extractors that consume the request body.
105///
106/// # Example: Implementing a custom extractor that consumes the body
107///
108/// ```rust
109/// use rustapi_core::FromRequest;
110/// use rustapi_core::{Request, ApiError, Result};
111/// use std::future::Future;
112///
113/// struct PlainText(String);
114///
115/// impl FromRequest for PlainText {
116///     async fn from_request(req: &mut Request) -> Result<Self> {
117///         // Ensure body is loaded
118///         req.load_body().await?;
119///         
120///         // Consume the body
121///         if let Some(bytes) = req.take_body() {
122///             if let Ok(text) = String::from_utf8(bytes.to_vec()) {
123///                 return Ok(PlainText(text));
124///             }
125///         }
126///         
127///         Err(ApiError::bad_request("Invalid plain text body"))
128///     }
129/// }
130/// ```
131pub trait FromRequest: Sized {
132    /// Extract from the full request
133    fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
134}
135
136// Blanket impl: FromRequestParts -> FromRequest
137impl<T: FromRequestParts> FromRequest for T {
138    async fn from_request(req: &mut Request) -> Result<Self> {
139        T::from_request_parts(req)
140    }
141}
142
143/// JSON body extractor
144///
145/// Parses the request body as JSON and deserializes into type `T`.
146/// Also works as a response type when T: Serialize.
147///
148/// # Example
149///
150/// ```rust,ignore
151/// #[derive(Deserialize)]
152/// struct CreateUser {
153///     name: String,
154///     email: String,
155/// }
156///
157/// async fn create_user(Json(body): Json<CreateUser>) -> impl IntoResponse {
158///     // body is already deserialized
159/// }
160/// ```
161#[derive(Debug, Clone, Copy, Default)]
162pub struct Json<T>(pub T);
163
164impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
165    async fn from_request(req: &mut Request) -> Result<Self> {
166        req.load_body().await?;
167        let body = req
168            .take_body()
169            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
170
171        // Use simd-json accelerated parsing when available (2-4x faster)
172        let value: T = json::from_slice(&body)?;
173        Ok(Json(value))
174    }
175}
176
177impl<T> Deref for Json<T> {
178    type Target = T;
179
180    fn deref(&self) -> &Self::Target {
181        &self.0
182    }
183}
184
185impl<T> DerefMut for Json<T> {
186    fn deref_mut(&mut self) -> &mut Self::Target {
187        &mut self.0
188    }
189}
190
191impl<T> From<T> for Json<T> {
192    fn from(value: T) -> Self {
193        Json(value)
194    }
195}
196
197/// Default pre-allocation size for JSON response buffers (256 bytes)
198/// This covers most small to medium JSON responses without reallocation.
199const JSON_RESPONSE_INITIAL_CAPACITY: usize = 256;
200
201// IntoResponse for Json - allows using Json<T> as a return type
202impl<T: Serialize> IntoResponse for Json<T> {
203    fn into_response(self) -> crate::response::Response {
204        // Use pre-allocated buffer to reduce allocations
205        match json::to_vec_with_capacity(&self.0, JSON_RESPONSE_INITIAL_CAPACITY) {
206            Ok(body) => http::Response::builder()
207                .status(StatusCode::OK)
208                .header(header::CONTENT_TYPE, "application/json")
209                .body(crate::response::Body::from(body))
210                .unwrap(),
211            Err(err) => {
212                ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
213            }
214        }
215    }
216}
217
218/// Validated JSON body extractor
219///
220/// Parses the request body as JSON, deserializes into type `T`, and validates
221/// using the `Validate` trait. Returns a 422 Unprocessable Entity error with
222/// detailed field-level validation errors if validation fails.
223///
224/// # Example
225///
226/// ```rust,ignore
227/// use rustapi_rs::prelude::*;
228/// use validator::Validate;
229///
230/// #[derive(Deserialize, Validate)]
231/// struct CreateUser {
232///     #[validate(email)]
233///     email: String,
234///     #[validate(length(min = 8))]
235///     password: String,
236/// }
237///
238/// async fn register(ValidatedJson(body): ValidatedJson<CreateUser>) -> impl IntoResponse {
239///     // body is already validated!
240///     // If email is invalid or password too short, a 422 error is returned automatically
241/// }
242/// ```
243#[derive(Debug, Clone, Copy, Default)]
244pub struct ValidatedJson<T>(pub T);
245
246impl<T> ValidatedJson<T> {
247    /// Create a new ValidatedJson wrapper
248    pub fn new(value: T) -> Self {
249        Self(value)
250    }
251
252    /// Get the inner value
253    pub fn into_inner(self) -> T {
254        self.0
255    }
256}
257
258impl<T: DeserializeOwned + Validatable + Send> FromRequest for ValidatedJson<T> {
259    async fn from_request(req: &mut Request) -> Result<Self> {
260        req.load_body().await?;
261        // First, deserialize the JSON body using simd-json when available
262        let body = req
263            .take_body()
264            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
265
266        let value: T = json::from_slice(&body)?;
267
268        // Then, validate it using the unified Validatable trait
269        value.do_validate()?;
270
271        Ok(ValidatedJson(value))
272    }
273}
274
275impl<T> Deref for ValidatedJson<T> {
276    type Target = T;
277
278    fn deref(&self) -> &Self::Target {
279        &self.0
280    }
281}
282
283impl<T> DerefMut for ValidatedJson<T> {
284    fn deref_mut(&mut self) -> &mut Self::Target {
285        &mut self.0
286    }
287}
288
289impl<T> From<T> for ValidatedJson<T> {
290    fn from(value: T) -> Self {
291        ValidatedJson(value)
292    }
293}
294
295impl<T: Serialize> IntoResponse for ValidatedJson<T> {
296    fn into_response(self) -> crate::response::Response {
297        Json(self.0).into_response()
298    }
299}
300
301/// Async validated JSON body extractor
302///
303/// Parses the request body as JSON, deserializes into type `T`, and validates
304/// using the `AsyncValidate` trait from `rustapi-validate`.
305///
306/// This extractor supports async validation rules, such as database uniqueness checks.
307///
308/// # Example
309///
310/// ```rust,ignore
311/// use rustapi_rs::prelude::*;
312/// use rustapi_validate::v2::prelude::*;
313///
314/// #[derive(Deserialize, Validate, AsyncValidate)]
315/// struct CreateUser {
316///     #[validate(email)]
317///     email: String,
318///     
319///     #[validate(async_unique(table = "users", column = "email"))]
320///     username: String,
321/// }
322///
323/// async fn register(AsyncValidatedJson(body): AsyncValidatedJson<CreateUser>) -> impl IntoResponse {
324///     // body is validated asynchronously (e.g. checked existing email in DB)
325/// }
326/// ```
327#[derive(Debug, Clone, Copy, Default)]
328pub struct AsyncValidatedJson<T>(pub T);
329
330impl<T> AsyncValidatedJson<T> {
331    /// Create a new AsyncValidatedJson wrapper
332    pub fn new(value: T) -> Self {
333        Self(value)
334    }
335
336    /// Get the inner value
337    pub fn into_inner(self) -> T {
338        self.0
339    }
340}
341
342impl<T> Deref for AsyncValidatedJson<T> {
343    type Target = T;
344
345    fn deref(&self) -> &Self::Target {
346        &self.0
347    }
348}
349
350impl<T> DerefMut for AsyncValidatedJson<T> {
351    fn deref_mut(&mut self) -> &mut Self::Target {
352        &mut self.0
353    }
354}
355
356impl<T> From<T> for AsyncValidatedJson<T> {
357    fn from(value: T) -> Self {
358        AsyncValidatedJson(value)
359    }
360}
361
362impl<T: Serialize> IntoResponse for AsyncValidatedJson<T> {
363    fn into_response(self) -> crate::response::Response {
364        Json(self.0).into_response()
365    }
366}
367
368impl<T: DeserializeOwned + AsyncValidate + Send + Sync> FromRequest for AsyncValidatedJson<T> {
369    async fn from_request(req: &mut Request) -> Result<Self> {
370        req.load_body().await?;
371
372        let body = req
373            .take_body()
374            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
375
376        let value: T = json::from_slice(&body)?;
377
378        // Create validation context from request
379        // TODO: Extract validators from App State
380        let ctx = ValidationContext::default();
381
382        // Perform full validation (sync + async)
383        if let Err(errors) = value.validate_full(&ctx).await {
384            // Convert v2 ValidationErrors to ApiError
385            let field_errors: Vec<crate::error::FieldError> = errors
386                .fields
387                .iter()
388                .flat_map(|(field, errs)| {
389                    let field_name = field.to_string();
390                    errs.iter().map(move |e| crate::error::FieldError {
391                        field: field_name.clone(),
392                        code: e.code.to_string(),
393                        message: e.message.clone(),
394                    })
395                })
396                .collect();
397
398            return Err(ApiError::validation(field_errors));
399        }
400
401        Ok(AsyncValidatedJson(value))
402    }
403}
404
405/// Query string extractor
406///
407/// Parses the query string into type `T`.
408///
409/// # Example
410///
411/// ```rust,ignore
412/// #[derive(Deserialize)]
413/// struct Pagination {
414///     page: Option<u32>,
415///     limit: Option<u32>,
416/// }
417///
418/// async fn list_users(Query(params): Query<Pagination>) -> impl IntoResponse {
419///     // params.page, params.limit
420/// }
421/// ```
422#[derive(Debug, Clone)]
423pub struct Query<T>(pub T);
424
425impl<T: DeserializeOwned> FromRequestParts for Query<T> {
426    fn from_request_parts(req: &Request) -> Result<Self> {
427        let query = req.query_string().unwrap_or("");
428        let value: T = serde_urlencoded::from_str(query)
429            .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
430        Ok(Query(value))
431    }
432}
433
434impl<T> Deref for Query<T> {
435    type Target = T;
436
437    fn deref(&self) -> &Self::Target {
438        &self.0
439    }
440}
441
442/// Path parameter extractor
443///
444/// Extracts path parameters defined in the route pattern.
445///
446/// # Example
447///
448/// For route `/users/{id}`:
449///
450/// ```rust,ignore
451/// async fn get_user(Path(id): Path<i64>) -> impl IntoResponse {
452///     // id is extracted from path
453/// }
454/// ```
455///
456/// For multiple params `/users/{user_id}/posts/{post_id}`:
457///
458/// ```rust,ignore
459/// async fn get_post(Path((user_id, post_id)): Path<(i64, i64)>) -> impl IntoResponse {
460///     // Both params extracted
461/// }
462/// ```
463#[derive(Debug, Clone)]
464pub struct Path<T>(pub T);
465
466impl<T: FromStr> FromRequestParts for Path<T>
467where
468    T::Err: std::fmt::Display,
469{
470    fn from_request_parts(req: &Request) -> Result<Self> {
471        let params = req.path_params();
472
473        // For single param, get the first one
474        if let Some((_, value)) = params.iter().next() {
475            let parsed = value
476                .parse::<T>()
477                .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
478            return Ok(Path(parsed));
479        }
480
481        Err(ApiError::internal("Missing path parameter"))
482    }
483}
484
485impl<T> Deref for Path<T> {
486    type Target = T;
487
488    fn deref(&self) -> &Self::Target {
489        &self.0
490    }
491}
492
493/// Typed path extractor
494///
495/// Extracts path parameters and deserializes them into a struct implementing `Deserialize`.
496/// This is similar to `Path<T>`, but supports complex structs that can be deserialized
497/// from a map of parameter names to values (e.g. via `serde_json`).
498///
499/// # Example
500///
501/// ```rust,ignore
502/// #[derive(Deserialize)]
503/// struct UserParams {
504///     id: u64,
505///     category: String,
506/// }
507///
508/// async fn get_user(Typed(params): Typed<UserParams>) -> impl IntoResponse {
509///     // params.id, params.category
510/// }
511/// ```
512#[derive(Debug, Clone)]
513pub struct Typed<T>(pub T);
514
515impl<T: DeserializeOwned + Send> FromRequestParts for Typed<T> {
516    fn from_request_parts(req: &Request) -> Result<Self> {
517        let params = req.path_params();
518        let mut map = serde_json::Map::new();
519        for (k, v) in params.iter() {
520            map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
521        }
522        let value = serde_json::Value::Object(map);
523        let parsed: T = serde_json::from_value(value)
524            .map_err(|e| ApiError::bad_request(format!("Invalid path parameters: {}", e)))?;
525        Ok(Typed(parsed))
526    }
527}
528
529impl<T> Deref for Typed<T> {
530    type Target = T;
531
532    fn deref(&self) -> &Self::Target {
533        &self.0
534    }
535}
536
537/// State extractor
538///
539/// Extracts shared application state.
540///
541/// # Example
542///
543/// ```rust,ignore
544/// #[derive(Clone)]
545/// struct AppState {
546///     db: DbPool,
547/// }
548///
549/// async fn handler(State(state): State<AppState>) -> impl IntoResponse {
550///     // Use state.db
551/// }
552/// ```
553#[derive(Debug, Clone)]
554pub struct State<T>(pub T);
555
556impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
557    fn from_request_parts(req: &Request) -> Result<Self> {
558        req.state().get::<T>().cloned().map(State).ok_or_else(|| {
559            ApiError::internal(format!(
560                "State of type `{}` not found. Did you forget to call .state()?",
561                std::any::type_name::<T>()
562            ))
563        })
564    }
565}
566
567impl<T> Deref for State<T> {
568    type Target = T;
569
570    fn deref(&self) -> &Self::Target {
571        &self.0
572    }
573}
574
575/// Raw body bytes extractor
576#[derive(Debug, Clone)]
577pub struct Body(pub Bytes);
578
579impl FromRequest for Body {
580    async fn from_request(req: &mut Request) -> Result<Self> {
581        req.load_body().await?;
582        let body = req
583            .take_body()
584            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
585        Ok(Body(body))
586    }
587}
588
589impl Deref for Body {
590    type Target = Bytes;
591
592    fn deref(&self) -> &Self::Target {
593        &self.0
594    }
595}
596
597/// Streaming body extractor
598pub struct BodyStream(pub StreamingBody);
599
600impl FromRequest for BodyStream {
601    async fn from_request(req: &mut Request) -> Result<Self> {
602        let config = StreamingConfig::default();
603
604        if let Some(stream) = req.take_stream() {
605            Ok(BodyStream(StreamingBody::new(stream, config.max_body_size)))
606        } else if let Some(bytes) = req.take_body() {
607            // Handle buffered body as stream
608            let stream = futures_util::stream::once(async move { Ok(bytes) });
609            Ok(BodyStream(StreamingBody::from_stream(
610                stream,
611                config.max_body_size,
612            )))
613        } else {
614            Err(ApiError::internal("Body already consumed"))
615        }
616    }
617}
618
619impl Deref for BodyStream {
620    type Target = StreamingBody;
621
622    fn deref(&self) -> &Self::Target {
623        &self.0
624    }
625}
626
627impl DerefMut for BodyStream {
628    fn deref_mut(&mut self) -> &mut Self::Target {
629        &mut self.0
630    }
631}
632
633// Forward stream implementation
634impl futures_util::Stream for BodyStream {
635    type Item = Result<Bytes, ApiError>;
636
637    fn poll_next(
638        mut self: std::pin::Pin<&mut Self>,
639        cx: &mut std::task::Context<'_>,
640    ) -> std::task::Poll<Option<Self::Item>> {
641        std::pin::Pin::new(&mut self.0).poll_next(cx)
642    }
643}
644
645/// Optional extractor wrapper
646///
647/// Makes any extractor optional - returns None instead of error on failure.
648impl<T: FromRequestParts> FromRequestParts for Option<T> {
649    fn from_request_parts(req: &Request) -> Result<Self> {
650        Ok(T::from_request_parts(req).ok())
651    }
652}
653
654/// Headers extractor
655///
656/// Provides access to all request headers as a typed map.
657///
658/// # Example
659///
660/// ```rust,ignore
661/// use rustapi_core::extract::Headers;
662///
663/// async fn handler(headers: Headers) -> impl IntoResponse {
664///     if let Some(content_type) = headers.get("content-type") {
665///         format!("Content-Type: {:?}", content_type)
666///     } else {
667///         "No Content-Type header".to_string()
668///     }
669/// }
670/// ```
671#[derive(Debug, Clone)]
672pub struct Headers(pub http::HeaderMap);
673
674impl Headers {
675    /// Get a header value by name
676    pub fn get(&self, name: &str) -> Option<&http::HeaderValue> {
677        self.0.get(name)
678    }
679
680    /// Check if a header exists
681    pub fn contains(&self, name: &str) -> bool {
682        self.0.contains_key(name)
683    }
684
685    /// Get the number of headers
686    pub fn len(&self) -> usize {
687        self.0.len()
688    }
689
690    /// Check if headers are empty
691    pub fn is_empty(&self) -> bool {
692        self.0.is_empty()
693    }
694
695    /// Iterate over all headers
696    pub fn iter(&self) -> http::header::Iter<'_, http::HeaderValue> {
697        self.0.iter()
698    }
699}
700
701impl FromRequestParts for Headers {
702    fn from_request_parts(req: &Request) -> Result<Self> {
703        Ok(Headers(req.headers().clone()))
704    }
705}
706
707impl Deref for Headers {
708    type Target = http::HeaderMap;
709
710    fn deref(&self) -> &Self::Target {
711        &self.0
712    }
713}
714
715/// Single header value extractor
716///
717/// Extracts a specific header value by name. Returns an error if the header is missing.
718///
719/// # Example
720///
721/// ```rust,ignore
722/// use rustapi_core::extract::HeaderValue;
723///
724/// async fn handler(
725///     auth: HeaderValue<{ "authorization" }>,
726/// ) -> impl IntoResponse {
727///     format!("Auth header: {}", auth.0)
728/// }
729/// ```
730///
731/// Note: Due to Rust's const generics limitations, you may need to use the
732/// `HeaderValueOf` type alias or extract headers manually using the `Headers` extractor.
733#[derive(Debug, Clone)]
734pub struct HeaderValue(pub String, pub &'static str);
735
736impl HeaderValue {
737    /// Create a new HeaderValue extractor for a specific header name
738    pub fn new(name: &'static str, value: String) -> Self {
739        Self(value, name)
740    }
741
742    /// Get the header value
743    pub fn value(&self) -> &str {
744        &self.0
745    }
746
747    /// Get the header name
748    pub fn name(&self) -> &'static str {
749        self.1
750    }
751
752    /// Extract a specific header from a request
753    pub fn extract(req: &Request, name: &'static str) -> Result<Self> {
754        req.headers()
755            .get(name)
756            .and_then(|v| v.to_str().ok())
757            .map(|s| HeaderValue(s.to_string(), name))
758            .ok_or_else(|| ApiError::bad_request(format!("Missing required header: {}", name)))
759    }
760}
761
762impl Deref for HeaderValue {
763    type Target = String;
764
765    fn deref(&self) -> &Self::Target {
766        &self.0
767    }
768}
769
770/// Extension extractor
771///
772/// Retrieves typed data from request extensions that was inserted by middleware.
773///
774/// # Example
775///
776/// ```rust,ignore
777/// use rustapi_core::extract::Extension;
778///
779/// // Middleware inserts user data
780/// #[derive(Clone)]
781/// struct CurrentUser { id: i64 }
782///
783/// async fn handler(Extension(user): Extension<CurrentUser>) -> impl IntoResponse {
784///     format!("User ID: {}", user.id)
785/// }
786/// ```
787#[derive(Debug, Clone)]
788pub struct Extension<T>(pub T);
789
790impl<T: Clone + Send + Sync + 'static> FromRequestParts for Extension<T> {
791    fn from_request_parts(req: &Request) -> Result<Self> {
792        req.extensions()
793            .get::<T>()
794            .cloned()
795            .map(Extension)
796            .ok_or_else(|| {
797                ApiError::internal(format!(
798                    "Extension of type `{}` not found. Did middleware insert it?",
799                    std::any::type_name::<T>()
800                ))
801            })
802    }
803}
804
805impl<T> Deref for Extension<T> {
806    type Target = T;
807
808    fn deref(&self) -> &Self::Target {
809        &self.0
810    }
811}
812
813impl<T> DerefMut for Extension<T> {
814    fn deref_mut(&mut self) -> &mut Self::Target {
815        &mut self.0
816    }
817}
818
819/// Client IP address extractor
820///
821/// Extracts the client IP address from the request. When `trust_proxy` is enabled,
822/// it will use the `X-Forwarded-For` header if present.
823///
824/// # Example
825///
826/// ```rust,ignore
827/// use rustapi_core::extract::ClientIp;
828///
829/// async fn handler(ClientIp(ip): ClientIp) -> impl IntoResponse {
830///     format!("Your IP: {}", ip)
831/// }
832/// ```
833#[derive(Debug, Clone)]
834pub struct ClientIp(pub std::net::IpAddr);
835
836impl ClientIp {
837    /// Extract client IP, optionally trusting X-Forwarded-For header
838    pub fn extract_with_config(req: &Request, trust_proxy: bool) -> Result<Self> {
839        if trust_proxy {
840            // Try X-Forwarded-For header first
841            if let Some(forwarded) = req.headers().get("x-forwarded-for") {
842                if let Ok(forwarded_str) = forwarded.to_str() {
843                    // X-Forwarded-For can contain multiple IPs, take the first one
844                    if let Some(first_ip) = forwarded_str.split(',').next() {
845                        if let Ok(ip) = first_ip.trim().parse() {
846                            return Ok(ClientIp(ip));
847                        }
848                    }
849                }
850            }
851        }
852
853        // Fall back to socket address from extensions (if set by server)
854        if let Some(addr) = req.extensions().get::<std::net::SocketAddr>() {
855            return Ok(ClientIp(addr.ip()));
856        }
857
858        // Default to localhost if no IP information available
859        Ok(ClientIp(std::net::IpAddr::V4(std::net::Ipv4Addr::new(
860            127, 0, 0, 1,
861        ))))
862    }
863}
864
865impl FromRequestParts for ClientIp {
866    fn from_request_parts(req: &Request) -> Result<Self> {
867        // By default, trust proxy headers
868        Self::extract_with_config(req, true)
869    }
870}
871
872/// Cookies extractor
873///
874/// Parses and provides access to request cookies from the Cookie header.
875///
876/// # Example
877///
878/// ```rust,ignore
879/// use rustapi_core::extract::Cookies;
880///
881/// async fn handler(cookies: Cookies) -> impl IntoResponse {
882///     if let Some(session) = cookies.get("session_id") {
883///         format!("Session: {}", session.value())
884///     } else {
885///         "No session cookie".to_string()
886///     }
887/// }
888/// ```
889#[cfg(feature = "cookies")]
890#[derive(Debug, Clone)]
891pub struct Cookies(pub cookie::CookieJar);
892
893#[cfg(feature = "cookies")]
894impl Cookies {
895    /// Get a cookie by name
896    pub fn get(&self, name: &str) -> Option<&cookie::Cookie<'static>> {
897        self.0.get(name)
898    }
899
900    /// Iterate over all cookies
901    pub fn iter(&self) -> impl Iterator<Item = &cookie::Cookie<'static>> {
902        self.0.iter()
903    }
904
905    /// Check if a cookie exists
906    pub fn contains(&self, name: &str) -> bool {
907        self.0.get(name).is_some()
908    }
909}
910
911#[cfg(feature = "cookies")]
912impl FromRequestParts for Cookies {
913    fn from_request_parts(req: &Request) -> Result<Self> {
914        let mut jar = cookie::CookieJar::new();
915
916        if let Some(cookie_header) = req.headers().get(header::COOKIE) {
917            if let Ok(cookie_str) = cookie_header.to_str() {
918                // Parse each cookie from the header
919                for cookie_part in cookie_str.split(';') {
920                    let trimmed = cookie_part.trim();
921                    if !trimmed.is_empty() {
922                        if let Ok(cookie) = cookie::Cookie::parse(trimmed.to_string()) {
923                            jar.add_original(cookie.into_owned());
924                        }
925                    }
926                }
927            }
928        }
929
930        Ok(Cookies(jar))
931    }
932}
933
934#[cfg(feature = "cookies")]
935impl Deref for Cookies {
936    type Target = cookie::CookieJar;
937
938    fn deref(&self) -> &Self::Target {
939        &self.0
940    }
941}
942
943// Implement FromRequestParts for common primitive types (path params)
944macro_rules! impl_from_request_parts_for_primitives {
945    ($($ty:ty),*) => {
946        $(
947            impl FromRequestParts for $ty {
948                fn from_request_parts(req: &Request) -> Result<Self> {
949                    let Path(value) = Path::<$ty>::from_request_parts(req)?;
950                    Ok(value)
951                }
952            }
953        )*
954    };
955}
956
957impl_from_request_parts_for_primitives!(
958    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
959);
960
961// OperationModifier implementations for extractors
962
963use rustapi_openapi::utoipa_types::openapi;
964use rustapi_openapi::{
965    IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
966    ResponseSpec, Schema, SchemaRef,
967};
968use std::collections::HashMap;
969
970// ValidatedJson - Adds request body
971impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
972    fn update_operation(op: &mut Operation) {
973        let (name, _) = T::schema();
974
975        let schema_ref = SchemaRef::Ref {
976            reference: format!("#/components/schemas/{}", name),
977        };
978
979        let mut content = HashMap::new();
980        content.insert(
981            "application/json".to_string(),
982            MediaType { schema: schema_ref },
983        );
984
985        op.request_body = Some(RequestBody {
986            required: true,
987            content,
988        });
989
990        // Add 422 Validation Error response
991        op.responses.insert(
992            "422".to_string(),
993            ResponseSpec {
994                description: "Validation Error".to_string(),
995                content: {
996                    let mut map = HashMap::new();
997                    map.insert(
998                        "application/json".to_string(),
999                        MediaType {
1000                            schema: SchemaRef::Ref {
1001                                reference: "#/components/schemas/ValidationErrorSchema".to_string(),
1002                            },
1003                        },
1004                    );
1005                    Some(map)
1006                },
1007            },
1008        );
1009    }
1010}
1011
1012// Json - Adds request body (Same as ValidatedJson)
1013impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
1014    fn update_operation(op: &mut Operation) {
1015        let (name, _) = T::schema();
1016
1017        let schema_ref = SchemaRef::Ref {
1018            reference: format!("#/components/schemas/{}", name),
1019        };
1020
1021        let mut content = HashMap::new();
1022        content.insert(
1023            "application/json".to_string(),
1024            MediaType { schema: schema_ref },
1025        );
1026
1027        op.request_body = Some(RequestBody {
1028            required: true,
1029            content,
1030        });
1031    }
1032}
1033
1034// Path - Path parameters are automatically extracted from route patterns
1035// The add_path_params_to_operation function in app.rs handles OpenAPI documentation
1036// based on the {param} syntax in route paths (e.g., "/users/{id}")
1037impl<T> OperationModifier for Path<T> {
1038    fn update_operation(_op: &mut Operation) {
1039        // Path parameters are automatically documented by add_path_params_to_operation
1040        // in app.rs based on the route pattern. No additional implementation needed here.
1041        //
1042        // For typed path params, the schema type defaults to "string" but will be
1043        // inferred from the actual type T when more sophisticated type introspection
1044        // is implemented.
1045    }
1046}
1047
1048// Typed - Same as Path, parameters are documented by route pattern
1049impl<T> OperationModifier for Typed<T> {
1050    fn update_operation(_op: &mut Operation) {
1051        // No-op, managed by route registration
1052    }
1053}
1054
1055// Query - Extracts query params using IntoParams
1056impl<T: IntoParams> OperationModifier for Query<T> {
1057    fn update_operation(op: &mut Operation) {
1058        let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
1059
1060        let new_params: Vec<Parameter> = params
1061            .into_iter()
1062            .map(|p| {
1063                let schema = match p.schema {
1064                    Some(schema) => match schema {
1065                        openapi::RefOr::Ref(r) => SchemaRef::Ref {
1066                            reference: r.ref_location,
1067                        },
1068                        openapi::RefOr::T(s) => {
1069                            let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
1070                            SchemaRef::Inline(value)
1071                        }
1072                    },
1073                    None => SchemaRef::Inline(serde_json::Value::Null),
1074                };
1075
1076                let required = match p.required {
1077                    openapi::Required::True => true,
1078                    openapi::Required::False => false,
1079                };
1080
1081                Parameter {
1082                    name: p.name,
1083                    location: "query".to_string(), // explicitly query
1084                    required,
1085                    description: p.description,
1086                    schema,
1087                }
1088            })
1089            .collect();
1090
1091        if let Some(existing) = &mut op.parameters {
1092            existing.extend(new_params);
1093        } else {
1094            op.parameters = Some(new_params);
1095        }
1096    }
1097}
1098
1099// State - No op
1100impl<T> OperationModifier for State<T> {
1101    fn update_operation(_op: &mut Operation) {}
1102}
1103
1104// Body - Generic binary body
1105impl OperationModifier for Body {
1106    fn update_operation(op: &mut Operation) {
1107        let mut content = HashMap::new();
1108        content.insert(
1109            "application/octet-stream".to_string(),
1110            MediaType {
1111                schema: SchemaRef::Inline(
1112                    serde_json::json!({ "type": "string", "format": "binary" }),
1113                ),
1114            },
1115        );
1116
1117        op.request_body = Some(RequestBody {
1118            required: true,
1119            content,
1120        });
1121    }
1122}
1123
1124// BodyStream - Generic binary stream (Same as Body)
1125impl OperationModifier for BodyStream {
1126    fn update_operation(op: &mut Operation) {
1127        let mut content = HashMap::new();
1128        content.insert(
1129            "application/octet-stream".to_string(),
1130            MediaType {
1131                schema: SchemaRef::Inline(
1132                    serde_json::json!({ "type": "string", "format": "binary" }),
1133                ),
1134            },
1135        );
1136
1137        op.request_body = Some(RequestBody {
1138            required: true,
1139            content,
1140        });
1141    }
1142}
1143
1144// ResponseModifier implementations for extractors
1145
1146// Json<T> - 200 OK with schema T
1147impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
1148    fn update_response(op: &mut Operation) {
1149        let (name, _) = T::schema();
1150
1151        let schema_ref = SchemaRef::Ref {
1152            reference: format!("#/components/schemas/{}", name),
1153        };
1154
1155        op.responses.insert(
1156            "200".to_string(),
1157            ResponseSpec {
1158                description: "Successful response".to_string(),
1159                content: {
1160                    let mut map = HashMap::new();
1161                    map.insert(
1162                        "application/json".to_string(),
1163                        MediaType { schema: schema_ref },
1164                    );
1165                    Some(map)
1166                },
1167            },
1168        );
1169    }
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174    use super::*;
1175    use crate::path_params::PathParams;
1176    use bytes::Bytes;
1177    use http::{Extensions, Method};
1178    use proptest::prelude::*;
1179    use proptest::test_runner::TestCaseError;
1180    use std::sync::Arc;
1181
1182    /// Create a test request with the given method, path, and headers
1183    fn create_test_request_with_headers(
1184        method: Method,
1185        path: &str,
1186        headers: Vec<(&str, &str)>,
1187    ) -> Request {
1188        let uri: http::Uri = path.parse().unwrap();
1189        let mut builder = http::Request::builder().method(method).uri(uri);
1190
1191        for (name, value) in headers {
1192            builder = builder.header(name, value);
1193        }
1194
1195        let req = builder.body(()).unwrap();
1196        let (parts, _) = req.into_parts();
1197
1198        Request::new(
1199            parts,
1200            crate::request::BodyVariant::Buffered(Bytes::new()),
1201            Arc::new(Extensions::new()),
1202            PathParams::new(),
1203        )
1204    }
1205
1206    /// Create a test request with extensions
1207    fn create_test_request_with_extensions<T: Clone + Send + Sync + 'static>(
1208        method: Method,
1209        path: &str,
1210        extension: T,
1211    ) -> Request {
1212        let uri: http::Uri = path.parse().unwrap();
1213        let builder = http::Request::builder().method(method).uri(uri);
1214
1215        let req = builder.body(()).unwrap();
1216        let (mut parts, _) = req.into_parts();
1217        parts.extensions.insert(extension);
1218
1219        Request::new(
1220            parts,
1221            crate::request::BodyVariant::Buffered(Bytes::new()),
1222            Arc::new(Extensions::new()),
1223            PathParams::new(),
1224        )
1225    }
1226
1227    // **Feature: phase3-batteries-included, Property 14: Headers extractor completeness**
1228    //
1229    // For any request with headers H, the `Headers` extractor SHALL return a map
1230    // containing all key-value pairs in H.
1231    //
1232    // **Validates: Requirements 5.1**
1233    proptest! {
1234        #![proptest_config(ProptestConfig::with_cases(100))]
1235
1236        #[test]
1237        fn prop_headers_extractor_completeness(
1238            // Generate random header names and values
1239            // Using alphanumeric strings to ensure valid header names/values
1240            headers in prop::collection::vec(
1241                (
1242                    "[a-z][a-z0-9-]{0,20}",  // Valid header name pattern
1243                    "[a-zA-Z0-9 ]{1,50}"     // Valid header value pattern
1244                ),
1245                0..10
1246            )
1247        ) {
1248            let result: Result<(), TestCaseError> = (|| {
1249                // Convert to header tuples
1250                let header_tuples: Vec<(&str, &str)> = headers
1251                    .iter()
1252                    .map(|(k, v)| (k.as_str(), v.as_str()))
1253                    .collect();
1254
1255                // Create request with headers
1256                let request = create_test_request_with_headers(
1257                    Method::GET,
1258                    "/test",
1259                    header_tuples.clone(),
1260                );
1261
1262                // Extract headers
1263                let extracted = Headers::from_request_parts(&request)
1264                    .map_err(|e| TestCaseError::fail(format!("Failed to extract headers: {}", e)))?;
1265
1266                // Verify all original headers are present
1267                // HTTP allows duplicate headers - get_all() returns all values for a header name
1268                for (name, value) in &headers {
1269                    // Check that the header name exists
1270                    let all_values: Vec<_> = extracted.get_all(name.as_str()).iter().collect();
1271                    prop_assert!(
1272                        !all_values.is_empty(),
1273                        "Header '{}' not found",
1274                        name
1275                    );
1276
1277                    // Check that the value is among the extracted values
1278                    let value_found = all_values.iter().any(|v| {
1279                        v.to_str().map(|s| s == value.as_str()).unwrap_or(false)
1280                    });
1281
1282                    prop_assert!(
1283                        value_found,
1284                        "Header '{}' value '{}' not found in extracted values",
1285                        name,
1286                        value
1287                    );
1288                }
1289
1290                Ok(())
1291            })();
1292            result?;
1293        }
1294    }
1295
1296    // **Feature: phase3-batteries-included, Property 15: HeaderValue extractor correctness**
1297    //
1298    // For any request with header "X" having value V, `HeaderValue::extract(req, "X")` SHALL return V;
1299    // for requests without header "X", it SHALL return an error.
1300    //
1301    // **Validates: Requirements 5.2**
1302    proptest! {
1303        #![proptest_config(ProptestConfig::with_cases(100))]
1304
1305        #[test]
1306        fn prop_header_value_extractor_correctness(
1307            header_name in "[a-z][a-z0-9-]{0,20}",
1308            header_value in "[a-zA-Z0-9 ]{1,50}",
1309            has_header in prop::bool::ANY,
1310        ) {
1311            let result: Result<(), TestCaseError> = (|| {
1312                let headers = if has_header {
1313                    vec![(header_name.as_str(), header_value.as_str())]
1314                } else {
1315                    vec![]
1316                };
1317
1318                let _request = create_test_request_with_headers(Method::GET, "/test", headers);
1319
1320                // We need to use a static string for the header name in the extractor
1321                // So we'll test with a known header name
1322                let test_header = "x-test-header";
1323                let request_with_known_header = if has_header {
1324                    create_test_request_with_headers(
1325                        Method::GET,
1326                        "/test",
1327                        vec![(test_header, header_value.as_str())],
1328                    )
1329                } else {
1330                    create_test_request_with_headers(Method::GET, "/test", vec![])
1331                };
1332
1333                let result = HeaderValue::extract(&request_with_known_header, test_header);
1334
1335                if has_header {
1336                    let extracted = result
1337                        .map_err(|e| TestCaseError::fail(format!("Expected header to be found: {}", e)))?;
1338                    prop_assert_eq!(
1339                        extracted.value(),
1340                        header_value.as_str(),
1341                        "Header value mismatch"
1342                    );
1343                } else {
1344                    prop_assert!(
1345                        result.is_err(),
1346                        "Expected error when header is missing"
1347                    );
1348                }
1349
1350                Ok(())
1351            })();
1352            result?;
1353        }
1354    }
1355
1356    // **Feature: phase3-batteries-included, Property 17: ClientIp extractor with forwarding**
1357    //
1358    // For any request with socket IP S and X-Forwarded-For header F, when forwarding is enabled,
1359    // `ClientIp` SHALL return the first IP in F; when disabled, it SHALL return S.
1360    //
1361    // **Validates: Requirements 5.4**
1362    proptest! {
1363        #![proptest_config(ProptestConfig::with_cases(100))]
1364
1365        #[test]
1366        fn prop_client_ip_extractor_with_forwarding(
1367            // Generate valid IPv4 addresses
1368            forwarded_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1369                .prop_map(|(a, b, c, d)| format!("{}.{}.{}.{}", a, b, c, d)),
1370            socket_ip in (0u8..=255, 0u8..=255, 0u8..=255, 0u8..=255)
1371                .prop_map(|(a, b, c, d)| std::net::IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d))),
1372            has_forwarded_header in prop::bool::ANY,
1373            trust_proxy in prop::bool::ANY,
1374        ) {
1375            let result: Result<(), TestCaseError> = (|| {
1376                let headers = if has_forwarded_header {
1377                    vec![("x-forwarded-for", forwarded_ip.as_str())]
1378                } else {
1379                    vec![]
1380                };
1381
1382                // Create request with headers
1383                let uri: http::Uri = "/test".parse().unwrap();
1384                let mut builder = http::Request::builder().method(Method::GET).uri(uri);
1385                for (name, value) in &headers {
1386                    builder = builder.header(*name, *value);
1387                }
1388                let req = builder.body(()).unwrap();
1389                let (mut parts, _) = req.into_parts();
1390
1391                // Add socket address to extensions
1392                let socket_addr = std::net::SocketAddr::new(socket_ip, 8080);
1393                parts.extensions.insert(socket_addr);
1394
1395                let request = Request::new(
1396                    parts,
1397                    crate::request::BodyVariant::Buffered(Bytes::new()),
1398                    Arc::new(Extensions::new()),
1399                    PathParams::new(),
1400                );
1401
1402                let extracted = ClientIp::extract_with_config(&request, trust_proxy)
1403                    .map_err(|e| TestCaseError::fail(format!("Failed to extract ClientIp: {}", e)))?;
1404
1405                if trust_proxy && has_forwarded_header {
1406                    // Should use X-Forwarded-For
1407                    let expected_ip: std::net::IpAddr = forwarded_ip.parse()
1408                        .map_err(|e| TestCaseError::fail(format!("Invalid IP: {}", e)))?;
1409                    prop_assert_eq!(
1410                        extracted.0,
1411                        expected_ip,
1412                        "Should use X-Forwarded-For IP when trust_proxy is enabled"
1413                    );
1414                } else {
1415                    // Should use socket IP
1416                    prop_assert_eq!(
1417                        extracted.0,
1418                        socket_ip,
1419                        "Should use socket IP when trust_proxy is disabled or no X-Forwarded-For"
1420                    );
1421                }
1422
1423                Ok(())
1424            })();
1425            result?;
1426        }
1427    }
1428
1429    // **Feature: phase3-batteries-included, Property 18: Extension extractor retrieval**
1430    //
1431    // For any type T and value V inserted into request extensions by middleware,
1432    // `Extension<T>` SHALL return V.
1433    //
1434    // **Validates: Requirements 5.5**
1435    proptest! {
1436        #![proptest_config(ProptestConfig::with_cases(100))]
1437
1438        #[test]
1439        fn prop_extension_extractor_retrieval(
1440            value in any::<i64>(),
1441            has_extension in prop::bool::ANY,
1442        ) {
1443            let result: Result<(), TestCaseError> = (|| {
1444                // Create a simple wrapper type for testing
1445                #[derive(Clone, Debug, PartialEq)]
1446                struct TestExtension(i64);
1447
1448                let uri: http::Uri = "/test".parse().unwrap();
1449                let builder = http::Request::builder().method(Method::GET).uri(uri);
1450                let req = builder.body(()).unwrap();
1451                let (mut parts, _) = req.into_parts();
1452
1453                if has_extension {
1454                    parts.extensions.insert(TestExtension(value));
1455                }
1456
1457                let request = Request::new(
1458                    parts,
1459                    crate::request::BodyVariant::Buffered(Bytes::new()),
1460                    Arc::new(Extensions::new()),
1461                    PathParams::new(),
1462                );
1463
1464                let result = Extension::<TestExtension>::from_request_parts(&request);
1465
1466                if has_extension {
1467                    let extracted = result
1468                        .map_err(|e| TestCaseError::fail(format!("Expected extension to be found: {}", e)))?;
1469                    prop_assert_eq!(
1470                        extracted.0,
1471                        TestExtension(value),
1472                        "Extension value mismatch"
1473                    );
1474                } else {
1475                    prop_assert!(
1476                        result.is_err(),
1477                        "Expected error when extension is missing"
1478                    );
1479                }
1480
1481                Ok(())
1482            })();
1483            result?;
1484        }
1485    }
1486
1487    // Unit tests for basic functionality
1488
1489    #[test]
1490    fn test_headers_extractor_basic() {
1491        let request = create_test_request_with_headers(
1492            Method::GET,
1493            "/test",
1494            vec![
1495                ("content-type", "application/json"),
1496                ("accept", "text/html"),
1497            ],
1498        );
1499
1500        let headers = Headers::from_request_parts(&request).unwrap();
1501
1502        assert!(headers.contains("content-type"));
1503        assert!(headers.contains("accept"));
1504        assert!(!headers.contains("x-custom"));
1505        assert_eq!(headers.len(), 2);
1506    }
1507
1508    #[test]
1509    fn test_header_value_extractor_present() {
1510        let request = create_test_request_with_headers(
1511            Method::GET,
1512            "/test",
1513            vec![("authorization", "Bearer token123")],
1514        );
1515
1516        let result = HeaderValue::extract(&request, "authorization");
1517        assert!(result.is_ok());
1518        assert_eq!(result.unwrap().value(), "Bearer token123");
1519    }
1520
1521    #[test]
1522    fn test_header_value_extractor_missing() {
1523        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1524
1525        let result = HeaderValue::extract(&request, "authorization");
1526        assert!(result.is_err());
1527    }
1528
1529    #[test]
1530    fn test_client_ip_from_forwarded_header() {
1531        let request = create_test_request_with_headers(
1532            Method::GET,
1533            "/test",
1534            vec![("x-forwarded-for", "192.168.1.100, 10.0.0.1")],
1535        );
1536
1537        let ip = ClientIp::extract_with_config(&request, true).unwrap();
1538        assert_eq!(ip.0, "192.168.1.100".parse::<std::net::IpAddr>().unwrap());
1539    }
1540
1541    #[test]
1542    fn test_client_ip_ignores_forwarded_when_not_trusted() {
1543        let uri: http::Uri = "/test".parse().unwrap();
1544        let builder = http::Request::builder()
1545            .method(Method::GET)
1546            .uri(uri)
1547            .header("x-forwarded-for", "192.168.1.100");
1548        let req = builder.body(()).unwrap();
1549        let (mut parts, _) = req.into_parts();
1550
1551        let socket_addr = std::net::SocketAddr::new(
1552            std::net::IpAddr::V4(std::net::Ipv4Addr::new(10, 0, 0, 1)),
1553            8080,
1554        );
1555        parts.extensions.insert(socket_addr);
1556
1557        let request = Request::new(
1558            parts,
1559            crate::request::BodyVariant::Buffered(Bytes::new()),
1560            Arc::new(Extensions::new()),
1561            PathParams::new(),
1562        );
1563
1564        let ip = ClientIp::extract_with_config(&request, false).unwrap();
1565        assert_eq!(ip.0, "10.0.0.1".parse::<std::net::IpAddr>().unwrap());
1566    }
1567
1568    #[test]
1569    fn test_extension_extractor_present() {
1570        #[derive(Clone, Debug, PartialEq)]
1571        struct MyData(String);
1572
1573        let request =
1574            create_test_request_with_extensions(Method::GET, "/test", MyData("hello".to_string()));
1575
1576        let result = Extension::<MyData>::from_request_parts(&request);
1577        assert!(result.is_ok());
1578        assert_eq!(result.unwrap().0, MyData("hello".to_string()));
1579    }
1580
1581    #[test]
1582    fn test_extension_extractor_missing() {
1583        #[derive(Clone, Debug)]
1584        #[allow(dead_code)]
1585        struct MyData(String);
1586
1587        let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1588
1589        let result = Extension::<MyData>::from_request_parts(&request);
1590        assert!(result.is_err());
1591    }
1592
1593    // Cookies tests (feature-gated)
1594    #[cfg(feature = "cookies")]
1595    mod cookies_tests {
1596        use super::*;
1597
1598        // **Feature: phase3-batteries-included, Property 16: Cookies extractor parsing**
1599        //
1600        // For any request with Cookie header containing cookies C, the `Cookies` extractor
1601        // SHALL return a CookieJar containing exactly the cookies in C.
1602        // Note: Duplicate cookie names result in only the last value being kept.
1603        //
1604        // **Validates: Requirements 5.3**
1605        proptest! {
1606            #![proptest_config(ProptestConfig::with_cases(100))]
1607
1608            #[test]
1609            fn prop_cookies_extractor_parsing(
1610                // Generate random cookie names and values
1611                // Using alphanumeric strings to ensure valid cookie names/values
1612                cookies in prop::collection::vec(
1613                    (
1614                        "[a-zA-Z][a-zA-Z0-9_]{0,15}",  // Valid cookie name pattern
1615                        "[a-zA-Z0-9]{1,30}"            // Valid cookie value pattern (no special chars)
1616                    ),
1617                    0..5
1618                )
1619            ) {
1620                let result: Result<(), TestCaseError> = (|| {
1621                    // Build cookie header string
1622                    let cookie_header = cookies
1623                        .iter()
1624                        .map(|(name, value)| format!("{}={}", name, value))
1625                        .collect::<Vec<_>>()
1626                        .join("; ");
1627
1628                    let headers = if !cookies.is_empty() {
1629                        vec![("cookie", cookie_header.as_str())]
1630                    } else {
1631                        vec![]
1632                    };
1633
1634                    let request = create_test_request_with_headers(Method::GET, "/test", headers);
1635
1636                    // Extract cookies
1637                    let extracted = Cookies::from_request_parts(&request)
1638                        .map_err(|e| TestCaseError::fail(format!("Failed to extract cookies: {}", e)))?;
1639
1640                    // Build expected cookies map - last value wins for duplicate names
1641                    let mut expected_cookies: std::collections::HashMap<&str, &str> = std::collections::HashMap::new();
1642                    for (name, value) in &cookies {
1643                        expected_cookies.insert(name.as_str(), value.as_str());
1644                    }
1645
1646                    // Verify all expected cookies are present with correct values
1647                    for (name, expected_value) in &expected_cookies {
1648                        let cookie = extracted.get(name)
1649                            .ok_or_else(|| TestCaseError::fail(format!("Cookie '{}' not found", name)))?;
1650
1651                        prop_assert_eq!(
1652                            cookie.value(),
1653                            *expected_value,
1654                            "Cookie '{}' value mismatch",
1655                            name
1656                        );
1657                    }
1658
1659                    // Count cookies in jar should match unique cookie names
1660                    let extracted_count = extracted.iter().count();
1661                    prop_assert_eq!(
1662                        extracted_count,
1663                        expected_cookies.len(),
1664                        "Expected {} unique cookies, got {}",
1665                        expected_cookies.len(),
1666                        extracted_count
1667                    );
1668
1669                    Ok(())
1670                })();
1671                result?;
1672            }
1673        }
1674
1675        #[test]
1676        fn test_cookies_extractor_basic() {
1677            let request = create_test_request_with_headers(
1678                Method::GET,
1679                "/test",
1680                vec![("cookie", "session=abc123; user=john")],
1681            );
1682
1683            let cookies = Cookies::from_request_parts(&request).unwrap();
1684
1685            assert!(cookies.contains("session"));
1686            assert!(cookies.contains("user"));
1687            assert!(!cookies.contains("other"));
1688
1689            assert_eq!(cookies.get("session").unwrap().value(), "abc123");
1690            assert_eq!(cookies.get("user").unwrap().value(), "john");
1691        }
1692
1693        #[test]
1694        fn test_cookies_extractor_empty() {
1695            let request = create_test_request_with_headers(Method::GET, "/test", vec![]);
1696
1697            let cookies = Cookies::from_request_parts(&request).unwrap();
1698            assert_eq!(cookies.iter().count(), 0);
1699        }
1700
1701        #[test]
1702        fn test_cookies_extractor_single() {
1703            let request = create_test_request_with_headers(
1704                Method::GET,
1705                "/test",
1706                vec![("cookie", "token=xyz789")],
1707            );
1708
1709            let cookies = Cookies::from_request_parts(&request).unwrap();
1710            assert_eq!(cookies.iter().count(), 1);
1711            assert_eq!(cookies.get("token").unwrap().value(), "xyz789");
1712        }
1713    }
1714}