Skip to main content

vld_rocket/
lib.rs

1//! # vld-rocket — Rocket integration for `vld`
2//!
3//! Validation extractors for [Rocket](https://rocket.rs/). Validates request
4//! data against `vld` schemas and returns `422 Unprocessable Entity` with
5//! structured JSON errors on failure.
6//!
7//! # Extractors
8//!
9//! | Extractor | Source | Rocket equivalent |
10//! |-----------|--------|-------------------|
11//! | `VldJson<T>` | JSON body | `rocket::serde::json::Json<T>` |
12//! | `VldQuery<T>` | Query string | query params |
13//! | `VldForm<T>` | Form body | `rocket::form::Form<T>` |
14//!
15//! # Error catcher
16//!
17//! Register [`vld_catcher()`] to get JSON error responses instead of the
18//! default HTML:
19//!
20//! ```rust,ignore
21//! rocket::build()
22//!     .mount("/", routes![...])
23//!     .register("/", catchers![vld_rocket::vld_422_catcher])
24//! ```
25
26use rocket::data::{Data, FromData, Outcome as DataOutcome};
27use rocket::http::Status;
28use rocket::request::{FromRequest, Outcome, Request};
29use rocket::serde::json::Json;
30use std::ops::{Deref, DerefMut};
31use vld::schema::VldParse;
32
33// ---------------------------------------------------------------------------
34// Request-local error storage
35// ---------------------------------------------------------------------------
36
37/// Stored in request local cache so the catcher can read it.
38#[derive(Debug, Clone, Default)]
39pub struct VldErrorCache(pub Option<serde_json::Value>);
40
41fn store_error(req: &Request<'_>, err: serde_json::Value) {
42    // Rocket's local_cache returns &T, setting a value requires the closure pattern
43    let _ = req.local_cache(|| VldErrorCache(Some(err.clone())));
44}
45
46// ---------------------------------------------------------------------------
47// VldJson<T> — validated JSON body
48// ---------------------------------------------------------------------------
49
50/// Validated JSON body extractor.
51///
52/// Reads the request body as JSON, validates via `T::vld_parse_value()`,
53/// and returns `422` with error details on failure.
54#[derive(Debug, Clone)]
55pub struct VldJson<T>(pub T);
56
57impl<T> Deref for VldJson<T> {
58    type Target = T;
59    fn deref(&self) -> &T {
60        &self.0
61    }
62}
63
64impl<T> DerefMut for VldJson<T> {
65    fn deref_mut(&mut self) -> &mut T {
66        &mut self.0
67    }
68}
69
70#[rocket::async_trait]
71impl<'r, T: VldParse + Send + 'static> FromData<'r> for VldJson<T> {
72    type Error = serde_json::Value;
73
74    async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
75        let json_outcome = <Json<serde_json::Value> as FromData<'r>>::from_data(req, data).await;
76        let value = match json_outcome {
77            DataOutcome::Success(json) => json.into_inner(),
78            DataOutcome::Error((status, e)) => {
79                let body = serde_json::json!({
80                    "error": "Invalid JSON",
81                    "message": format!("{e}"),
82                });
83                store_error(req, body.clone());
84                return DataOutcome::Error((status, body));
85            }
86            DataOutcome::Forward(f) => return DataOutcome::Forward(f),
87        };
88
89        match T::vld_parse_value(&value) {
90            Ok(parsed) => DataOutcome::Success(VldJson(parsed)),
91            Err(vld_err) => {
92                let body = format_vld_error(&vld_err);
93                store_error(req, body.clone());
94                DataOutcome::Error((Status::UnprocessableEntity, body))
95            }
96        }
97    }
98}
99
100// ---------------------------------------------------------------------------
101// VldQuery<T> — validated query string
102// ---------------------------------------------------------------------------
103
104/// Validated query string extractor.
105///
106/// Parses query parameters into a JSON object (coercing string values),
107/// validates via `T::vld_parse_value()`.
108#[derive(Debug, Clone)]
109pub struct VldQuery<T>(pub T);
110
111impl<T> Deref for VldQuery<T> {
112    type Target = T;
113    fn deref(&self) -> &T {
114        &self.0
115    }
116}
117
118impl<T> DerefMut for VldQuery<T> {
119    fn deref_mut(&mut self) -> &mut T {
120        &mut self.0
121    }
122}
123
124#[rocket::async_trait]
125impl<'r, T: VldParse + Send + Sync + 'static> FromRequest<'r> for VldQuery<T> {
126    type Error = serde_json::Value;
127
128    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
129        let qs = req.uri().query().map(|q| q.as_str()).unwrap_or("");
130        let map = parse_query_to_json(qs);
131        let value = serde_json::Value::Object(map);
132
133        match T::vld_parse_value(&value) {
134            Ok(parsed) => Outcome::Success(VldQuery(parsed)),
135            Err(vld_err) => {
136                let body = format_vld_error(&vld_err);
137                store_error(req, body.clone());
138                Outcome::Error((Status::UnprocessableEntity, body))
139            }
140        }
141    }
142}
143
144// ---------------------------------------------------------------------------
145// VldForm<T> — validated form body
146// ---------------------------------------------------------------------------
147
148/// Validated form body extractor.
149///
150/// Reads `application/x-www-form-urlencoded` body, parses into a JSON object
151/// (coercing values), and validates via `T::vld_parse_value()`.
152#[derive(Debug, Clone)]
153pub struct VldForm<T>(pub T);
154
155impl<T> Deref for VldForm<T> {
156    type Target = T;
157    fn deref(&self) -> &T {
158        &self.0
159    }
160}
161
162impl<T> DerefMut for VldForm<T> {
163    fn deref_mut(&mut self) -> &mut T {
164        &mut self.0
165    }
166}
167
168#[rocket::async_trait]
169impl<'r, T: VldParse + Send + 'static> FromData<'r> for VldForm<T> {
170    type Error = serde_json::Value;
171
172    async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
173        use rocket::data::ToByteUnit;
174        let bytes = match data.open(1.mebibytes()).into_bytes().await {
175            Ok(b) if b.is_complete() => b.into_inner(),
176            _ => {
177                let body = serde_json::json!({"error": "Payload too large"});
178                store_error(req, body.clone());
179                return DataOutcome::Error((Status::PayloadTooLarge, body));
180            }
181        };
182
183        let body_str = match String::from_utf8(bytes) {
184            Ok(s) => s,
185            Err(_) => {
186                let body = serde_json::json!({"error": "Invalid UTF-8"});
187                store_error(req, body.clone());
188                return DataOutcome::Error((Status::BadRequest, body));
189            }
190        };
191
192        let map = parse_query_to_json(&body_str);
193        let value = serde_json::Value::Object(map);
194
195        match T::vld_parse_value(&value) {
196            Ok(parsed) => DataOutcome::Success(VldForm(parsed)),
197            Err(vld_err) => {
198                let body = format_vld_error(&vld_err);
199                store_error(req, body.clone());
200                DataOutcome::Error((Status::UnprocessableEntity, body))
201            }
202        }
203    }
204}
205
206// ---------------------------------------------------------------------------
207// Error catcher
208// ---------------------------------------------------------------------------
209
210/// Catcher for `422 Unprocessable Entity` that returns JSON from the
211/// validation error stored by vld extractors.
212///
213/// Register in your Rocket application:
214///
215/// ```rust,ignore
216/// rocket::build()
217///     .register("/", catchers![vld_rocket::vld_422_catcher])
218/// ```
219#[rocket::catch(422)]
220pub fn vld_422_catcher(req: &Request<'_>) -> (Status, Json<serde_json::Value>) {
221    let cached = req.local_cache(|| VldErrorCache(None));
222    let body = cached
223        .0
224        .clone()
225        .unwrap_or_else(|| serde_json::json!({"error": "Unprocessable Entity"}));
226    (Status::UnprocessableEntity, Json(body))
227}
228
229/// Catcher for `400 Bad Request` that returns JSON.
230#[rocket::catch(400)]
231pub fn vld_400_catcher(req: &Request<'_>) -> (Status, Json<serde_json::Value>) {
232    let cached = req.local_cache(|| VldErrorCache(None));
233    let body = cached
234        .0
235        .clone()
236        .unwrap_or_else(|| serde_json::json!({"error": "Bad Request"}));
237    (Status::BadRequest, Json(body))
238}
239
240// ---------------------------------------------------------------------------
241// Helpers
242// ---------------------------------------------------------------------------
243
244use vld_http_common::{format_vld_error, parse_query_string as parse_query_to_json};
245
246/// Prelude — import everything you need.
247pub mod prelude {
248    pub use crate::{vld_400_catcher, vld_422_catcher, VldForm, VldJson, VldQuery};
249    pub use vld::prelude::*;
250}