rustapi_core/
extract.rs

1//! Extractors for RustAPI
2//!
3//! Extractors automatically parse and validate data from incoming requests.
4
5use crate::error::{ApiError, Result};
6use crate::request::Request;
7use crate::response::IntoResponse;
8use bytes::Bytes;
9use http::{header, StatusCode};
10use http_body_util::Full;
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use std::future::Future;
14use std::ops::{Deref, DerefMut};
15use std::str::FromStr;
16
17/// Trait for extracting data from request parts (headers, path, query)
18///
19/// This is used for extractors that don't need the request body.
20pub trait FromRequestParts: Sized {
21    /// Extract from request parts
22    fn from_request_parts(req: &Request) -> Result<Self>;
23}
24
25/// Trait for extracting data from the full request (including body)
26///
27/// This is used for extractors that consume the request body.
28pub trait FromRequest: Sized {
29    /// Extract from the full request
30    fn from_request(req: &mut Request) -> impl Future<Output = Result<Self>> + Send;
31}
32
33// Blanket impl: FromRequestParts -> FromRequest
34impl<T: FromRequestParts> FromRequest for T {
35    async fn from_request(req: &mut Request) -> Result<Self> {
36        T::from_request_parts(req)
37    }
38}
39
40/// JSON body extractor
41///
42/// Parses the request body as JSON and deserializes into type `T`.
43/// Also works as a response type when T: Serialize.
44///
45/// # Example
46///
47/// ```rust,ignore
48/// #[derive(Deserialize)]
49/// struct CreateUser {
50///     name: String,
51///     email: String,
52/// }
53///
54/// async fn create_user(Json(body): Json<CreateUser>) -> impl IntoResponse {
55///     // body is already deserialized
56/// }
57/// ```
58#[derive(Debug, Clone, Copy, Default)]
59pub struct Json<T>(pub T);
60
61impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
62    async fn from_request(req: &mut Request) -> Result<Self> {
63        let body = req
64            .take_body()
65            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
66
67        let value: T = serde_json::from_slice(&body)?;
68        Ok(Json(value))
69    }
70}
71
72impl<T> Deref for Json<T> {
73    type Target = T;
74
75    fn deref(&self) -> &Self::Target {
76        &self.0
77    }
78}
79
80impl<T> DerefMut for Json<T> {
81    fn deref_mut(&mut self) -> &mut Self::Target {
82        &mut self.0
83    }
84}
85
86impl<T> From<T> for Json<T> {
87    fn from(value: T) -> Self {
88        Json(value)
89    }
90}
91
92// IntoResponse for Json - allows using Json<T> as a return type
93impl<T: Serialize> IntoResponse for Json<T> {
94    fn into_response(self) -> crate::response::Response {
95        match serde_json::to_vec(&self.0) {
96            Ok(body) => http::Response::builder()
97                .status(StatusCode::OK)
98                .header(header::CONTENT_TYPE, "application/json")
99                .body(Full::new(Bytes::from(body)))
100                .unwrap(),
101            Err(err) => {
102                ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
103            }
104        }
105    }
106}
107
108/// Validated JSON body extractor
109///
110/// Parses the request body as JSON, deserializes into type `T`, and validates
111/// using the `Validate` trait. Returns a 422 Unprocessable Entity error with
112/// detailed field-level validation errors if validation fails.
113///
114/// # Example
115///
116/// ```rust,ignore
117/// use rustapi_rs::prelude::*;
118/// use validator::Validate;
119///
120/// #[derive(Deserialize, Validate)]
121/// struct CreateUser {
122///     #[validate(email)]
123///     email: String,
124///     #[validate(length(min = 8))]
125///     password: String,
126/// }
127///
128/// async fn register(ValidatedJson(body): ValidatedJson<CreateUser>) -> impl IntoResponse {
129///     // body is already validated!
130///     // If email is invalid or password too short, a 422 error is returned automatically
131/// }
132/// ```
133#[derive(Debug, Clone, Copy, Default)]
134pub struct ValidatedJson<T>(pub T);
135
136impl<T> ValidatedJson<T> {
137    /// Create a new ValidatedJson wrapper
138    pub fn new(value: T) -> Self {
139        Self(value)
140    }
141
142    /// Get the inner value
143    pub fn into_inner(self) -> T {
144        self.0
145    }
146}
147
148impl<T: DeserializeOwned + rustapi_validate::Validate + Send> FromRequest for ValidatedJson<T> {
149    async fn from_request(req: &mut Request) -> Result<Self> {
150        // First, deserialize the JSON body
151        let body = req
152            .take_body()
153            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
154
155        let value: T = serde_json::from_slice(&body)?;
156
157        // Then, validate it
158        if let Err(validation_error) = rustapi_validate::Validate::validate(&value) {
159            // Convert validation error to API error with 422 status
160            return Err(validation_error.into());
161        }
162
163        Ok(ValidatedJson(value))
164    }
165}
166
167impl<T> Deref for ValidatedJson<T> {
168    type Target = T;
169
170    fn deref(&self) -> &Self::Target {
171        &self.0
172    }
173}
174
175impl<T> DerefMut for ValidatedJson<T> {
176    fn deref_mut(&mut self) -> &mut Self::Target {
177        &mut self.0
178    }
179}
180
181impl<T> From<T> for ValidatedJson<T> {
182    fn from(value: T) -> Self {
183        ValidatedJson(value)
184    }
185}
186
187impl<T: Serialize> IntoResponse for ValidatedJson<T> {
188    fn into_response(self) -> crate::response::Response {
189        Json(self.0).into_response()
190    }
191}
192
193/// Query string extractor
194///
195/// Parses the query string into type `T`.
196///
197/// # Example
198///
199/// ```rust,ignore
200/// #[derive(Deserialize)]
201/// struct Pagination {
202///     page: Option<u32>,
203///     limit: Option<u32>,
204/// }
205///
206/// async fn list_users(Query(params): Query<Pagination>) -> impl IntoResponse {
207///     // params.page, params.limit
208/// }
209/// ```
210#[derive(Debug, Clone)]
211pub struct Query<T>(pub T);
212
213impl<T: DeserializeOwned> FromRequestParts for Query<T> {
214    fn from_request_parts(req: &Request) -> Result<Self> {
215        let query = req.query_string().unwrap_or("");
216        let value: T = serde_urlencoded::from_str(query)
217            .map_err(|e| ApiError::bad_request(format!("Invalid query string: {}", e)))?;
218        Ok(Query(value))
219    }
220}
221
222impl<T> Deref for Query<T> {
223    type Target = T;
224
225    fn deref(&self) -> &Self::Target {
226        &self.0
227    }
228}
229
230/// Path parameter extractor
231///
232/// Extracts path parameters defined in the route pattern.
233///
234/// # Example
235///
236/// For route `/users/{id}`:
237///
238/// ```rust,ignore
239/// async fn get_user(Path(id): Path<i64>) -> impl IntoResponse {
240///     // id is extracted from path
241/// }
242/// ```
243///
244/// For multiple params `/users/{user_id}/posts/{post_id}`:
245///
246/// ```rust,ignore
247/// async fn get_post(Path((user_id, post_id)): Path<(i64, i64)>) -> impl IntoResponse {
248///     // Both params extracted
249/// }
250/// ```
251#[derive(Debug, Clone)]
252pub struct Path<T>(pub T);
253
254impl<T: FromStr> FromRequestParts for Path<T>
255where
256    T::Err: std::fmt::Display,
257{
258    fn from_request_parts(req: &Request) -> Result<Self> {
259        let params = req.path_params();
260
261        // For single param, get the first one
262        if let Some((_, value)) = params.iter().next() {
263            let parsed = value
264                .parse::<T>()
265                .map_err(|e| ApiError::bad_request(format!("Invalid path parameter: {}", e)))?;
266            return Ok(Path(parsed));
267        }
268
269        Err(ApiError::internal("Missing path parameter"))
270    }
271}
272
273impl<T> Deref for Path<T> {
274    type Target = T;
275
276    fn deref(&self) -> &Self::Target {
277        &self.0
278    }
279}
280
281/// State extractor
282///
283/// Extracts shared application state.
284///
285/// # Example
286///
287/// ```rust,ignore
288/// #[derive(Clone)]
289/// struct AppState {
290///     db: DbPool,
291/// }
292///
293/// async fn handler(State(state): State<AppState>) -> impl IntoResponse {
294///     // Use state.db
295/// }
296/// ```
297#[derive(Debug, Clone)]
298pub struct State<T>(pub T);
299
300impl<T: Clone + Send + Sync + 'static> FromRequestParts for State<T> {
301    fn from_request_parts(req: &Request) -> Result<Self> {
302        req.state().get::<T>().cloned().map(State).ok_or_else(|| {
303            ApiError::internal(format!(
304                "State of type `{}` not found. Did you forget to call .state()?",
305                std::any::type_name::<T>()
306            ))
307        })
308    }
309}
310
311impl<T> Deref for State<T> {
312    type Target = T;
313
314    fn deref(&self) -> &Self::Target {
315        &self.0
316    }
317}
318
319/// Raw body bytes extractor
320#[derive(Debug, Clone)]
321pub struct Body(pub Bytes);
322
323impl FromRequest for Body {
324    async fn from_request(req: &mut Request) -> Result<Self> {
325        let body = req
326            .take_body()
327            .ok_or_else(|| ApiError::internal("Body already consumed"))?;
328        Ok(Body(body))
329    }
330}
331
332impl Deref for Body {
333    type Target = Bytes;
334
335    fn deref(&self) -> &Self::Target {
336        &self.0
337    }
338}
339
340/// Optional extractor wrapper
341///
342/// Makes any extractor optional - returns None instead of error on failure.
343impl<T: FromRequestParts> FromRequestParts for Option<T> {
344    fn from_request_parts(req: &Request) -> Result<Self> {
345        Ok(T::from_request_parts(req).ok())
346    }
347}
348
349// Implement FromRequestParts for common primitive types (path params)
350macro_rules! impl_from_request_parts_for_primitives {
351    ($($ty:ty),*) => {
352        $(
353            impl FromRequestParts for $ty {
354                fn from_request_parts(req: &Request) -> Result<Self> {
355                    let Path(value) = Path::<$ty>::from_request_parts(req)?;
356                    Ok(value)
357                }
358            }
359        )*
360    };
361}
362
363impl_from_request_parts_for_primitives!(
364    i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, String
365);
366
367// OperationModifier implementations for extractors
368
369use rustapi_openapi::utoipa_types::openapi;
370use rustapi_openapi::{
371    IntoParams, MediaType, Operation, OperationModifier, Parameter, RequestBody, ResponseModifier,
372    ResponseSpec, Schema, SchemaRef,
373};
374use std::collections::HashMap;
375
376// ValidatedJson - Adds request body
377impl<T: for<'a> Schema<'a>> OperationModifier for ValidatedJson<T> {
378    fn update_operation(op: &mut Operation) {
379        let (name, _) = T::schema();
380
381        let schema_ref = SchemaRef::Ref {
382            reference: format!("#/components/schemas/{}", name),
383        };
384
385        let mut content = HashMap::new();
386        content.insert(
387            "application/json".to_string(),
388            MediaType { schema: schema_ref },
389        );
390
391        op.request_body = Some(RequestBody {
392            required: true,
393            content,
394        });
395
396        // Add 422 Validation Error response
397        op.responses.insert(
398            "422".to_string(),
399            ResponseSpec {
400                description: "Validation Error".to_string(),
401                content: {
402                    let mut map = HashMap::new();
403                    map.insert(
404                        "application/json".to_string(),
405                        MediaType {
406                            schema: SchemaRef::Ref {
407                                reference: "#/components/schemas/ValidationErrorSchema".to_string(),
408                            },
409                        },
410                    );
411                    Some(map)
412                },
413                ..Default::default()
414            },
415        );
416    }
417}
418
419// Json - Adds request body (Same as ValidatedJson)
420impl<T: for<'a> Schema<'a>> OperationModifier for Json<T> {
421    fn update_operation(op: &mut Operation) {
422        let (name, _) = T::schema();
423
424        let schema_ref = SchemaRef::Ref {
425            reference: format!("#/components/schemas/{}", name),
426        };
427
428        let mut content = HashMap::new();
429        content.insert(
430            "application/json".to_string(),
431            MediaType { schema: schema_ref },
432        );
433
434        op.request_body = Some(RequestBody {
435            required: true,
436            content,
437        });
438    }
439}
440
441// Path - Placeholder for path params
442impl<T> OperationModifier for Path<T> {
443    fn update_operation(_op: &mut Operation) {
444        // TODO: Implement path param extraction
445    }
446}
447
448// Query - Extracts query params using IntoParams
449impl<T: IntoParams> OperationModifier for Query<T> {
450    fn update_operation(op: &mut Operation) {
451        let params = T::into_params(|| Some(openapi::path::ParameterIn::Query));
452
453        let new_params: Vec<Parameter> = params
454            .into_iter()
455            .map(|p| {
456                let schema = match p.schema {
457                    Some(schema) => match schema {
458                        openapi::RefOr::Ref(r) => SchemaRef::Ref {
459                            reference: r.ref_location,
460                        },
461                        openapi::RefOr::T(s) => {
462                            let value = serde_json::to_value(s).unwrap_or(serde_json::Value::Null);
463                            SchemaRef::Inline(value)
464                        }
465                    },
466                    None => SchemaRef::Inline(serde_json::Value::Null),
467                };
468
469                let required = match p.required {
470                    openapi::Required::True => true,
471                    openapi::Required::False => false,
472                };
473
474                Parameter {
475                    name: p.name,
476                    location: "query".to_string(), // explicitly query
477                    required,
478                    description: p.description,
479                    schema,
480                }
481            })
482            .collect();
483
484        if let Some(existing) = &mut op.parameters {
485            existing.extend(new_params);
486        } else {
487            op.parameters = Some(new_params);
488        }
489    }
490}
491
492// State - No op
493impl<T> OperationModifier for State<T> {
494    fn update_operation(_op: &mut Operation) {}
495}
496
497// Body - Generic binary body
498impl OperationModifier for Body {
499    fn update_operation(op: &mut Operation) {
500        let mut content = HashMap::new();
501        content.insert(
502            "application/octet-stream".to_string(),
503            MediaType {
504                schema: SchemaRef::Inline(
505                    serde_json::json!({ "type": "string", "format": "binary" }),
506                ),
507            },
508        );
509
510        op.request_body = Some(RequestBody {
511            required: true,
512            content,
513        });
514    }
515}
516
517// ResponseModifier implementations for extractors
518
519// Json<T> - 200 OK with schema T
520impl<T: for<'a> Schema<'a>> ResponseModifier for Json<T> {
521    fn update_response(op: &mut Operation) {
522        let (name, _) = T::schema();
523
524        let schema_ref = SchemaRef::Ref {
525            reference: format!("#/components/schemas/{}", name),
526        };
527
528        op.responses.insert(
529            "200".to_string(),
530            ResponseSpec {
531                description: "Successful response".to_string(),
532                content: {
533                    let mut map = HashMap::new();
534                    map.insert(
535                        "application/json".to_string(),
536                        MediaType { schema: schema_ref },
537                    );
538                    Some(map)
539                },
540                ..Default::default()
541            },
542        );
543    }
544}