Skip to main content

vld_rocket/
lib.rs

1//! # vld-rocket — Rocket integration for `vld`
2//!
3//! Validation extractors for [Rocket](https://rocket.rs/). Validates request
4//! data against `vld` schemas and returns `422 Unprocessable Entity` with
5//! structured JSON errors on failure.
6//!
7//! # Extractors
8//!
9//! | Extractor | Source | Rocket equivalent |
10//! |-----------|--------|-------------------|
11//! | `VldJson<T>` | JSON body | `rocket::serde::json::Json<T>` |
12//! | `VldQuery<T>` | Query string | query params |
13//! | `VldPath<T>` | Path segments | `<param>` segments |
14//! | `VldForm<T>` | Form body | `rocket::form::Form<T>` |
15//! | `VldHeaders<T>` | HTTP headers | manual extraction |
16//! | `VldCookie<T>` | Cookie values | `CookieJar` |
17//!
18//! # Error catcher
19//!
20//! Register [`vld_catcher()`] to get JSON error responses instead of the
21//! default HTML:
22//!
23//! ```rust,ignore
24//! rocket::build()
25//!     .mount("/", routes![...])
26//!     .register("/", catchers![vld_rocket::vld_422_catcher])
27//! ```
28
29use rocket::data::{Data, FromData, Outcome as DataOutcome};
30use rocket::http::Status;
31use rocket::request::{FromRequest, Outcome, Request};
32use rocket::serde::json::Json;
33use std::ops::{Deref, DerefMut};
34use vld::schema::VldParse;
35use vld_http_common::{
36    coerce_value, cookies_to_json, format_vld_error, parse_query_string as parse_query_to_json,
37};
38
39// ---------------------------------------------------------------------------
40// Request-local error storage
41// ---------------------------------------------------------------------------
42
43/// Stored in request local cache so the catcher can read it.
44#[derive(Debug, Clone, Default)]
45pub struct VldErrorCache(pub Option<serde_json::Value>);
46
47fn store_error(req: &Request<'_>, err: serde_json::Value) {
48    // Rocket's local_cache returns &T, setting a value requires the closure pattern
49    let _ = req.local_cache(|| VldErrorCache(Some(err.clone())));
50}
51
52// ---------------------------------------------------------------------------
53// VldJson<T> — validated JSON body
54// ---------------------------------------------------------------------------
55
56/// Validated JSON body extractor.
57///
58/// Reads the request body as JSON, validates via `T::vld_parse_value()`,
59/// and returns `422` with error details on failure.
60#[derive(Debug, Clone)]
61pub struct VldJson<T>(pub T);
62
63impl<T> Deref for VldJson<T> {
64    type Target = T;
65    fn deref(&self) -> &T {
66        &self.0
67    }
68}
69
70impl<T> DerefMut for VldJson<T> {
71    fn deref_mut(&mut self) -> &mut T {
72        &mut self.0
73    }
74}
75
76#[rocket::async_trait]
77impl<'r, T: VldParse + Send + 'static> FromData<'r> for VldJson<T> {
78    type Error = serde_json::Value;
79
80    async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
81        let json_outcome = <Json<serde_json::Value> as FromData<'r>>::from_data(req, data).await;
82        let value = match json_outcome {
83            DataOutcome::Success(json) => json.into_inner(),
84            DataOutcome::Error((status, e)) => {
85                let body = vld_http_common::format_json_parse_error(&format!("{e}"));
86                store_error(req, body.clone());
87                return DataOutcome::Error((status, body));
88            }
89            DataOutcome::Forward(f) => return DataOutcome::Forward(f),
90        };
91
92        match T::vld_parse_value(&value) {
93            Ok(parsed) => DataOutcome::Success(VldJson(parsed)),
94            Err(vld_err) => {
95                let body = format_vld_error(&vld_err);
96                store_error(req, body.clone());
97                DataOutcome::Error((Status::UnprocessableEntity, body))
98            }
99        }
100    }
101}
102
103// ---------------------------------------------------------------------------
104// VldQuery<T> — validated query string
105// ---------------------------------------------------------------------------
106
107/// Validated query string extractor.
108///
109/// Parses query parameters into a JSON object (coercing string values),
110/// validates via `T::vld_parse_value()`.
111#[derive(Debug, Clone)]
112pub struct VldQuery<T>(pub T);
113
114impl<T> Deref for VldQuery<T> {
115    type Target = T;
116    fn deref(&self) -> &T {
117        &self.0
118    }
119}
120
121impl<T> DerefMut for VldQuery<T> {
122    fn deref_mut(&mut self) -> &mut T {
123        &mut self.0
124    }
125}
126
127#[rocket::async_trait]
128impl<'r, T: VldParse + Send + Sync + 'static> FromRequest<'r> for VldQuery<T> {
129    type Error = serde_json::Value;
130
131    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
132        let qs = req.uri().query().map(|q| q.as_str()).unwrap_or("");
133        let map = parse_query_to_json(qs);
134        let value = serde_json::Value::Object(map);
135
136        match T::vld_parse_value(&value) {
137            Ok(parsed) => Outcome::Success(VldQuery(parsed)),
138            Err(vld_err) => {
139                let body = format_vld_error(&vld_err);
140                store_error(req, body.clone());
141                Outcome::Error((Status::UnprocessableEntity, body))
142            }
143        }
144    }
145}
146
147// ---------------------------------------------------------------------------
148// VldForm<T> — validated form body
149// ---------------------------------------------------------------------------
150
151/// Validated form body extractor.
152///
153/// Reads `application/x-www-form-urlencoded` body, parses into a JSON object
154/// (coercing values), and validates via `T::vld_parse_value()`.
155#[derive(Debug, Clone)]
156pub struct VldForm<T>(pub T);
157
158impl<T> Deref for VldForm<T> {
159    type Target = T;
160    fn deref(&self) -> &T {
161        &self.0
162    }
163}
164
165impl<T> DerefMut for VldForm<T> {
166    fn deref_mut(&mut self) -> &mut T {
167        &mut self.0
168    }
169}
170
171#[rocket::async_trait]
172impl<'r, T: VldParse + Send + 'static> FromData<'r> for VldForm<T> {
173    type Error = serde_json::Value;
174
175    async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
176        use rocket::data::ToByteUnit;
177        let bytes = match data.open(1.mebibytes()).into_bytes().await {
178            Ok(b) if b.is_complete() => b.into_inner(),
179            _ => {
180                let body = vld_http_common::format_payload_too_large();
181                store_error(req, body.clone());
182                return DataOutcome::Error((Status::PayloadTooLarge, body));
183            }
184        };
185
186        let body_str = match String::from_utf8(bytes) {
187            Ok(s) => s,
188            Err(_) => {
189                let body = vld_http_common::format_utf8_error();
190                store_error(req, body.clone());
191                return DataOutcome::Error((Status::BadRequest, body));
192            }
193        };
194
195        let map = parse_query_to_json(&body_str);
196        let value = serde_json::Value::Object(map);
197
198        match T::vld_parse_value(&value) {
199            Ok(parsed) => DataOutcome::Success(VldForm(parsed)),
200            Err(vld_err) => {
201                let body = format_vld_error(&vld_err);
202                store_error(req, body.clone());
203                DataOutcome::Error((Status::UnprocessableEntity, body))
204            }
205        }
206    }
207}
208
209// ---------------------------------------------------------------------------
210// Error catcher
211// ---------------------------------------------------------------------------
212
213/// Catcher for `422 Unprocessable Entity` that returns JSON from the
214/// validation error stored by vld extractors.
215///
216/// Register in your Rocket application:
217///
218/// ```rust,ignore
219/// rocket::build()
220///     .register("/", catchers![vld_rocket::vld_422_catcher])
221/// ```
222#[rocket::catch(422)]
223pub fn vld_422_catcher(req: &Request<'_>) -> (Status, Json<serde_json::Value>) {
224    let cached = req.local_cache(|| VldErrorCache(None));
225    let body = cached
226        .0
227        .clone()
228        .unwrap_or_else(|| vld_http_common::format_generic_error("Unprocessable Entity"));
229    (Status::UnprocessableEntity, Json(body))
230}
231
232/// Catcher for `400 Bad Request` that returns JSON.
233#[rocket::catch(400)]
234pub fn vld_400_catcher(req: &Request<'_>) -> (Status, Json<serde_json::Value>) {
235    let cached = req.local_cache(|| VldErrorCache(None));
236    let body = cached
237        .0
238        .clone()
239        .unwrap_or_else(|| vld_http_common::format_generic_error("Bad Request"));
240    (Status::BadRequest, Json(body))
241}
242
243// ---------------------------------------------------------------------------
244// VldPath<T> — validated path parameters
245// ---------------------------------------------------------------------------
246
247/// Validated path parameter extractor for Rocket.
248///
249/// Extracts named path segments and validates via `T::vld_parse_value()`.
250/// Path values are coerced: `"42"` → number, `"true"` → bool, etc.
251///
252/// Use Rocket's `<param>` syntax to define path parameters.
253/// The struct field names must match the parameter names.
254#[derive(Debug, Clone)]
255pub struct VldPath<T>(pub T);
256
257impl<T> Deref for VldPath<T> {
258    type Target = T;
259    fn deref(&self) -> &T {
260        &self.0
261    }
262}
263
264impl<T> DerefMut for VldPath<T> {
265    fn deref_mut(&mut self) -> &mut T {
266        &mut self.0
267    }
268}
269
270#[rocket::async_trait]
271impl<'r, T: VldParse + Send + Sync + 'static> FromRequest<'r> for VldPath<T> {
272    type Error = serde_json::Value;
273
274    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
275        let mut map = serde_json::Map::new();
276
277        // Rocket exposes route segments in match_info via routed_segments
278        // We iterate over the raw segments from the uri
279        for (i, seg) in req.routed_segments(0..).enumerate() {
280            // Try to get param name from route if available
281            let key = format!("{}", i);
282            let _ = key; // fallback
283            map.insert(seg.to_string(), coerce_value(seg));
284        }
285
286        // Better approach: use named query params from Rocket's param API
287        // Since Rocket doesn't expose route pattern names easily,
288        // we'll read all segments as positional values and also try
289        // to extract by common names from the route's dynamic segments
290        let mut named_map = serde_json::Map::new();
291
292        // Extract each dynamic param by trying common names
293        // Rocket stores route segments; we can access them by index
294        let segments: Vec<&str> = req.routed_segments(0..).collect();
295        if let Some(route) = req.route() {
296            let uri_str = route.uri.origin.path().as_str();
297            let mut param_idx = 0;
298            for part in uri_str.split('/') {
299                if part.starts_with('<') && part.ends_with('>') {
300                    let name = part
301                        .trim_start_matches('<')
302                        .trim_end_matches('>')
303                        .trim_end_matches("..");
304                    if let Some(&seg_value) = segments.get(param_idx) {
305                        named_map.insert(name.to_string(), coerce_value(seg_value));
306                    }
307                    param_idx += 1;
308                } else if !part.is_empty() {
309                    param_idx += 1;
310                }
311            }
312        }
313
314        let value = serde_json::Value::Object(named_map);
315
316        match T::vld_parse_value(&value) {
317            Ok(parsed) => Outcome::Success(VldPath(parsed)),
318            Err(vld_err) => {
319                let body = format_vld_error(&vld_err);
320                store_error(req, body.clone());
321                Outcome::Error((Status::UnprocessableEntity, body))
322            }
323        }
324    }
325}
326
327// ---------------------------------------------------------------------------
328// VldHeaders<T> — validated HTTP headers
329// ---------------------------------------------------------------------------
330
331/// Validated HTTP headers extractor for Rocket.
332///
333/// Header names are normalised to snake_case: `Content-Type` → `content_type`.
334/// Values are coerced: `"42"` → number, `"true"` → bool, etc.
335#[derive(Debug, Clone)]
336pub struct VldHeaders<T>(pub T);
337
338impl<T> Deref for VldHeaders<T> {
339    type Target = T;
340    fn deref(&self) -> &T {
341        &self.0
342    }
343}
344
345impl<T> DerefMut for VldHeaders<T> {
346    fn deref_mut(&mut self) -> &mut T {
347        &mut self.0
348    }
349}
350
351#[rocket::async_trait]
352impl<'r, T: VldParse + Send + Sync + 'static> FromRequest<'r> for VldHeaders<T> {
353    type Error = serde_json::Value;
354
355    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
356        let mut map = serde_json::Map::new();
357
358        for header in req.headers().iter() {
359            let key = header.name().as_str().to_lowercase().replace('-', "_");
360            map.insert(key, coerce_value(header.value()));
361        }
362
363        let value = serde_json::Value::Object(map);
364
365        match T::vld_parse_value(&value) {
366            Ok(parsed) => Outcome::Success(VldHeaders(parsed)),
367            Err(vld_err) => {
368                let body = format_vld_error(&vld_err);
369                store_error(req, body.clone());
370                Outcome::Error((Status::UnprocessableEntity, body))
371            }
372        }
373    }
374}
375
376// ---------------------------------------------------------------------------
377// VldCookie<T> — validated cookies
378// ---------------------------------------------------------------------------
379
380/// Validated cookie extractor for Rocket.
381///
382/// Reads cookies from the `Cookie` header and validates against the schema.
383/// Cookie names are used as-is for field matching.
384#[derive(Debug, Clone)]
385pub struct VldCookie<T>(pub T);
386
387impl<T> Deref for VldCookie<T> {
388    type Target = T;
389    fn deref(&self) -> &T {
390        &self.0
391    }
392}
393
394impl<T> DerefMut for VldCookie<T> {
395    fn deref_mut(&mut self) -> &mut T {
396        &mut self.0
397    }
398}
399
400#[rocket::async_trait]
401impl<'r, T: VldParse + Send + Sync + 'static> FromRequest<'r> for VldCookie<T> {
402    type Error = serde_json::Value;
403
404    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
405        let cookie_header = req.headers().get_one("Cookie").unwrap_or("");
406
407        let value = cookies_to_json(cookie_header);
408
409        match T::vld_parse_value(&value) {
410            Ok(parsed) => Outcome::Success(VldCookie(parsed)),
411            Err(vld_err) => {
412                let body = format_vld_error(&vld_err);
413                store_error(req, body.clone());
414                Outcome::Error((Status::UnprocessableEntity, body))
415            }
416        }
417    }
418}
419
420/// Prelude — import everything you need.
421pub mod prelude {
422    pub use crate::{
423        vld_400_catcher, vld_422_catcher, VldCookie, VldForm, VldHeaders, VldJson, VldPath,
424        VldQuery,
425    };
426    pub use vld::prelude::*;
427}