tower_async_http/
validate_request.rs

1//! Middleware that validates requests.
2//!
3//! # Example
4//!
5//! ```
6//! use tower_async_http::validate_request::ValidateRequestHeaderLayer;
7//! use http::{Request, Response, StatusCode, header::ACCEPT};
8//! use http_body_util::Full;
9//! use bytes::Bytes;
10//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
11//!
12//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
13//!     Ok(Response::new(Full::default()))
14//! }
15//!
16//! # #[tokio::main]
17//! # async fn main() -> Result<(), BoxError> {
18//! let mut service = ServiceBuilder::new()
19//!     // Require the `Accept` header to be `application/json`, `*/*` or `application/*`
20//!     .layer(ValidateRequestHeaderLayer::accept("application/json"))
21//!     .service_fn(handle);
22//!
23//! // Requests with the correct value are allowed through
24//! let request = Request::builder()
25//!     .header(ACCEPT, "application/json")
26//!     .body(Full::default())
27//!     .unwrap();
28//!
29//! let response = service
30//!     .call(request)
31//!     .await?;
32//!
33//! assert_eq!(StatusCode::OK, response.status());
34//!
35//! // Requests with an invalid value get a `406 Not Acceptable` response
36//! let request = Request::builder()
37//!     .header(ACCEPT, "text/strings")
38//!     .body(Full::default())
39//!     .unwrap();
40//!
41//! let response = service
42//!     .call(request)
43//!     .await?;
44//!
45//! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status());
46//! # Ok(())
47//! # }
48//! ```
49//!
50//! Custom validation can be made by implementing [`ValidateRequest`]:
51//!
52//! ```
53//! use tower_async_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
54//! use http::{Request, Response, StatusCode, header::ACCEPT};
55//! use http_body_util::Full;
56//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
57//! use bytes::Bytes;
58//!
59//! #[derive(Clone, Copy)]
60//! pub struct MyHeader { /* ...  */ }
61//!
62//! impl<B> ValidateRequest<B> for MyHeader {
63//!     type ResponseBody = Full<Bytes>;
64//!
65//!     fn validate(
66//!         &self,
67//!         request: &mut Request<B>,
68//!     ) -> Result<(), Response<Self::ResponseBody>> {
69//!         // validate the request...
70//!         # unimplemented!()
71//!     }
72//! }
73//!
74//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
75//!     Ok(Response::new(Full::default()))
76//! }
77//!
78//!
79//! # #[tokio::main]
80//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
81//! let service = ServiceBuilder::new()
82//!     // Validate requests using `MyHeader`
83//!     .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ }))
84//!     .service_fn(handle);
85//! # Ok(())
86//! # }
87//! ```
88//!
89//! Or using a closure:
90//!
91//! ```
92//! use tower_async_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
93//! use http::{Request, Response, StatusCode, header::ACCEPT};
94//! use bytes::Bytes;
95//! use http_body_util::Full;
96//! use tower_async::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
97//!
98//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
99//!     # todo!();
100//!     // ...
101//! }
102//!
103//! # #[tokio::main]
104//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
105//! let service = ServiceBuilder::new()
106//!     .layer(ValidateRequestHeaderLayer::custom(|request: &mut Request<Full<Bytes>>| {
107//!         // Validate the request
108//!         # Ok::<_, Response<Full<Bytes>>>(())
109//!     }))
110//!     .service_fn(handle);
111//! # Ok(())
112//! # }
113//! ```
114
115use http::{header, Request, Response, StatusCode};
116use http_body::Body;
117use mime::{Mime, MimeIter};
118use std::{fmt, marker::PhantomData, sync::Arc};
119use tower_async_layer::Layer;
120use tower_async_service::Service;
121
122/// Layer that applies [`ValidateRequestHeader`] which validates all requests.
123///
124/// See the [module docs](crate::validate_request) for an example.
125#[derive(Debug, Clone)]
126pub struct ValidateRequestHeaderLayer<T> {
127    validate: T,
128}
129
130impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
131    /// Validate requests have the required Accept header.
132    ///
133    /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
134    /// as configured.
135    ///
136    /// # Panics
137    ///
138    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
139    /// See `AcceptHeader::new` for when this method panics.
140    ///
141    /// # Example
142    ///
143    /// ```
144    /// use http_body_util::Full;
145    /// use bytes::Bytes;
146    /// use tower_async_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer};
147    ///
148    /// let layer = ValidateRequestHeaderLayer::<AcceptHeader<Full<Bytes>>>::accept("application/json");
149    /// ```
150    ///
151    /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
152    pub fn accept(value: &str) -> Self
153    where
154        ResBody: Body + Default,
155    {
156        Self::custom(AcceptHeader::new(value))
157    }
158}
159
160impl<T> ValidateRequestHeaderLayer<T> {
161    /// Validate requests using a custom method.
162    pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
163        Self { validate }
164    }
165}
166
167impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
168where
169    T: Clone,
170{
171    type Service = ValidateRequestHeader<S, T>;
172
173    fn layer(&self, inner: S) -> Self::Service {
174        ValidateRequestHeader::new(inner, self.validate.clone())
175    }
176}
177
178/// Middleware that validates requests.
179///
180/// See the [module docs](crate::validate_request) for an example.
181#[derive(Clone, Debug)]
182pub struct ValidateRequestHeader<S, T> {
183    inner: S,
184    validate: T,
185}
186
187impl<S, T> ValidateRequestHeader<S, T> {
188    fn new(inner: S, validate: T) -> Self {
189        Self::custom(inner, validate)
190    }
191
192    define_inner_service_accessors!();
193}
194
195impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
196    /// Validate requests have the required Accept header.
197    ///
198    /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
199    /// as configured.
200    ///
201    /// # Panics
202    ///
203    /// See `AcceptHeader::new` for when this method panics.
204    pub fn accept(inner: S, value: &str) -> Self
205    where
206        ResBody: Body + Default,
207    {
208        Self::custom(inner, AcceptHeader::new(value))
209    }
210}
211
212impl<S, T> ValidateRequestHeader<S, T> {
213    /// Validate requests using a custom method.
214    pub fn custom(inner: S, validate: T) -> ValidateRequestHeader<S, T> {
215        Self { inner, validate }
216    }
217}
218
219impl<ReqBody, ResBody, S, V> Service<Request<ReqBody>> for ValidateRequestHeader<S, V>
220where
221    V: ValidateRequest<ReqBody, ResponseBody = ResBody>,
222    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
223{
224    type Response = Response<ResBody>;
225    type Error = S::Error;
226
227    async fn call(&self, mut req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
228        match self.validate.validate(&mut req) {
229            Ok(_) => self.inner.call(req).await,
230            Err(res) => Ok(res),
231        }
232    }
233}
234
235/// Trait for validating requests.
236pub trait ValidateRequest<B> {
237    /// The body type used for responses to unvalidated requests.
238    type ResponseBody;
239
240    /// Validate the request.
241    ///
242    /// If `Ok(())` is returned then the request is allowed through, otherwise not.
243    fn validate(&self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>>;
244}
245
246impl<B, F, ResBody> ValidateRequest<B> for F
247where
248    F: Fn(&mut Request<B>) -> Result<(), Response<ResBody>>,
249{
250    type ResponseBody = ResBody;
251
252    fn validate(&self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
253        self(request)
254    }
255}
256
257/// Type that performs validation of the Accept header.
258pub struct AcceptHeader<ResBody> {
259    header_value: Arc<Mime>,
260    _ty: PhantomData<fn() -> ResBody>,
261}
262
263impl<ResBody> AcceptHeader<ResBody> {
264    /// Create a new `AcceptHeader`.
265    ///
266    /// # Panics
267    ///
268    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
269    fn new(header_value: &str) -> Self
270    where
271        ResBody: Body + Default,
272    {
273        Self {
274            header_value: Arc::new(
275                header_value
276                    .parse::<Mime>()
277                    .expect("value is not a valid header value"),
278            ),
279            _ty: PhantomData,
280        }
281    }
282}
283
284impl<ResBody> Clone for AcceptHeader<ResBody> {
285    fn clone(&self) -> Self {
286        Self {
287            header_value: self.header_value.clone(),
288            _ty: PhantomData,
289        }
290    }
291}
292
293impl<ResBody> fmt::Debug for AcceptHeader<ResBody> {
294    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295        f.debug_struct("AcceptHeader")
296            .field("header_value", &self.header_value)
297            .finish()
298    }
299}
300
301impl<B, ResBody> ValidateRequest<B> for AcceptHeader<ResBody>
302where
303    ResBody: Body + Default,
304{
305    type ResponseBody = ResBody;
306
307    fn validate(&self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
308        if !req.headers().contains_key(header::ACCEPT) {
309            return Ok(());
310        }
311        if req
312            .headers()
313            .get_all(header::ACCEPT)
314            .into_iter()
315            .filter_map(|header| header.to_str().ok())
316            .any(|h| {
317                MimeIter::new(h)
318                    .map(|mim| {
319                        if let Ok(mim) = mim {
320                            let typ = self.header_value.type_();
321                            let subtype = self.header_value.subtype();
322                            match (mim.type_(), mim.subtype()) {
323                                (t, s) if t == typ && s == subtype => true,
324                                (t, mime::STAR) if t == typ => true,
325                                (mime::STAR, mime::STAR) => true,
326                                _ => false,
327                            }
328                        } else {
329                            false
330                        }
331                    })
332                    .reduce(|acc, mim| acc || mim)
333                    .unwrap_or(false)
334            })
335        {
336            return Ok(());
337        }
338        let mut res = Response::new(ResBody::default());
339        *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
340        Err(res)
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    #[allow(unused_imports)]
347    use super::*;
348
349    use crate::test_helpers::Body;
350
351    use http::{header, StatusCode};
352    use tower_async::{BoxError, ServiceBuilder};
353
354    #[tokio::test]
355    async fn valid_accept_header() {
356        let service = ServiceBuilder::new()
357            .layer(ValidateRequestHeaderLayer::accept("application/json"))
358            .service_fn(echo);
359
360        let request = Request::get("/")
361            .header(header::ACCEPT, "application/json")
362            .body(Body::empty())
363            .unwrap();
364
365        let res = service.call(request).await.unwrap();
366
367        assert_eq!(res.status(), StatusCode::OK);
368    }
369
370    #[tokio::test]
371    async fn valid_accept_header_accept_all_json() {
372        let service = ServiceBuilder::new()
373            .layer(ValidateRequestHeaderLayer::accept("application/json"))
374            .service_fn(echo);
375
376        let request = Request::get("/")
377            .header(header::ACCEPT, "application/*")
378            .body(Body::empty())
379            .unwrap();
380
381        let res = service.call(request).await.unwrap();
382
383        assert_eq!(res.status(), StatusCode::OK);
384    }
385
386    #[tokio::test]
387    async fn valid_accept_header_accept_all() {
388        let service = ServiceBuilder::new()
389            .layer(ValidateRequestHeaderLayer::accept("application/json"))
390            .service_fn(echo);
391
392        let request = Request::get("/")
393            .header(header::ACCEPT, "*/*")
394            .body(Body::empty())
395            .unwrap();
396
397        let res = service.call(request).await.unwrap();
398
399        assert_eq!(res.status(), StatusCode::OK);
400    }
401
402    #[tokio::test]
403    async fn invalid_accept_header() {
404        let service = ServiceBuilder::new()
405            .layer(ValidateRequestHeaderLayer::accept("application/json"))
406            .service_fn(echo);
407
408        let request = Request::get("/")
409            .header(header::ACCEPT, "invalid")
410            .body(Body::empty())
411            .unwrap();
412
413        let res = service.call(request).await.unwrap();
414
415        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
416    }
417    #[tokio::test]
418    async fn not_accepted_accept_header_subtype() {
419        let service = ServiceBuilder::new()
420            .layer(ValidateRequestHeaderLayer::accept("application/json"))
421            .service_fn(echo);
422
423        let request = Request::get("/")
424            .header(header::ACCEPT, "application/strings")
425            .body(Body::empty())
426            .unwrap();
427
428        let res = service.call(request).await.unwrap();
429
430        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
431    }
432
433    #[tokio::test]
434    async fn not_accepted_accept_header() {
435        let service = ServiceBuilder::new()
436            .layer(ValidateRequestHeaderLayer::accept("application/json"))
437            .service_fn(echo);
438
439        let request = Request::get("/")
440            .header(header::ACCEPT, "text/strings")
441            .body(Body::empty())
442            .unwrap();
443
444        let res = service.call(request).await.unwrap();
445
446        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
447    }
448
449    #[tokio::test]
450    async fn accepted_multiple_header_value() {
451        let service = ServiceBuilder::new()
452            .layer(ValidateRequestHeaderLayer::accept("application/json"))
453            .service_fn(echo);
454
455        let request = Request::get("/")
456            .header(header::ACCEPT, "text/strings")
457            .header(header::ACCEPT, "invalid, application/json")
458            .body(Body::empty())
459            .unwrap();
460
461        let res = service.call(request).await.unwrap();
462
463        assert_eq!(res.status(), StatusCode::OK);
464    }
465
466    #[tokio::test]
467    async fn accepted_inner_header_value() {
468        let service = ServiceBuilder::new()
469            .layer(ValidateRequestHeaderLayer::accept("application/json"))
470            .service_fn(echo);
471
472        let request = Request::get("/")
473            .header(header::ACCEPT, "text/strings, invalid, application/json")
474            .body(Body::empty())
475            .unwrap();
476
477        let res = service.call(request).await.unwrap();
478
479        assert_eq!(res.status(), StatusCode::OK);
480    }
481
482    #[tokio::test]
483    async fn accepted_header_with_quotes_valid() {
484        let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
485        let service = ServiceBuilder::new()
486            .layer(ValidateRequestHeaderLayer::accept("application/xml"))
487            .service_fn(echo);
488
489        let request = Request::get("/")
490            .header(header::ACCEPT, value)
491            .body(Body::empty())
492            .unwrap();
493
494        let res = service.call(request).await.unwrap();
495
496        assert_eq!(res.status(), StatusCode::OK);
497    }
498
499    #[tokio::test]
500    async fn accepted_header_with_quotes_invalid() {
501        let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
502        let service = ServiceBuilder::new()
503            .layer(ValidateRequestHeaderLayer::accept("text/html"))
504            .service_fn(echo);
505
506        let request = Request::get("/")
507            .header(header::ACCEPT, value)
508            .body(Body::empty())
509            .unwrap();
510
511        let res = service.call(request).await.unwrap();
512
513        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
514    }
515
516    async fn echo<B>(req: Request<B>) -> Result<Response<B>, BoxError> {
517        Ok(Response::new(req.into_body()))
518    }
519}