vld_tower/lib.rs
1//! # vld-tower — Tower middleware for `vld` validation
2//!
3//! A universal [`tower::Layer`] that validates incoming HTTP JSON request
4//! bodies against a `vld` schema. Works with **any** Tower-compatible
5//! framework: Axum, Hyper, Tonic, Warp, etc.
6//!
7//! On **success** the validated struct is stored in
8//! [`http::Request::extensions`] so downstream handlers can retrieve it
9//! without re-parsing. The original body bytes are forwarded as-is.
10//!
11//! On **failure** a `422 Unprocessable Entity` JSON response is returned
12//! immediately — the inner service is never called.
13//!
14//! # Quick Start (with Axum)
15//!
16//! ```rust,no_run
17//! use vld::prelude::*;
18//! use vld_tower::ValidateJsonLayer;
19//!
20//! vld::schema! {
21//! #[derive(Debug, Clone)]
22//! pub struct CreateUser {
23//! pub name: String => vld::string().min(2).max(100),
24//! pub email: String => vld::string().email(),
25//! }
26//! }
27//!
28//! // Apply as a layer — works with any Tower-based router
29//! // let app = Router::new()
30//! // .route("/users", post(handler))
31//! // .layer(ValidateJsonLayer::<CreateUser>::new());
32//! ```
33
34use bytes::Bytes;
35use http::{Request, Response, StatusCode};
36use http_body::Body;
37use http_body_util::BodyExt;
38use std::future::Future;
39use std::marker::PhantomData;
40use std::pin::Pin;
41use std::task::{Context, Poll};
42use vld::schema::VldParse;
43
44// ---------------------------------------------------------------------------
45// Layer
46// ---------------------------------------------------------------------------
47
48/// A [`tower_layer::Layer`] that validates JSON request bodies with `vld`.
49///
50/// The type parameter `T` is the validated struct (must implement
51/// [`VldParse`] + [`Clone`] + [`Send`] + [`Sync`] + `'static`).
52///
53/// # Behaviour
54///
55/// 1. Reads the full request body.
56/// 2. Parses as JSON and validates via `T::vld_parse_value()`.
57/// 3. **Valid** — inserts `T` into request extensions, re-attaches the
58/// body bytes, and calls the inner service.
59/// 4. **Invalid** — returns `422 Unprocessable Entity` with a JSON body
60/// containing the validation errors. The inner service is **not** called.
61///
62/// Requests without `Content-Type: application/json` (or missing content
63/// type) are **passed through** without validation.
64#[derive(Clone)]
65pub struct ValidateJsonLayer<T> {
66 _marker: PhantomData<fn() -> T>,
67}
68
69impl<T> ValidateJsonLayer<T> {
70 /// Create a new validation layer.
71 pub fn new() -> Self {
72 Self {
73 _marker: PhantomData,
74 }
75 }
76}
77
78impl<T> Default for ValidateJsonLayer<T> {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl<S, T> tower_layer::Layer<S> for ValidateJsonLayer<T> {
85 type Service = ValidateJsonService<S, T>;
86
87 fn layer(&self, inner: S) -> Self::Service {
88 ValidateJsonService {
89 inner,
90 _marker: PhantomData,
91 }
92 }
93}
94
95// ---------------------------------------------------------------------------
96// Service
97// ---------------------------------------------------------------------------
98
99/// The middleware [`Service`](tower_service::Service) created by
100/// [`ValidateJsonLayer`].
101#[derive(Clone)]
102pub struct ValidateJsonService<S, T> {
103 inner: S,
104 _marker: PhantomData<fn() -> T>,
105}
106
107impl<S, T, ReqBody, ResBody> tower_service::Service<Request<ReqBody>> for ValidateJsonService<S, T>
108where
109 S: tower_service::Service<Request<http_body_util::Full<Bytes>>, Response = Response<ResBody>>
110 + Clone
111 + Send
112 + 'static,
113 S::Future: Send + 'static,
114 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
115 ReqBody: Body + Send + 'static,
116 ReqBody::Data: Send,
117 ReqBody::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
118 ResBody: From<http_body_util::Full<Bytes>> + Send + 'static,
119 T: VldParse + Clone + Send + Sync + 'static,
120{
121 type Response = Response<ResBody>;
122 type Error = Box<dyn std::error::Error + Send + Sync>;
123 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
124
125 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126 self.inner.poll_ready(cx).map_err(Into::into)
127 }
128
129 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
130 let mut inner = self.inner.clone();
131 // Swap so `self` is ready for next call (standard Tower pattern)
132 std::mem::swap(&mut self.inner, &mut inner);
133
134 Box::pin(async move {
135 let is_json = req
136 .headers()
137 .get(http::header::CONTENT_TYPE)
138 .and_then(|v| v.to_str().ok())
139 .map(|ct| ct.starts_with("application/json"))
140 .unwrap_or(false);
141
142 if !is_json {
143 // Pass through non-JSON requests untouched
144 let (parts, body) = req.into_parts();
145 let bytes = body
146 .collect()
147 .await
148 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { e.into() })?
149 .to_bytes();
150 let new_req = Request::from_parts(parts, http_body_util::Full::new(bytes));
151 return inner.call(new_req).await.map_err(Into::into);
152 }
153
154 // Collect body bytes
155 let (parts, body) = req.into_parts();
156 let bytes = body
157 .collect()
158 .await
159 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { e.into() })?
160 .to_bytes();
161
162 // Parse JSON
163 let json_value: serde_json::Value = match serde_json::from_slice(&bytes) {
164 Ok(v) => v,
165 Err(e) => {
166 let error_body = vld_http_common::format_json_parse_error(&e.to_string());
167 let resp = Response::builder()
168 .status(StatusCode::BAD_REQUEST)
169 .header(http::header::CONTENT_TYPE, "application/json")
170 .body(ResBody::from(http_body_util::Full::new(Bytes::from(
171 serde_json::to_vec(&error_body).unwrap_or_default(),
172 ))))
173 .unwrap();
174 return Ok(resp);
175 }
176 };
177
178 // Validate with vld
179 match T::vld_parse_value(&json_value) {
180 Ok(validated) => {
181 let mut new_req = Request::from_parts(parts, http_body_util::Full::new(bytes));
182 // Store validated struct in extensions
183 new_req.extensions_mut().insert(validated);
184 inner.call(new_req).await.map_err(Into::into)
185 }
186 Err(vld_err) => {
187 let error_body = vld_http_common::format_vld_error(&vld_err);
188
189 let resp = Response::builder()
190 .status(StatusCode::UNPROCESSABLE_ENTITY)
191 .header(http::header::CONTENT_TYPE, "application/json")
192 .body(ResBody::from(http_body_util::Full::new(Bytes::from(
193 serde_json::to_vec(&error_body).unwrap_or_default(),
194 ))))
195 .unwrap();
196 Ok(resp)
197 }
198 }
199 })
200 }
201}
202
203// ---------------------------------------------------------------------------
204// Helper: extract validated value from request extensions
205// ---------------------------------------------------------------------------
206
207/// Extract the validated value from request extensions.
208///
209/// The [`ValidateJsonService`] middleware stores the parsed and validated
210/// struct in the request's extensions map. Use this function (or
211/// `req.extensions().get::<T>()` directly) to retrieve it.
212///
213/// # Panics
214///
215/// Panics if `T` is not present in extensions (i.e. the middleware was
216/// not applied).
217pub fn validated<T: Clone + Send + Sync + 'static>(req: &Request<impl Body>) -> T {
218 req.extensions()
219 .get::<T>()
220 .expect(
221 "vld-tower: validated value not found in request extensions. \
222 Make sure ValidateJsonLayer is applied.",
223 )
224 .clone()
225}
226
227/// Try to extract the validated value from request extensions.
228///
229/// Returns `None` if the middleware was not applied or the value type
230/// doesn't match.
231pub fn try_validated<T: Clone + Send + Sync + 'static>(req: &Request<impl Body>) -> Option<T> {
232 req.extensions().get::<T>().cloned()
233}
234
235/// Prelude — import everything you need.
236pub mod prelude {
237 pub use crate::{try_validated, validated, ValidateJsonLayer, ValidateJsonService};
238}