tide_validator/
lib.rs

1//! tide-validator is a middleware working with [Tide](https://github.com/http-rs/tide), a web framework in Rust
2//! which let you validate your data coming from a request. You'll be able
3//! to create custom validators to validate your HTTP parameters, query parameters,
4//! cookies and headers.
5//!
6//! # Features
7//!
8//! - __Custom validators:__ you can chain multiple validators and develop a custom validator is very easy. It's just a closure.
9//! - __Validate everything:__ with the enum `HttpField` you can validate different fields like cookies, headers, query parameters and parameters.
10//! - __Your own errors:__ thanks to generics in Rust you can use your own custom error when the data is invalid.
11//!     need.
12//!
13//! # Validators
14//!
15//! To create your own validator it's just a closure to create with this form:
16//!
17//! ```rust,no_run,compile_fail
18//! // The first closure's parameter is the parameter/queryparameter/cookie/header name.
19//! // The second parameter is the value of this HTTP element. None means the field doesn't exist in the request (useful to force specific fields to be required).
20//! Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static where T: Serialize + Send + Sync + 'static
21//! ```
22//!
23//! # Examples
24//!
25//! __simple validation__
26//! ```rust,no_run,compile_fail
27//! // Our own validator is a simple closure to check if the field is a number
28//! fn is_number(field_name: &str, field_value: Option<&str>) -> Result<(), String> {
29//!     if let Some(field_value) = field_value {
30//!         if field_value.parse::<i64>().is_err() {
31//!             return Err(format!("field '{}' = '{}' is not a valid number", field_name, field_value));
32//!         }
33//!     }
34//!
35//!     Ok(())
36//! }
37//!
38//! //... in main function
39//! let mut app = tide::new();
40//! let mut validator_middleware = ValidatorMiddleware::new();
41//! // 'age' is the parameter name inside the route '/test/:age'
42//! validator_middleware.add_validator(HttpField::Param("age"), is_number);
43//! // You can assign different middleware for each routes therefore different validators for each routes
44//! app.at("/test/:age")
45//!     .middleware(validator_middleware)
46//!     .get(|_: tide::Request<()>| async move {
47//!         let cat = Cat {
48//!             name: "Gribouille".into(),
49//!         };
50//!         Ok(tide::Response::new(StatusCode::Ok).body_json(&cat).unwrap())
51//!      });
52//! app.listen("127.0.0.1:8080").await?;
53//! ```
54//!
55//! __chain multiple validators__
56//! ```rust,no_run,compile_fail
57//! // This validator force element to be required
58//! fn is_required(field_name: &str, field_value: Option<&str>) -> Result<(), String> {
59//!     if field_value.is_none() {
60//!         Err(format!("'{}' is required", field_name))
61//!     } else {
62//!         Ok(())
63//!     }
64//! }
65//!
66//! // ... your main function
67//!
68//! let mut app = tide::new();
69//! let mut validator_middleware = ValidatorMiddleware::new();
70//! // Here 'age' is a query parameter, the validator stay the same as in previous example
71//! validator_middleware.add_validator(HttpField::QueryParam("age"), is_number);
72//! // You can also add multiple validators on a single query parameter to check different things
73//! validator_middleware.add_validator(HttpField::QueryParam("age"), is_required);
74//!
75//! // You can assign different middleware for each routes therefore different validators for each routes
76//! app.at("/test")
77//!     .middleware(validator_middleware)
78//!     .get(|_: tide::Request<()>| async move {
79//!            let cat = Cat {
80//!                 name: "Mozart".into(),
81//!            };
82//!            Ok(tide::Response::new(StatusCode::Ok).body_json(&cat).unwrap())
83//!         },
84//!     );
85//!
86//! app.listen("127.0.0.1:8080").await?;
87//! ```
88//!
89//! __Use your own custom error__
90//! ```rust,no_run,compile_fail
91//! // Your custom error which your api will send if an error occurs
92//! #[derive(Debug, Serialize)]
93//! struct CustomError {
94//!     status_code: usize,
95//!     message: String,
96//! }
97//!
98//! // Your validator can also return your own error type
99//! fn is_number(field_name: &str, field_value: Option<&str>) -> Result<(), CustomError> {
100//!     if let Some(field_value) = field_value {
101//!         if field_value.parse::<i64>().is_err() {
102//!             return Err(CustomError {
103//!                 status_code: 400,
104//!                 message: format!(
105//!                     "field '{}' = '{}' is not a valid number",
106//!                     field_name, field_value
107//!                 ),
108//!             });
109//!         }
110//!     }
111//!     Ok(())
112//! }
113//!
114//! // ... your main function
115//! ```
116//!
117//! __Dynamic validators__
118//! ```rust,no_run,compile_fail
119//! // Validator inside a function as a closure to be dynamic with max_length
120//! fn is_length_under(
121//!     max_length: usize,
122//! ) -> Box<dyn Fn(&str, Option<&str>) -> Result<(), CustomError> + Send + Sync + 'static> {
123//!     Box::new(
124//!         move |field_name: &str, field_value: Option<&str>| -> Result<(), CustomError> {
125//!             if let Some(field_value) = field_value {
126//!                 if field_value.len() > max_length {
127//!                     let my_error = CustomError {
128//!                         status_code: 400,
129//!                         message: format!(
130//!                             "element '{} which is equals to '{}' have not the maximum length of {}",
131//!                             field_name, field_value, max_length
132//!                         ),
133//!                     };
134//!                     return Err(my_error);
135//!                 }
136//!             }
137//!             Ok(())
138//!         },
139//!     )
140//! }
141//!
142//! // Simply call it on a cookie `session` for example:
143//! validator_middleware.add_validator(HttpField::Cookie("session"), is_length_under(20));
144//!
145//! ```
146//!
147//! For more details about examples check out [the `examples` directory on GitHub](https://github.com/bnjjj/tide-validator/tree/master/examples)
148
149use std::collections::HashMap;
150use std::str::FromStr;
151use std::{fmt::Debug, sync::Arc};
152
153use futures::future::BoxFuture;
154use serde::Serialize;
155use tide::{http::headers::HeaderName, Middleware, Next, Request, Response, StatusCode};
156// trait Validator = Fn(&str) -> Result<(), String> + Send + Sync + 'static;
157
158/// Enum to indicate on which HTTP field you want to make validations
159#[derive(Debug, Clone, Hash, Eq, PartialEq)]
160pub enum HttpField<'a> {
161    /// To validate a path parameter. Example in URL `/test/:name` you can use `HttpField::Param("name")`
162    Param(&'a str),
163    /// To validate a query parameter. Example in URL `/test?name=test` you can use `HttpField::QueryParam("name")`
164    QueryParam(&'a str),
165    /// To validate a header. Example `HttpField::Header("X-My-Custom-Header")`
166    Header(&'a str),
167    /// To validate a cookie. Example `HttpField::Cookie("session")`
168    Cookie(&'a str),
169}
170
171/// Used as a middleware in your tide framework and add your custom validators
172pub struct ValidatorMiddleware<T>
173where
174    T: Serialize + Send + Sync + 'static,
175{
176    validators: HashMap<
177        HttpField<'static>,
178        Vec<Arc<dyn Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static>>,
179    >,
180}
181impl<T> Debug for ValidatorMiddleware<T>
182where
183    T: Serialize + Send + Sync + 'static,
184{
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        f.write_fmt(format_args!("validators keys {:?}", self.validators.keys()))
187    }
188}
189
190impl<T> ValidatorMiddleware<T>
191where
192    T: Serialize + Send + Sync + 'static,
193{
194    /// Create a new ValidatorMiddleware to put in your tide configuration.
195    ///
196    /// # Example
197    ///
198    /// ```rust,no_run,compile_fail
199    /// fn main() -> io::Result<()> {
200    ///     task::block_on(async {
201    ///         let mut app = tide::new();
202    ///
203    ///         let mut validator_middleware = ValidatorMiddleware::new();
204    ///         validator_middleware.add_validator(HttpField::Header("X-Custom-Header"), is_number);
205    ///
206    ///         app.at("/test/:n").middleware(validator_middleware).get(
207    ///             |_: tide::Request<()>| async move { Ok(tide::Response::new(StatusCode::Ok).body_json("test").unwrap()) },
208    ///         );
209    ///
210    ///         app.listen("127.0.0.1:8080").await?;
211    ///         Ok(())
212    ///     })
213    /// }
214    /// ```
215    pub fn new() -> Self {
216        ValidatorMiddleware {
217            validators: HashMap::new(),
218        }
219    }
220
221    pub fn with_validators<F>(mut self, validators: HashMap<HttpField<'static>, F>) -> Self
222    where
223        F: Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static,
224    {
225        for (param_name, validator) in validators {
226            self.add_validator(param_name, validator);
227        }
228        self
229    }
230
231    /// Add new validator for your middleware
232    ///
233    /// # Example
234    ///
235    /// ```rust,no_run,compile_fail
236    /// fn main() -> io::Result<()> {
237    ///     task::block_on(async {
238    ///         let mut app = tide::new();
239    ///
240    ///         let mut validator_middleware = ValidatorMiddleware::new();
241    ///         validator_middleware.add_validator(HttpField::Header("X-Custom-Header"), is_number);
242    ///         validator_middleware.add_validator(HttpField::QueryParam("myqueryparam"), is_required);
243    ///
244    ///         app.at("/test/:n").middleware(validator_middleware).get(
245    ///             |_: tide::Request<()>| async move { Ok(tide::Response::new(StatusCode::Ok).body_json("test").unwrap()) },
246    ///         );
247    ///
248    ///         app.listen("127.0.0.1:8080").await?;
249    ///         Ok(())
250    ///     })
251    /// }
252    /// ```
253    pub fn add_validator<F>(&mut self, param_name: HttpField<'static>, validator: F)
254    where
255        F: Fn(&str, Option<&str>) -> Result<(), T> + Send + Sync + 'static,
256    {
257        let validator = Arc::new(validator);
258        let validator_moved = Arc::clone(&validator);
259        self.validators
260            .entry(param_name.into())
261            .and_modify(|e| e.push(validator_moved))
262            .or_insert(vec![validator]);
263    }
264}
265
266impl<State, T> Middleware<State> for ValidatorMiddleware<T>
267where
268    State: Send + Sync + 'static,
269    T: Serialize + Send + Sync + 'static,
270{
271    fn handle<'a>(
272        &'a self,
273        ctx: Request<State>,
274        next: Next<'a, State>,
275    ) -> BoxFuture<'a, tide::Result> {
276        Box::pin(async move {
277            let mut query_parameters: Option<HashMap<String, String>> = None;
278
279            for (param_name, validators) in &self.validators {
280                match param_name {
281                    HttpField::Param(param_name) => {
282                        for validator in validators {
283                            let param_found: Result<String, _> = ctx.param(param_name);
284                            if let Err(err) =
285                                validator(param_name, param_found.ok().as_ref().map(|p| &p[..]))
286                            {
287                                return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else(
288                                        |err| {
289                                            Response::new(StatusCode::InternalServerError).body_string(format!(
290                                                "cannot serialize your parameter validator for '{}' error : {:?}",
291                                                param_name,
292                                                err
293                                            ))
294                                        },
295                                    ));
296                            }
297                        }
298                    }
299                    HttpField::QueryParam(param_name) => {
300                        if query_parameters.is_none() {
301                            match ctx.query::<HashMap<String, String>>() {
302                                Err(err) => {
303                                    return Ok(Response::new(StatusCode::InternalServerError)
304                                        .body_string(format!(
305                                            "cannot read query parameters: {:?}",
306                                            err
307                                        )));
308                                }
309                                Ok(qps) => query_parameters = Some(qps),
310                            }
311                        }
312                        let query_parameters = query_parameters.as_ref().unwrap();
313
314                        for validator in validators {
315                            if let Err(err) = validator(
316                                param_name,
317                                query_parameters.get(&param_name[..]).map(|p| &p[..]),
318                            ) {
319                                return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else(
320                                        |err| {
321                                            Response::new(StatusCode::InternalServerError).body_string(format!(
322                                                "cannot serialize your query parameter validator for '{}' error : {:?}",
323                                                param_name,
324                                                err
325                                            ))
326                                        },
327                                    ));
328                            }
329                        }
330                    }
331                    HttpField::Header(header_name) => {
332                        for validator in validators {
333                            let header_found: Option<&str> = ctx
334                                .header(&HeaderName::from_str(header_name).unwrap())
335                                .map(|header| header.last().map(|val| val.as_str()).unwrap());
336                            if let Err(err) = validator(header_name, header_found) {
337                                return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else(
338                                        |err| {
339                                            Response::new(StatusCode::InternalServerError).body_string(format!(
340                                                "cannot serialize your header validator for '{}' error : {:?}",
341                                                header_name,
342                                                err
343                                            ))
344                                        },
345                                    ));
346                            }
347                        }
348                    }
349                    HttpField::Cookie(cookie_name) => {
350                        for validator in validators {
351                            let cookie_found = ctx.cookie(cookie_name);
352                            if let Err(err) =
353                                validator(cookie_name, cookie_found.as_ref().map(|c| c.value()))
354                            {
355                                return Ok(Response::new(StatusCode::BadRequest).body_json(&err).unwrap_or_else(
356                                        |err| {
357                                            Response::new(StatusCode::InternalServerError).body_string(format!(
358                                                "cannot serialize your cookie validator for '{}' error : {:?}",
359                                                cookie_name,
360                                                err
361                                            ))
362                                        },
363                                    ));
364                            }
365                        }
366                    }
367                }
368            }
369            next.run(ctx).await
370        })
371    }
372}
373
374#[cfg(test)]
375mod tests {
376
377    use super::{HttpField, StatusCode, ValidatorMiddleware};
378
379    use super::*;
380    use async_std::io::prelude::*;
381    use futures::executor::block_on;
382    use http_service_mock::make_server;
383    use serde::{Deserialize, Serialize};
384    use tide::http::{Method, Request};
385
386    #[inline]
387    fn is_number(field_name: &str, field_value: Option<&str>) -> Result<(), String> {
388        if let Some(field_value) = field_value {
389            if field_value.parse::<i64>().is_err() {
390                return Err(format!(
391                    "field '{}' = '{}' is not a valid number",
392                    field_name, field_value
393                ));
394            }
395        }
396
397        Ok(())
398    }
399
400    #[test]
401    fn validator_simple() {
402        let mut inner = tide::new();
403        let mut validators = ValidatorMiddleware::new();
404        validators.add_validator(HttpField::Param("bar"), is_number);
405        inner
406            .at("/foo/:bar")
407            .middleware(validators)
408            .get(|_| async { Ok("foo") });
409
410        let mut server = make_server(inner).unwrap();
411
412        let mut buf = Vec::new();
413        let req = Request::new(Method::Get, "http://localhost/foo/4".parse().unwrap());
414        let mut res = server.simulate(req).unwrap();
415        assert_eq!(res.status(), 200);
416        block_on(res.read_to_end(&mut buf)).unwrap();
417        assert_eq!(&*buf, &*b"foo");
418
419        buf.clear();
420        let req = Request::new(Method::Get, "http://localhost/foo/bar".parse().unwrap());
421        let mut res = server.simulate(req).unwrap();
422        assert_eq!(res.status(), StatusCode::BadRequest);
423        block_on(res.read_to_end(&mut buf)).unwrap();
424        assert_eq!(
425            String::from_utf8_lossy(&buf[..]),
426            String::from(r#""field 'bar' = 'bar' is not a valid number""#)
427        );
428    }
429
430    #[derive(Debug, Serialize, Deserialize)]
431    struct CustomError {
432        status_code: usize,
433        message: String,
434    }
435
436    fn is_length_under(
437        max_length: usize,
438    ) -> Box<dyn Fn(&str, Option<&str>) -> Result<(), CustomError> + Send + Sync + 'static> {
439        Box::new(
440            move |field_name: &str, field_value: Option<&str>| -> Result<(), CustomError> {
441                if let Some(field_value) = field_value {
442                    if field_value.len() > max_length {
443                        let my_error = CustomError {
444                            status_code: 400,
445                            message: format!(
446                            "element '{}' which is equals to '{}' have not the maximum length of {}",
447                            field_name, field_value, max_length
448                        ),
449                        };
450                        return Err(my_error);
451                    }
452                }
453                Ok(())
454            },
455        )
456    }
457
458    #[test]
459    fn validator_custom() {
460        let mut inner = tide::new();
461        let mut validators = ValidatorMiddleware::new();
462        validators.add_validator(HttpField::QueryParam("test"), is_length_under(10));
463        validators.add_validator(HttpField::Cookie("session"), is_length_under(10));
464        inner
465            .at("/foo")
466            .middleware(validators)
467            .get(|_| async { Ok("foo") });
468
469        let mut server = make_server(inner).unwrap();
470
471        let mut buf = Vec::new();
472        let req = Request::new(
473            Method::Get,
474            "http://localhost/foo?test=coucou".parse().unwrap(),
475        );
476        let mut res = server.simulate(req).unwrap();
477        assert_eq!(res.status(), 200);
478        block_on(res.read_to_end(&mut buf)).unwrap();
479        assert_eq!(&*buf, &*b"foo");
480
481        buf.clear();
482
483        let req = Request::new(
484            Method::Get,
485            "http://localhost/foo?test=blablablablabla".parse().unwrap(),
486        );
487        let mut res = server.simulate(req).unwrap();
488        assert_eq!(res.status(), StatusCode::BadRequest);
489        block_on(res.read_to_end(&mut buf)).unwrap();
490
491        let err: CustomError = serde_json::from_slice(&buf[..]).unwrap();
492
493        assert_eq!(err.status_code, 400usize);
494        assert_eq!(
495            err.message,
496            String::from("element 'test' which is equals to 'blablablablabla' have not the maximum length of 10")
497        );
498    }
499
500    #[inline]
501    fn is_bool(field_name: &str, field_value: Option<&str>) -> Result<(), CustomError> {
502        if let Some(field_value) = field_value {
503            match field_value {
504                "true" | "false" => return Ok(()),
505                other => {
506                    return Err(CustomError {
507                        status_code: 400,
508                        message: format!(
509                            "field '{}' = '{}' is not a valid boolean",
510                            field_name, other
511                        ),
512                    })
513                }
514            }
515        }
516        Ok(())
517    }
518
519    #[inline]
520    fn is_required(field_name: &str, field_value: Option<&str>) -> Result<(), CustomError> {
521        if field_value.is_none() {
522            Err(CustomError {
523                status_code: 400,
524                message: format!("'{}' is mandatory", field_name),
525            })
526        } else {
527            Ok(())
528        }
529    }
530
531    #[test]
532    fn validator_chains() {
533        let mut inner = tide::new();
534        let mut validators = ValidatorMiddleware::new();
535        validators.add_validator(HttpField::QueryParam("test"), is_length_under(10));
536        validators.add_validator(HttpField::Header("X-Is-Connected"), is_required);
537        validators.add_validator(HttpField::Header("X-Is-Connected"), is_bool);
538        inner
539            .at("/foo")
540            .middleware(validators)
541            .get(|_| async { Ok("foo") });
542
543        let mut server = make_server(inner).unwrap();
544
545        let mut buf = Vec::new();
546
547        let mut req = Request::new(
548            Method::Get,
549            "http://localhost/foo?test=coucou".parse().unwrap(),
550        );
551        req.insert_header("X-Is-Connected", "true").unwrap();
552        let mut res = server.simulate(req).unwrap();
553        assert_eq!(res.status(), 200);
554        block_on(res.read_to_end(&mut buf)).unwrap();
555        assert_eq!(&*buf, &*b"foo");
556
557        buf.clear();
558        let req = Request::new(
559            Method::Get,
560            "http://localhost/foo?test=coucou".parse().unwrap(),
561        );
562        let mut res = server.simulate(req).unwrap();
563        assert_eq!(res.status(), StatusCode::BadRequest);
564        block_on(res.read_to_end(&mut buf)).unwrap();
565
566        let err: CustomError = serde_json::from_slice(&buf[..]).unwrap();
567
568        assert_eq!(err.status_code, 400usize);
569        assert_eq!(err.message, String::from("'X-Is-Connected' is mandatory"));
570    }
571}