vld_tower/lib.rs
1//! # vld-tower — Tower middleware for `vld` validation
2//!
3//! A universal [`tower::Layer`] that validates incoming HTTP JSON request
4//! bodies against a `vld` schema. Works with **any** Tower-compatible
5//! framework: Axum, Hyper, Tonic, Warp, etc.
6//!
7//! On **success** the validated struct is stored in
8//! [`http::Request::extensions`] so downstream handlers can retrieve it
9//! without re-parsing. The original body bytes are forwarded as-is.
10//!
11//! On **failure** a `422 Unprocessable Entity` JSON response is returned
12//! immediately — the inner service is never called.
13//!
14//! # Quick Start (with Axum)
15//!
16//! ```rust,no_run
17//! use vld::prelude::*;
18//! use vld_tower::ValidateJsonLayer;
19//!
20//! vld::schema! {
21//! #[derive(Debug, Clone)]
22//! pub struct CreateUser {
23//! pub name: String => vld::string().min(2).max(100),
24//! pub email: String => vld::string().email(),
25//! }
26//! }
27//!
28//! // Apply as a layer — works with any Tower-based router
29//! // let app = Router::new()
30//! // .route("/users", post(handler))
31//! // .layer(ValidateJsonLayer::<CreateUser>::new());
32//! ```
33
34use bytes::Bytes;
35use http::{Request, Response, StatusCode};
36use http_body::Body;
37use http_body_util::BodyExt;
38use std::future::Future;
39use std::marker::PhantomData;
40use std::pin::Pin;
41use std::task::{Context, Poll};
42use vld::schema::VldParse;
43
44// ---------------------------------------------------------------------------
45// Layer
46// ---------------------------------------------------------------------------
47
48/// A [`tower_layer::Layer`] that validates JSON request bodies with `vld`.
49///
50/// The type parameter `T` is the validated struct (must implement
51/// [`VldParse`] + [`Clone`] + [`Send`] + [`Sync`] + `'static`).
52///
53/// # Behaviour
54///
55/// 1. Reads the full request body.
56/// 2. Parses as JSON and validates via `T::vld_parse_value()`.
57/// 3. **Valid** — inserts `T` into request extensions, re-attaches the
58/// body bytes, and calls the inner service.
59/// 4. **Invalid** — returns `422 Unprocessable Entity` with a JSON body
60/// containing the validation errors. The inner service is **not** called.
61///
62/// Requests without `Content-Type: application/json` (or missing content
63/// type) are **passed through** without validation.
64#[derive(Clone)]
65pub struct ValidateJsonLayer<T> {
66 _marker: PhantomData<fn() -> T>,
67}
68
69impl<T> ValidateJsonLayer<T> {
70 /// Create a new validation layer.
71 pub fn new() -> Self {
72 Self {
73 _marker: PhantomData,
74 }
75 }
76}
77
78impl<T> Default for ValidateJsonLayer<T> {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl<S, T> tower_layer::Layer<S> for ValidateJsonLayer<T> {
85 type Service = ValidateJsonService<S, T>;
86
87 fn layer(&self, inner: S) -> Self::Service {
88 ValidateJsonService {
89 inner,
90 _marker: PhantomData,
91 }
92 }
93}
94
95// ---------------------------------------------------------------------------
96// Service
97// ---------------------------------------------------------------------------
98
99/// The middleware [`Service`](tower_service::Service) created by
100/// [`ValidateJsonLayer`].
101#[derive(Clone)]
102pub struct ValidateJsonService<S, T> {
103 inner: S,
104 _marker: PhantomData<fn() -> T>,
105}
106
107impl<S, T, ReqBody, ResBody> tower_service::Service<Request<ReqBody>> for ValidateJsonService<S, T>
108where
109 S: tower_service::Service<Request<http_body_util::Full<Bytes>>, Response = Response<ResBody>>
110 + Clone
111 + Send
112 + 'static,
113 S::Future: Send + 'static,
114 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
115 ReqBody: Body + Send + 'static,
116 ReqBody::Data: Send,
117 ReqBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
118 ResBody: From<http_body_util::Full<Bytes>> + Send + 'static,
119 T: VldParse + Clone + Send + Sync + 'static,
120{
121 type Response = Response<ResBody>;
122 type Error = Box<dyn std::error::Error + Send + Sync>;
123 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
124
125 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126 self.inner.poll_ready(cx).map_err(Into::into)
127 }
128
129 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
130 let mut inner = self.inner.clone();
131 // Swap so `self` is ready for next call (standard Tower pattern)
132 std::mem::swap(&mut self.inner, &mut inner);
133
134 Box::pin(async move {
135 let is_json = req
136 .headers()
137 .get(http::header::CONTENT_TYPE)
138 .and_then(|v| v.to_str().ok())
139 .map(|ct| ct.starts_with("application/json"))
140 .unwrap_or(false);
141
142 if !is_json {
143 // Pass through non-JSON requests untouched
144 let (parts, body) = req.into_parts();
145 let bytes = body
146 .collect()
147 .await
148 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { e.into() })?
149 .to_bytes();
150 let new_req = Request::from_parts(parts, http_body_util::Full::new(bytes));
151 return inner.call(new_req).await.map_err(Into::into);
152 }
153
154 // Collect body bytes
155 let (parts, body) = req.into_parts();
156 let bytes = body
157 .collect()
158 .await
159 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { e.into() })?
160 .to_bytes();
161
162 // Parse JSON
163 let json_value: serde_json::Value = match serde_json::from_slice(&bytes) {
164 Ok(v) => v,
165 Err(e) => {
166 let error_body = serde_json::json!({
167 "error": "Invalid JSON",
168 "message": e.to_string(),
169 });
170 let resp = Response::builder()
171 .status(StatusCode::BAD_REQUEST)
172 .header(http::header::CONTENT_TYPE, "application/json")
173 .body(ResBody::from(http_body_util::Full::new(Bytes::from(
174 serde_json::to_vec(&error_body).unwrap_or_default(),
175 ))))
176 .unwrap();
177 return Ok(resp);
178 }
179 };
180
181 // Validate with vld
182 match T::vld_parse_value(&json_value) {
183 Ok(validated) => {
184 let mut new_req = Request::from_parts(parts, http_body_util::Full::new(bytes));
185 // Store validated struct in extensions
186 new_req.extensions_mut().insert(validated);
187 inner.call(new_req).await.map_err(Into::into)
188 }
189 Err(vld_err) => {
190 let issues: Vec<serde_json::Value> = vld_err
191 .issues
192 .iter()
193 .map(|issue| {
194 serde_json::json!({
195 "path": issue.path.iter()
196 .map(|p| p.to_string())
197 .collect::<Vec<_>>()
198 .join("."),
199 "message": issue.message,
200 })
201 })
202 .collect();
203
204 let error_body = serde_json::json!({
205 "error": "Validation failed",
206 "issues": issues,
207 });
208
209 let resp = Response::builder()
210 .status(StatusCode::UNPROCESSABLE_ENTITY)
211 .header(http::header::CONTENT_TYPE, "application/json")
212 .body(ResBody::from(http_body_util::Full::new(Bytes::from(
213 serde_json::to_vec(&error_body).unwrap_or_default(),
214 ))))
215 .unwrap();
216 Ok(resp)
217 }
218 }
219 })
220 }
221}
222
223// ---------------------------------------------------------------------------
224// Helper: extract validated value from request extensions
225// ---------------------------------------------------------------------------
226
227/// Extract the validated value from request extensions.
228///
229/// The [`ValidateJsonService`] middleware stores the parsed and validated
230/// struct in the request's extensions map. Use this function (or
231/// `req.extensions().get::<T>()` directly) to retrieve it.
232///
233/// # Panics
234///
235/// Panics if `T` is not present in extensions (i.e. the middleware was
236/// not applied).
237pub fn validated<T: Clone + Send + Sync + 'static>(req: &Request<impl Body>) -> T {
238 req.extensions()
239 .get::<T>()
240 .expect(
241 "vld-tower: validated value not found in request extensions. \
242 Make sure ValidateJsonLayer is applied.",
243 )
244 .clone()
245}
246
247/// Try to extract the validated value from request extensions.
248///
249/// Returns `None` if the middleware was not applied or the value type
250/// doesn't match.
251pub fn try_validated<T: Clone + Send + Sync + 'static>(req: &Request<impl Body>) -> Option<T> {
252 req.extensions().get::<T>().cloned()
253}
254
255/// Prelude — import everything you need.
256pub mod prelude {
257 pub use crate::{try_validated, validated, ValidateJsonLayer, ValidateJsonService};
258}