tide_validator/lib.rs
1//! tide-validator is a middleware working with [Tide](https://github.com/http-rs/tide), a web framework in Rust
2//! which let you validate your data coming from a request. You'll be able
3//! to create custom validators to validate your HTTP parameters, query parameters,
4//! cookies and headers.
5//!
6//! # Features
7//!
8//! - __Custom validators:__ you can chain multiple validators and develop a custom validator is very easy. It's just a closure.
9//! - __Validate everything:__ with the enum `HttpField` you can validate different fields like cookies, headers, query parameters and parameters.
10//! - __Your own errors:__ thanks to generics in Rust you can use your own custom error when the data is invalid.
11//! need.
12//!
13//! # Validators
14//!
15//! To create your own validator it's just a closure to create with this form:
16//!
17//! ```rust,no_run,compile_fail
18//! // The first closure's parameter is the parameter/queryparameter/cookie/header name.
19//! // The second parameter is the value of this HTTP element. None means the field doesn't exist in the request (useful to force specific fields to be required).
20//! Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static where T: Serialize + Send + Sync + 'static
21//! ```
22//!
23//! # Examples
24//!
25//! __simple validation__
26//! ```rust,no_run,compile_fail
27//! // Our own validator is a simple closure to check if the field is a number
28//! fn is_number(field_name: &str, field_value: Option<&str>) -> Result<(), String> {
29//! if let Some(field_value) = field_value {
30//! if field_value.parse::<i64>().is_err() {
31//! return Err(format!("field '{}' = '{}' is not a valid number", field_name, field_value));
32//! }
33//! }
34//!
35//! Ok(())
36//! }
37//!
38//! //... in main function
39//! let mut app = tide::new();
40//! let mut validator_middleware = ValidatorMiddleware::new();
41//! // 'age' is the parameter name inside the route '/test/:age'
42//! validator_middleware.add_validator(HttpField::Param("age"), is_number);
43//! // You can assign different middleware for each routes therefore different validators for each routes
44//! app.at("/test/:age")
45//! .middleware(validator_middleware)
46//! .get(|_: tide::Request<()>| async move {
47//! let cat = Cat {
48//! name: "Gribouille".into(),
49//! };
50//! Ok(tide::Response::new(StatusCode::Ok).body_json(&cat).unwrap())
51//! });
52//! app.listen("127.0.0.1:8080").await?;
53//! ```
54//!
55//! __chain multiple validators__
56//! ```rust,no_run,compile_fail
57//! // This validator force element to be required
58//! fn is_required(field_name: &str, field_value: Option<&str>) -> Result<(), String> {
59//! if field_value.is_none() {
60//! Err(format!("'{}' is required", field_name))
61//! } else {
62//! Ok(())
63//! }
64//! }
65//!
66//! // ... your main function
67//!
68//! let mut app = tide::new();
69//! let mut validator_middleware = ValidatorMiddleware::new();
70//! // Here 'age' is a query parameter, the validator stay the same as in previous example
71//! validator_middleware.add_validator(HttpField::QueryParam("age"), is_number);
72//! // You can also add multiple validators on a single query parameter to check different things
73//! validator_middleware.add_validator(HttpField::QueryParam("age"), is_required);
74//!
75//! // You can assign different middleware for each routes therefore different validators for each routes
76//! app.at("/test")
77//! .middleware(validator_middleware)
78//! .get(|_: tide::Request<()>| async move {
79//! let cat = Cat {
80//! name: "Mozart".into(),
81//! };
82//! Ok(tide::Response::new(StatusCode::Ok).body_json(&cat).unwrap())
83//! },
84//! );
85//!
86//! app.listen("127.0.0.1:8080").await?;
87//! ```
88//!
89//! __Use your own custom error__
90//! ```rust,no_run,compile_fail
91//! // Your custom error which your api will send if an error occurs
92//! #[derive(Debug, Serialize)]
93//! struct CustomError {
94//! status_code: usize,
95//! message: String,
96//! }
97//!
98//! // Your validator can also return your own error type
99//! fn is_number(field_name: &str, field_value: Option<&str>) -> Result<(), CustomError> {
100//! if let Some(field_value) = field_value {
101//! if field_value.parse::<i64>().is_err() {
102//! return Err(CustomError {
103//! status_code: 400,
104//! message: format!(
105//! "field '{}' = '{}' is not a valid number",
106//! field_name, field_value
107//! ),
108//! });
109//! }
110//! }
111//! Ok(())
112//! }
113//!
114//! // ... your main function
115//! ```
116//!
117//! __Dynamic validators__
118//! ```rust,no_run,compile_fail
119//! // Validator inside a function as a closure to be dynamic with max_length
120//! fn is_length_under(
121//! max_length: usize,
122//! ) -> Box<dyn Fn(&str, Option<&str>) -> Result<(), CustomError> + Send + Sync + 'static> {
123//! Box::new(
124//! move |field_name: &str, field_value: Option<&str>| -> Result<(), CustomError> {
125//! if let Some(field_value) = field_value {
126//! if field_value.len() > max_length {
127//! let my_error = CustomError {
128//! status_code: 400,
129//! message: format!(
130//! "element '{} which is equals to '{}' have not the maximum length of {}",
131//! field_name, field_value, max_length
132//! ),
133//! };
134//! return Err(my_error);
135//! }
136//! }
137//! Ok(())
138//! },
139//! )
140//! }
141//!
142//! // Simply call it on a cookie `session` for example:
143//! validator_middleware.add_validator(HttpField::Cookie("session"), is_length_under(20));
144//!
145//! ```
146//!
147//! For more details about examples check out [the `examples` directory on GitHub](https://github.com/bnjjj/tide-validator/tree/master/examples)
148
149use std::collections::HashMap;
150use std::str::FromStr;
151use std::{fmt::Debug, sync::Arc};
152
153use futures::future::BoxFuture;
154use serde::Serialize;
155use tide::{http::headers::HeaderName, Middleware, Next, Request, Response, StatusCode};
156// trait Validator = Fn(&str) -> Result<(), String> + Send + Sync + 'static;
157
158/// Enum to indicate on which HTTP field you want to make validations
159#[derive(Debug, Clone, Hash, Eq, PartialEq)]
160pub enum HttpField<'a> {
161 /// To validate a path parameter. Example in URL `/test/:name` you can use `HttpField::Param("name")`
162 Param(&'a str),
163 /// To validate a query parameter. Example in URL `/test?name=test` you can use `HttpField::QueryParam("name")`
164 QueryParam(&'a str),
165 /// To validate a header. Example `HttpField::Header("X-My-Custom-Header")`
166 Header(&'a str),
167 /// To validate a cookie. Example `HttpField::Cookie("session")`
168 Cookie(&'a str),
169}
170
171/// Used as a middleware in your tide framework and add your custom validators
172pub struct ValidatorMiddleware<T>
173where
174 T: Serialize + Send + Sync + 'static,
175{
176 validators: HashMap<
177 HttpField<'static>,
178 Vec<Arc<dyn Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static>>,
179 >,
180}
181impl<T> Debug for ValidatorMiddleware<T>
182where
183 T: Serialize + Send + Sync + 'static,
184{
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 f.write_fmt(format_args!("validators keys {:?}", self.validators.keys()))
187 }
188}
189
190impl<T> ValidatorMiddleware<T>
191where
192 T: Serialize + Send + Sync + 'static,
193{
194 /// Create a new ValidatorMiddleware to put in your tide configuration.
195 ///
196 /// # Example
197 ///
198 /// ```rust,no_run,compile_fail
199 /// fn main() -> io::Result<()> {
200 /// task::block_on(async {
201 /// let mut app = tide::new();
202 ///
203 /// let mut validator_middleware = ValidatorMiddleware::new();
204 /// validator_middleware.add_validator(HttpField::Header("X-Custom-Header"), is_number);
205 ///
206 /// app.at("/test/:n").middleware(validator_middleware).get(
207 /// |_: tide::Request<()>| async move { Ok(tide::Response::new(StatusCode::Ok).body_json("test").unwrap()) },
208 /// );
209 ///
210 /// app.listen("127.0.0.1:8080").await?;
211 /// Ok(())
212 /// })
213 /// }
214 /// ```
215 pub fn new() -> Self {
216 ValidatorMiddleware {
217 validators: HashMap::new(),
218 }
219 }
220
221 pub fn with_validators<F>(mut self, validators: HashMap<HttpField<'static>, F>) -> Self
222 where
223 F: Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static,
224 {
225 for (param_name, validator) in validators {
226 self.add_validator(param_name, validator);
227 }
228 self
229 }
230
231 /// Add new validator for your middleware
232 ///
233 /// # Example
234 ///
235 /// ```rust,no_run,compile_fail
236 /// fn main() -> io::Result<()> {
237 /// task::block_on(async {
238 /// let mut app = tide::new();
239 ///
240 /// let mut validator_middleware = ValidatorMiddleware::new();
241 /// validator_middleware.add_validator(HttpField::Header("X-Custom-Header"), is_number);
242 /// validator_middleware.add_validator(HttpField::QueryParam("myqueryparam"), is_required);
243 ///
244 /// app.at("/test/:n").middleware(validator_middleware).get(
245 /// |_: tide::Request<()>| async move { Ok(tide::Response::new(StatusCode::Ok).body_json("test").unwrap()) },
246 /// );
247 ///
248 /// app.listen("127.0.0.1:8080").await?;
249 /// Ok(())
250 /// })
251 /// }
252 /// ```
253 pub fn add_validator<F>(&mut self, param_name: HttpField<'static>, validator: F)
254 where
255 F: Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static,
256 {
257 let validator = Arc::new(validator);
258 let validator_moved = Arc::clone(&validator);
259 self.validators
260 .entry(param_name.into())
261 .and_modify(|e| e.push(validator_moved))
262 .or_insert(vec![validator]);
263 }
264}
265
266impl<State, T> Middleware<State> for ValidatorMiddleware<T>
267where
268 State: Send + Sync + 'static,
269 T: Serialize + Send + Sync + 'static,
270{
271 fn handle<'a>(
272 &'a self,
273 ctx: Request<State>,
274 next: Next<'a, State>,
275 ) -> BoxFuture<'a, tide::Result> {
276 Box::pin(async move {
277 let mut query_parameters: Option<HashMap<String, String>> = None;
278
279 for (param_name, validators) in &self.validators {
280 match param_name {
281 HttpField::Param(param_name) => {
282 for validator in validators {
283 let param_found: Result<String, _> = ctx.param(param_name);
284 if let Err(err) =
285 validator(param_name, param_found.ok().as_ref().map(|p| &p[..]))
286 {
287 return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else(
288 |err| {
289 Response::new(StatusCode::InternalServerError).body_string(format!(
290 "cannot serialize your parameter validator for '{}' error : {:?}",
291 param_name,
292 err
293 ))
294 },
295 ));
296 }
297 }
298 }
299 HttpField::QueryParam(param_name) => {
300 if query_parameters.is_none() {
301 match ctx.query::<HashMap<String, String>>() {
302 Err(err) => {
303 return Ok(Response::new(StatusCode::InternalServerError)
304 .body_string(format!(
305 "cannot read query parameters: {:?}",
306 err
307 )));
308 }
309 Ok(qps) => query_parameters = Some(qps),
310 }
311 }
312 let query_parameters = query_parameters.as_ref().unwrap();
313
314 for validator in validators {
315 if let Err(err) = validator(
316 param_name,
317 query_parameters.get(¶m_name[..]).map(|p| &p[..]),
318 ) {
319 return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else(
320 |err| {
321 Response::new(StatusCode::InternalServerError).body_string(format!(
322 "cannot serialize your query parameter validator for '{}' error : {:?}",
323 param_name,
324 err
325 ))
326 },
327 ));
328 }
329 }
330 }
331 HttpField::Header(header_name) => {
332 for validator in validators {
333 let header_found: Option<&str> = ctx
334 .header(&HeaderName::from_str(header_name).unwrap())
335 .map(|header| header.last().map(|val| val.as_str()).unwrap());
336 if let Err(err) = validator(header_name, header_found) {
337 return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else(
338 |err| {
339 Response::new(StatusCode::InternalServerError).body_string(format!(
340 "cannot serialize your header validator for '{}' error : {:?}",
341 header_name,
342 err
343 ))
344 },
345 ));
346 }
347 }
348 }
349 HttpField::Cookie(cookie_name) => {
350 for validator in validators {
351 let cookie_found = ctx.cookie(cookie_name);
352 if let Err(err) =
353 validator(cookie_name, cookie_found.as_ref().map(|c| c.value()))
354 {
355 return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else(
356 |err| {
357 Response::new(StatusCode::InternalServerError).body_string(format!(
358 "cannot serialize your cookie validator for '{}' error : {:?}",
359 cookie_name,
360 err
361 ))
362 },
363 ));
364 }
365 }
366 }
367 }
368 }
369 next.run(ctx).await
370 })
371 }
372}
373
374#[cfg(test)]
375mod tests {
376
377 use super::{HttpField, StatusCode, ValidatorMiddleware};
378
379 use super::*;
380 use async_std::io::prelude::*;
381 use futures::executor::block_on;
382 use http_service_mock::make_server;
383 use serde::{Deserialize, Serialize};
384 use tide::http::{Method, Request};
385
386 #[inline]
387 fn is_number(field_name: &str, field_value: Option<&str>) -> Result<(), String> {
388 if let Some(field_value) = field_value {
389 if field_value.parse::<i64>().is_err() {
390 return Err(format!(
391 "field '{}' = '{}' is not a valid number",
392 field_name, field_value
393 ));
394 }
395 }
396
397 Ok(())
398 }
399
400 #[test]
401 fn validator_simple() {
402 let mut inner = tide::new();
403 let mut validators = ValidatorMiddleware::new();
404 validators.add_validator(HttpField::Param("bar"), is_number);
405 inner
406 .at("/foo/:bar")
407 .middleware(validators)
408 .get(|_| async { Ok("foo") });
409
410 let mut server = make_server(inner).unwrap();
411
412 let mut buf = Vec::new();
413 let req = Request::new(Method::Get, "http://localhost/foo/4".parse().unwrap());
414 let mut res = server.simulate(req).unwrap();
415 assert_eq!(res.status(), 200);
416 block_on(res.read_to_end(&mut buf)).unwrap();
417 assert_eq!(&*buf, &*b"foo");
418
419 buf.clear();
420 let req = Request::new(Method::Get, "http://localhost/foo/bar".parse().unwrap());
421 let mut res = server.simulate(req).unwrap();
422 assert_eq!(res.status(), StatusCode::BadRequest);
423 block_on(res.read_to_end(&mut buf)).unwrap();
424 assert_eq!(
425 String::from_utf8_lossy(&buf[..]),
426 String::from(r#""field 'bar' = 'bar' is not a valid number""#)
427 );
428 }
429
430 #[derive(Debug, Serialize, Deserialize)]
431 struct CustomError {
432 status_code: usize,
433 message: String,
434 }
435
436 fn is_length_under(
437 max_length: usize,
438 ) -> Box<dyn Fn(&str, Option<&str>) -> Result<(), CustomError> + Send + Sync + 'static> {
439 Box::new(
440 move |field_name: &str, field_value: Option<&str>| -> Result<(), CustomError> {
441 if let Some(field_value) = field_value {
442 if field_value.len() > max_length {
443 let my_error = CustomError {
444 status_code: 400,
445 message: format!(
446 "element '{}' which is equals to '{}' have not the maximum length of {}",
447 field_name, field_value, max_length
448 ),
449 };
450 return Err(my_error);
451 }
452 }
453 Ok(())
454 },
455 )
456 }
457
458 #[test]
459 fn validator_custom() {
460 let mut inner = tide::new();
461 let mut validators = ValidatorMiddleware::new();
462 validators.add_validator(HttpField::QueryParam("test"), is_length_under(10));
463 validators.add_validator(HttpField::Cookie("session"), is_length_under(10));
464 inner
465 .at("/foo")
466 .middleware(validators)
467 .get(|_| async { Ok("foo") });
468
469 let mut server = make_server(inner).unwrap();
470
471 let mut buf = Vec::new();
472 let req = Request::new(
473 Method::Get,
474 "http://localhost/foo?test=coucou".parse().unwrap(),
475 );
476 let mut res = server.simulate(req).unwrap();
477 assert_eq!(res.status(), 200);
478 block_on(res.read_to_end(&mut buf)).unwrap();
479 assert_eq!(&*buf, &*b"foo");
480
481 buf.clear();
482
483 let req = Request::new(
484 Method::Get,
485 "http://localhost/foo?test=blablablablabla".parse().unwrap(),
486 );
487 let mut res = server.simulate(req).unwrap();
488 assert_eq!(res.status(), StatusCode::BadRequest);
489 block_on(res.read_to_end(&mut buf)).unwrap();
490
491 let err: CustomError = serde_json::from_slice(&buf[..]).unwrap();
492
493 assert_eq!(err.status_code, 400usize);
494 assert_eq!(
495 err.message,
496 String::from("element 'test' which is equals to 'blablablablabla' have not the maximum length of 10")
497 );
498 }
499
500 #[inline]
501 fn is_bool(field_name: &str, field_value: Option<&str>) -> Result<(), CustomError> {
502 if let Some(field_value) = field_value {
503 match field_value {
504 "true" | "false" => return Ok(()),
505 other => {
506 return Err(CustomError {
507 status_code: 400,
508 message: format!(
509 "field '{}' = '{}' is not a valid boolean",
510 field_name, other
511 ),
512 })
513 }
514 }
515 }
516 Ok(())
517 }
518
519 #[inline]
520 fn is_required(field_name: &str, field_value: Option<&str>) -> Result<(), CustomError> {
521 if field_value.is_none() {
522 Err(CustomError {
523 status_code: 400,
524 message: format!("'{}' is mandatory", field_name),
525 })
526 } else {
527 Ok(())
528 }
529 }
530
531 #[test]
532 fn validator_chains() {
533 let mut inner = tide::new();
534 let mut validators = ValidatorMiddleware::new();
535 validators.add_validator(HttpField::QueryParam("test"), is_length_under(10));
536 validators.add_validator(HttpField::Header("X-Is-Connected"), is_required);
537 validators.add_validator(HttpField::Header("X-Is-Connected"), is_bool);
538 inner
539 .at("/foo")
540 .middleware(validators)
541 .get(|_| async { Ok("foo") });
542
543 let mut server = make_server(inner).unwrap();
544
545 let mut buf = Vec::new();
546
547 let mut req = Request::new(
548 Method::Get,
549 "http://localhost/foo?test=coucou".parse().unwrap(),
550 );
551 req.insert_header("X-Is-Connected", "true").unwrap();
552 let mut res = server.simulate(req).unwrap();
553 assert_eq!(res.status(), 200);
554 block_on(res.read_to_end(&mut buf)).unwrap();
555 assert_eq!(&*buf, &*b"foo");
556
557 buf.clear();
558 let req = Request::new(
559 Method::Get,
560 "http://localhost/foo?test=coucou".parse().unwrap(),
561 );
562 let mut res = server.simulate(req).unwrap();
563 assert_eq!(res.status(), StatusCode::BadRequest);
564 block_on(res.read_to_end(&mut buf)).unwrap();
565
566 let err: CustomError = serde_json::from_slice(&buf[..]).unwrap();
567
568 assert_eq!(err.status_code, 400usize);
569 assert_eq!(err.message, String::from("'X-Is-Connected' is mandatory"));
570 }
571}