Skip to main content

tower_http/
validate_request.rs

1//! Middleware that validates requests.
2//!
3//! # Example
4//!
5//! Validation of the `Accept` header can be made by using [`ValidateRequestHeaderLayer::accept()`]:
6//!
7//! ```
8//! use tower_http::validate_request::ValidateRequestHeaderLayer;
9//! use http::{Request, Response, StatusCode, header::ACCEPT};
10//! use http_body_util::Full;
11//! use bytes::Bytes;
12//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
13//!
14//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
15//!     Ok(Response::new(Full::default()))
16//! }
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), BoxError> {
20//! let mut service = ServiceBuilder::new()
21//!     // Require the `Accept` header to be `application/json`, `*/*` or `application/*`
22//!     .layer(ValidateRequestHeaderLayer::accept("application/json"))
23//!     .service_fn(handle);
24//!
25//! // Requests with the correct value are allowed through
26//! let request = Request::builder()
27//!     .header(ACCEPT, "application/json")
28//!     .body(Full::default())
29//!     .unwrap();
30//!
31//! let response = service
32//!     .ready()
33//!     .await?
34//!     .call(request)
35//!     .await?;
36//!
37//! assert_eq!(StatusCode::OK, response.status());
38//!
39//! // Requests with an invalid value get a `406 Not Acceptable` response
40//! let request = Request::builder()
41//!     .header(ACCEPT, "text/strings")
42//!     .body(Full::default())
43//!     .unwrap();
44//!
45//! let response = service
46//!     .ready()
47//!     .await?
48//!     .call(request)
49//!     .await?;
50//!
51//! assert_eq!(StatusCode::NOT_ACCEPTABLE, response.status());
52//! # Ok(())
53//! # }
54//! ```
55//!
56//! Validation of a custom header can be made by using [`ValidateRequestHeaderLayer::has_header_value()`]:
57//!
58//! ```
59//! use tower_http::validate_request::ValidateRequestHeaderLayer;
60//! use http::{Request, Response, StatusCode};
61//! use http_body_util::Full;
62//! use bytes::Bytes;
63//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
64//!
65//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
66//!     Ok(Response::new(Full::default()))
67//! }
68//!
69//! # #[tokio::main]
70//! # async fn main() -> Result<(), BoxError> {
71//! let mut service = ServiceBuilder::new()
72//!     // Require a `X-Custom-Header` header to have the value `random-value-1234567890` or reject with a `403 Forbidden` response
73//!     .layer(ValidateRequestHeaderLayer::has_header_value(
74//!         "x-custom-header",
75//!         "random-value-1234567890",
76//!     ).expect("invalid validate header"))
77//!     .service_fn(handle);
78//!
79//! // Requests with the correct value are allowed through
80//! let request = Request::builder()
81//!     .header("x-custom-header", "random-value-1234567890")
82//!     .body(Full::default())
83//!     .unwrap();
84//!
85//! let response = service
86//!     .ready()
87//!     .await?
88//!     .call(request)
89//!     .await?;
90//!
91//! assert_eq!(StatusCode::OK, response.status());
92//!
93//! // Requests with an invalid value get a `403 Forbidden` response
94//! let request = Request::builder()
95//!     .header("x-custom-header", "wrong-value")
96//!     .body(Full::default())
97//!     .unwrap();
98//!
99//! let response = service
100//!     .ready()
101//!     .await?
102//!     .call(request)
103//!     .await?;
104//!
105//! assert_eq!(StatusCode::FORBIDDEN, response.status());
106//! # Ok(())
107//! # }
108//! ```
109//!
110//! To require only that a header is present, use [`ValidateRequestHeaderLayer::custom()`]:
111//!
112//! ```
113//! use tower_http::validate_request::ValidateRequestHeaderLayer;
114//! use http::{Request, Response, StatusCode};
115//! use http_body_util::Full;
116//! use bytes::Bytes;
117//! use tower::{ServiceBuilder, service_fn, BoxError};
118//!
119//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
120//!     Ok(Response::new(Full::default()))
121//! }
122//!
123//! # fn main() {
124//! let service = ServiceBuilder::new()
125//!     .layer(ValidateRequestHeaderLayer::custom(|req: &mut Request<Full<Bytes>>| {
126//!         if req.headers().contains_key("x-custom-header") {
127//!             Ok(())
128//!         } else {
129//!             let mut res = Response::new(Full::<Bytes>::default());
130//!             *res.status_mut() = StatusCode::FORBIDDEN;
131//!             Err(res)
132//!         }
133//!     }))
134//!     .service_fn(handle);
135//! # }
136//! ```
137//!
138//! To serve a custom response when validation fails, also use [`ValidateRequestHeaderLayer::custom()`]:
139//!
140//! ```
141//! use tower_http::validate_request::ValidateRequestHeaderLayer;
142//! use http::{Request, Response, StatusCode};
143//! use http_body_util::Full;
144//! use bytes::Bytes;
145//! use tower::{ServiceBuilder, service_fn, BoxError};
146//!
147//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
148//!     Ok(Response::new(Full::default()))
149//! }
150//!
151//! # fn main() {
152//! let service = ServiceBuilder::new()
153//!     .layer(ValidateRequestHeaderLayer::custom(|req: &mut Request<Full<Bytes>>| {
154//!         match req.headers().get("x-custom-header").map(|v| v.as_bytes()) {
155//!             Some(b"random-value-1234567890") => Ok(()),
156//!             _ => Err(Response::builder()
157//!                 .status(StatusCode::FORBIDDEN)
158//!                 .body(Full::<Bytes>::default())
159//!                 .unwrap()),
160//!         }
161//!     }))
162//!     .service_fn(handle);
163//! # }
164//! ```
165//!
166//! Custom validation can be made by implementing [`ValidateRequest`]:
167//!
168//! ```
169//! use tower_http::validate_request::{ValidateRequestHeaderLayer, ValidateRequest};
170//! use http::{Request, Response, StatusCode, header::ACCEPT};
171//! use http_body_util::Full;
172//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
173//! use bytes::Bytes;
174//!
175//! #[derive(Clone, Copy)]
176//! pub struct MyHeader { /* ...  */ }
177//!
178//! impl<B> ValidateRequest<B> for MyHeader {
179//!     type ResponseBody = Full<Bytes>;
180//!
181//!     fn validate(
182//!         &mut self,
183//!         request: &mut Request<B>,
184//!     ) -> Result<(), Response<Self::ResponseBody>> {
185//!         // validate the request...
186//!         # unimplemented!()
187//!     }
188//! }
189//!
190//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
191//!     Ok(Response::new(Full::default()))
192//! }
193//!
194//!
195//! # #[tokio::main]
196//! # async fn main() -> Result<(), BoxError> {
197//! let service = ServiceBuilder::new()
198//!     // Validate requests using `MyHeader`
199//!     .layer(ValidateRequestHeaderLayer::custom(MyHeader { /* ... */ }))
200//!     .service_fn(handle);
201//! # Ok(())
202//! # }
203//! ```
204//!
205//! [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
206
207use http::header::InvalidHeaderName;
208use http::{header, header::HeaderName, Request, Response, StatusCode};
209use mime::{Mime, MimeIter};
210use pin_project_lite::pin_project;
211use std::{
212    fmt,
213    future::Future,
214    marker::PhantomData,
215    pin::Pin,
216    sync::Arc,
217    task::{Context, Poll},
218};
219use tower_layer::Layer;
220use tower_service::Service;
221
222/// Layer that applies [`ValidateRequestHeader`] which validates all requests.
223///
224/// See the [module docs](crate::validate_request) for an example.
225#[derive(Debug, Clone)]
226pub struct ValidateRequestHeaderLayer<T> {
227    validate: T,
228}
229
230impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
231    /// Validate requests have the required Accept header.
232    ///
233    /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
234    /// as configured.
235    ///
236    /// # Panics
237    ///
238    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
239    /// See `AcceptHeader::new` for when this method panics.
240    ///
241    /// # Example
242    ///
243    /// ```
244    /// use http_body_util::Full;
245    /// use bytes::Bytes;
246    /// use tower_http::validate_request::{AcceptHeader, ValidateRequestHeaderLayer};
247    ///
248    /// let layer = ValidateRequestHeaderLayer::<AcceptHeader<Full<Bytes>>>::accept("application/json");
249    /// ```
250    ///
251    /// [`Accept`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept
252    pub fn accept(value: &str) -> Self
253    where
254        ResBody: Default,
255    {
256        Self::custom(AcceptHeader::new(value))
257    }
258}
259
260impl<ResBody> ValidateRequestHeaderLayer<RequiredHeaderValue<ResBody>> {
261    /// Validate requests have a required header with a specific value.
262    ///
263    /// Rejects with `403 Forbidden` if the header is missing or does not have the expected value.
264    /// Header values that are not valid UTF-8 are treated as non-matching.
265    ///
266    /// If the request contains multiple values for the header, only the first occurrence is
267    /// checked.
268    ///
269    /// # Errors
270    ///
271    /// Returns an error if `expected_header_name` is not a valid HTTP header name per RFC 7230
272    /// (non-empty, at most 32,768 bytes, containing only valid token characters).
273    ///
274    /// # Example
275    ///
276    /// ```
277    /// use http::{Request, Response, StatusCode};
278    /// use http_body_util::Full;
279    /// use bytes::Bytes;
280    /// use tower::{Service, ServiceBuilder, ServiceExt, service_fn};
281    /// use tower_http::validate_request::ValidateRequestHeaderLayer;
282    ///
283    /// async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, std::convert::Infallible> {
284    ///     Ok(Response::new(request.into_body()))
285    /// }
286    ///
287    /// # #[tokio::main]
288    /// # async fn main() {
289    /// let mut service = ServiceBuilder::new()
290    ///     .layer(ValidateRequestHeaderLayer::has_header_value(
291    ///         "x-custom-header",
292    ///         "random-value-1234567890",
293    ///     ).expect("invalid validate header"))
294    ///     .service_fn(handle);
295    ///
296    /// let request = Request::builder()
297    ///     .header("x-custom-header", "random-value-1234567890")
298    ///     .body(Full::default())
299    ///     .unwrap();
300    ///
301    /// let response = service.ready().await.unwrap().call(request).await.unwrap();
302    /// assert_eq!(response.status(), StatusCode::OK);
303    /// # }
304    /// ```
305    pub fn has_header_value(
306        expected_header_name: &str,
307        expected_header_value: &str,
308    ) -> Result<Self, InvalidHeaderName>
309    where
310        ResBody: Default,
311    {
312        Ok(Self::custom(RequiredHeaderValue::new(
313            expected_header_name.parse::<HeaderName>()?,
314            expected_header_value,
315        )))
316    }
317}
318
319impl<T> ValidateRequestHeaderLayer<T> {
320    /// Validate requests using a custom method.
321    pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
322        Self { validate }
323    }
324}
325
326impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
327where
328    T: Clone,
329{
330    type Service = ValidateRequestHeader<S, T>;
331
332    fn layer(&self, inner: S) -> Self::Service {
333        ValidateRequestHeader::new(inner, self.validate.clone())
334    }
335}
336
337/// Middleware that validates requests.
338///
339/// See the [module docs](crate::validate_request) for an example.
340#[derive(Clone, Debug)]
341pub struct ValidateRequestHeader<S, T> {
342    inner: S,
343    validate: T,
344}
345
346impl<S, T> ValidateRequestHeader<S, T> {
347    fn new(inner: S, validate: T) -> Self {
348        Self::custom(inner, validate)
349    }
350
351    define_inner_service_accessors!();
352}
353
354impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
355    /// Validate requests have the required Accept header.
356    ///
357    /// The `Accept` header is required to be `*/*`, `type/*` or `type/subtype`,
358    /// as configured.
359    ///
360    /// # Panics
361    ///
362    /// See `AcceptHeader::new` for when this method panics.
363    pub fn accept(inner: S, value: &str) -> Self
364    where
365        ResBody: Default,
366    {
367        Self::custom(inner, AcceptHeader::new(value))
368    }
369}
370
371impl<S, T> ValidateRequestHeader<S, T> {
372    /// Validate requests using a custom method.
373    pub fn custom(inner: S, validate: T) -> ValidateRequestHeader<S, T> {
374        Self { inner, validate }
375    }
376}
377
378impl<ReqBody, ResBody, S, V> Service<Request<ReqBody>> for ValidateRequestHeader<S, V>
379where
380    V: ValidateRequest<ReqBody, ResponseBody = ResBody>,
381    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
382{
383    type Response = Response<ResBody>;
384    type Error = S::Error;
385    type Future = ResponseFuture<S::Future, ResBody>;
386
387    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
388        self.inner.poll_ready(cx)
389    }
390
391    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
392        match self.validate.validate(&mut req) {
393            Ok(_) => ResponseFuture::future(self.inner.call(req)),
394            Err(res) => ResponseFuture::invalid_header_value(res),
395        }
396    }
397}
398
399pin_project! {
400    /// Response future for [`ValidateRequestHeader`].
401    pub struct ResponseFuture<F, B> {
402        #[pin]
403        kind: Kind<F, B>,
404    }
405}
406
407impl<F, B> ResponseFuture<F, B> {
408    fn future(future: F) -> Self {
409        Self {
410            kind: Kind::Future { future },
411        }
412    }
413
414    fn invalid_header_value(res: Response<B>) -> Self {
415        Self {
416            kind: Kind::Error {
417                response: Some(res),
418            },
419        }
420    }
421}
422
423pin_project! {
424    #[project = KindProj]
425    enum Kind<F, B> {
426        Future {
427            #[pin]
428            future: F,
429        },
430        Error {
431            response: Option<Response<B>>,
432        },
433    }
434}
435
436impl<F, B, E> Future for ResponseFuture<F, B>
437where
438    F: Future<Output = Result<Response<B>, E>>,
439{
440    type Output = F::Output;
441
442    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
443        match self.project().kind.project() {
444            KindProj::Future { future } => future.poll(cx),
445            KindProj::Error { response } => {
446                let response = response.take().expect("future polled after completion");
447                Poll::Ready(Ok(response))
448            }
449        }
450    }
451}
452
453/// Trait for validating requests.
454pub trait ValidateRequest<B> {
455    /// The body type used for responses to unvalidated requests.
456    type ResponseBody;
457
458    /// Validate the request.
459    ///
460    /// If `Ok(())` is returned then the request is allowed through, otherwise not.
461    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>>;
462}
463
464impl<B, F, ResBody> ValidateRequest<B> for F
465where
466    F: FnMut(&mut Request<B>) -> Result<(), Response<ResBody>>,
467{
468    type ResponseBody = ResBody;
469
470    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
471        self(request)
472    }
473}
474
475/// Type that performs validation of the Accept header.
476pub struct AcceptHeader<ResBody> {
477    header_value: Arc<Mime>,
478    _ty: PhantomData<fn() -> ResBody>,
479}
480
481impl<ResBody> AcceptHeader<ResBody> {
482    /// Create a new `AcceptHeader`.
483    ///
484    /// # Panics
485    ///
486    /// Panics if `header_value` is not in the form: `type/subtype`, such as `application/json`
487    fn new(header_value: &str) -> Self
488    where
489        ResBody: Default,
490    {
491        Self {
492            header_value: Arc::new(
493                header_value
494                    .parse::<Mime>()
495                    .expect("value is not a valid header value"),
496            ),
497            _ty: PhantomData,
498        }
499    }
500}
501
502impl<ResBody> Clone for AcceptHeader<ResBody> {
503    fn clone(&self) -> Self {
504        Self {
505            header_value: self.header_value.clone(),
506            _ty: PhantomData,
507        }
508    }
509}
510
511impl<ResBody> fmt::Debug for AcceptHeader<ResBody> {
512    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
513        f.debug_struct("AcceptHeader")
514            .field("header_value", &self.header_value)
515            .finish()
516    }
517}
518
519impl<B, ResBody> ValidateRequest<B> for AcceptHeader<ResBody>
520where
521    ResBody: Default,
522{
523    type ResponseBody = ResBody;
524
525    fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
526        if !req.headers().contains_key(header::ACCEPT) {
527            return Ok(());
528        }
529        if req
530            .headers()
531            .get_all(header::ACCEPT)
532            .into_iter()
533            .filter_map(|header| header.to_str().ok())
534            .any(|h| {
535                MimeIter::new(h)
536                    .map(|mim| {
537                        if let Ok(mim) = mim {
538                            let typ = self.header_value.type_();
539                            let subtype = self.header_value.subtype();
540                            match (mim.type_(), mim.subtype()) {
541                                (t, s) if t == typ && s == subtype => true,
542                                (t, mime::STAR) if t == typ => true,
543                                (mime::STAR, mime::STAR) => true,
544                                _ => false,
545                            }
546                        } else {
547                            false
548                        }
549                    })
550                    .reduce(|acc, mim| acc || mim)
551                    .unwrap_or(false)
552            })
553        {
554            return Ok(());
555        }
556        let mut res = Response::new(ResBody::default());
557        *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
558        Err(res)
559    }
560}
561
562/// Type that rejects requests if a header is not present or does not have an expected value.
563pub struct RequiredHeaderValue<ResBody> {
564    expected_header_name: HeaderName,
565    expected_header_value: Arc<str>,
566    _ty: PhantomData<fn() -> ResBody>,
567}
568
569impl<ResBody> RequiredHeaderValue<ResBody> {
570    fn new(expected_header_name: HeaderName, expected_header_value: &str) -> Self
571    where
572        ResBody: Default,
573    {
574        Self {
575            expected_header_name,
576            expected_header_value: expected_header_value.into(),
577            _ty: PhantomData,
578        }
579    }
580}
581
582impl<ResBody> Clone for RequiredHeaderValue<ResBody> {
583    fn clone(&self) -> Self {
584        Self {
585            expected_header_name: self.expected_header_name.clone(),
586            expected_header_value: self.expected_header_value.clone(),
587            _ty: PhantomData,
588        }
589    }
590}
591
592impl<ResBody> fmt::Debug for RequiredHeaderValue<ResBody> {
593    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
594        f.debug_struct("RequiredHeaderValue")
595            .field("expected_header_name", &self.expected_header_name)
596            .field("expected_header_value", &self.expected_header_value)
597            .finish()
598    }
599}
600
601impl<B, ResBody> ValidateRequest<B> for RequiredHeaderValue<ResBody>
602where
603    ResBody: Default,
604{
605    type ResponseBody = ResBody;
606
607    fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
608        let request_header_value = req
609            .headers()
610            .get(&self.expected_header_name)
611            .and_then(|v| v.to_str().ok());
612
613        if request_header_value != Some(&*self.expected_header_value) {
614            let mut res = Response::new(ResBody::default());
615            *res.status_mut() = StatusCode::FORBIDDEN;
616            return Err(res);
617        }
618
619        Ok(())
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    #[allow(unused_imports)]
626    use super::*;
627    use crate::test_helpers::Body;
628    use http::header;
629    use tower::{BoxError, ServiceBuilder, ServiceExt};
630
631    #[tokio::test]
632    async fn valid_accept_header() {
633        let mut service = ServiceBuilder::new()
634            .layer(ValidateRequestHeaderLayer::accept("application/json"))
635            .service_fn(echo);
636
637        let request = Request::get("/")
638            .header(header::ACCEPT, "application/json")
639            .body(Body::empty())
640            .unwrap();
641
642        let res = service.ready().await.unwrap().call(request).await.unwrap();
643
644        assert_eq!(res.status(), StatusCode::OK);
645    }
646
647    #[tokio::test]
648    async fn valid_accept_header_accept_all_json() {
649        let mut service = ServiceBuilder::new()
650            .layer(ValidateRequestHeaderLayer::accept("application/json"))
651            .service_fn(echo);
652
653        let request = Request::get("/")
654            .header(header::ACCEPT, "application/*")
655            .body(Body::empty())
656            .unwrap();
657
658        let res = service.ready().await.unwrap().call(request).await.unwrap();
659
660        assert_eq!(res.status(), StatusCode::OK);
661    }
662
663    #[tokio::test]
664    async fn valid_accept_header_accept_all() {
665        let mut service = ServiceBuilder::new()
666            .layer(ValidateRequestHeaderLayer::accept("application/json"))
667            .service_fn(echo);
668
669        let request = Request::get("/")
670            .header(header::ACCEPT, "*/*")
671            .body(Body::empty())
672            .unwrap();
673
674        let res = service.ready().await.unwrap().call(request).await.unwrap();
675
676        assert_eq!(res.status(), StatusCode::OK);
677    }
678
679    #[tokio::test]
680    async fn invalid_accept_header() {
681        let mut service = ServiceBuilder::new()
682            .layer(ValidateRequestHeaderLayer::accept("application/json"))
683            .service_fn(echo);
684
685        let request = Request::get("/")
686            .header(header::ACCEPT, "invalid")
687            .body(Body::empty())
688            .unwrap();
689
690        let res = service.ready().await.unwrap().call(request).await.unwrap();
691
692        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
693    }
694    #[tokio::test]
695    async fn not_accepted_accept_header_subtype() {
696        let mut service = ServiceBuilder::new()
697            .layer(ValidateRequestHeaderLayer::accept("application/json"))
698            .service_fn(echo);
699
700        let request = Request::get("/")
701            .header(header::ACCEPT, "application/strings")
702            .body(Body::empty())
703            .unwrap();
704
705        let res = service.ready().await.unwrap().call(request).await.unwrap();
706
707        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
708    }
709
710    #[tokio::test]
711    async fn not_accepted_accept_header() {
712        let mut service = ServiceBuilder::new()
713            .layer(ValidateRequestHeaderLayer::accept("application/json"))
714            .service_fn(echo);
715
716        let request = Request::get("/")
717            .header(header::ACCEPT, "text/strings")
718            .body(Body::empty())
719            .unwrap();
720
721        let res = service.ready().await.unwrap().call(request).await.unwrap();
722
723        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
724    }
725
726    #[tokio::test]
727    async fn accepted_multiple_header_value() {
728        let mut service = ServiceBuilder::new()
729            .layer(ValidateRequestHeaderLayer::accept("application/json"))
730            .service_fn(echo);
731
732        let request = Request::get("/")
733            .header(header::ACCEPT, "text/strings")
734            .header(header::ACCEPT, "invalid, application/json")
735            .body(Body::empty())
736            .unwrap();
737
738        let res = service.ready().await.unwrap().call(request).await.unwrap();
739
740        assert_eq!(res.status(), StatusCode::OK);
741    }
742
743    #[tokio::test]
744    async fn accepted_inner_header_value() {
745        let mut service = ServiceBuilder::new()
746            .layer(ValidateRequestHeaderLayer::accept("application/json"))
747            .service_fn(echo);
748
749        let request = Request::get("/")
750            .header(header::ACCEPT, "text/strings, invalid, application/json")
751            .body(Body::empty())
752            .unwrap();
753
754        let res = service.ready().await.unwrap().call(request).await.unwrap();
755
756        assert_eq!(res.status(), StatusCode::OK);
757    }
758
759    #[tokio::test]
760    async fn accepted_header_with_quotes_valid() {
761        let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
762        let mut service = ServiceBuilder::new()
763            .layer(ValidateRequestHeaderLayer::accept("application/xml"))
764            .service_fn(echo);
765
766        let request = Request::get("/")
767            .header(header::ACCEPT, value)
768            .body(Body::empty())
769            .unwrap();
770
771        let res = service.ready().await.unwrap().call(request).await.unwrap();
772
773        assert_eq!(res.status(), StatusCode::OK);
774    }
775
776    #[tokio::test]
777    async fn accepted_header_with_quotes_invalid() {
778        let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
779        let mut service = ServiceBuilder::new()
780            .layer(ValidateRequestHeaderLayer::accept("text/html"))
781            .service_fn(echo);
782
783        let request = Request::get("/")
784            .header(header::ACCEPT, value)
785            .body(Body::empty())
786            .unwrap();
787
788        let res = service.ready().await.unwrap().call(request).await.unwrap();
789
790        assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
791    }
792
793    #[tokio::test]
794    async fn valid_custom_header() {
795        let mut service = ServiceBuilder::new()
796            .layer(
797                ValidateRequestHeaderLayer::has_header_value(
798                    "x-custom-header",
799                    "random-value-1234567890",
800                )
801                .expect("invalid validate header"),
802            )
803            .service_fn(echo);
804
805        let request = Request::get("/")
806            .header("x-custom-header", "random-value-1234567890")
807            .body(Body::empty())
808            .unwrap();
809
810        let res = service.ready().await.unwrap().call(request).await.unwrap();
811
812        assert_eq!(res.status(), StatusCode::OK);
813    }
814
815    #[tokio::test]
816    async fn invalid_custom_header() {
817        let mut service = ServiceBuilder::new()
818            .layer(
819                ValidateRequestHeaderLayer::has_header_value(
820                    "x-custom-header",
821                    "random-value-1234567890",
822                )
823                .expect("invalid validate header"),
824            )
825            .service_fn(echo);
826
827        let request = Request::get("/")
828            .header("x-custom-header", "wrong-value")
829            .body(Body::empty())
830            .unwrap();
831
832        let res = service.ready().await.unwrap().call(request).await.unwrap();
833
834        assert_eq!(res.status(), StatusCode::FORBIDDEN);
835    }
836
837    #[tokio::test]
838    async fn missing_custom_header() {
839        let mut service = ServiceBuilder::new()
840            .layer(
841                ValidateRequestHeaderLayer::has_header_value(
842                    "x-custom-header",
843                    "random-value-1234567890",
844                )
845                .expect("invalid validate header"),
846            )
847            .service_fn(echo);
848
849        let request = Request::get("/").body(Body::empty()).unwrap();
850
851        let res = service.ready().await.unwrap().call(request).await.unwrap();
852
853        assert_eq!(res.status(), StatusCode::FORBIDDEN);
854    }
855
856    #[tokio::test]
857    async fn custom_header_multiple_values_uses_first() {
858        let mut service = ServiceBuilder::new()
859            .layer(
860                ValidateRequestHeaderLayer::has_header_value("x-custom-header", "correct-value")
861                    .expect("invalid validate header"),
862            )
863            .service_fn(echo);
864
865        // First value matches: should pass
866        let request = Request::get("/")
867            .header("x-custom-header", "correct-value")
868            .header("x-custom-header", "other-value")
869            .body(Body::empty())
870            .unwrap();
871
872        let res = service.ready().await.unwrap().call(request).await.unwrap();
873        assert_eq!(res.status(), StatusCode::OK);
874
875        // First value does not match: should reject even if second matches
876        let request = Request::get("/")
877            .header("x-custom-header", "wrong-value")
878            .header("x-custom-header", "correct-value")
879            .body(Body::empty())
880            .unwrap();
881
882        let res = service.ready().await.unwrap().call(request).await.unwrap();
883        assert_eq!(res.status(), StatusCode::FORBIDDEN);
884    }
885
886    #[test]
887    fn invalid_header_name_returns_error() {
888        let result = ValidateRequestHeaderLayer::<RequiredHeaderValue<Body>>::has_header_value(
889            "invalid header name with spaces",
890            "value",
891        );
892        assert!(result.is_err());
893    }
894
895    #[tokio::test]
896    async fn custom_header_non_utf8_value_rejects() {
897        let mut service = ServiceBuilder::new()
898            .layer(
899                ValidateRequestHeaderLayer::has_header_value("x-custom-header", "expected-value")
900                    .expect("invalid validate header"),
901            )
902            .service_fn(echo);
903
904        let request = Request::get("/")
905            .header("x-custom-header", b"\xff\xfe".as_slice())
906            .body(Body::empty())
907            .unwrap();
908
909        let res = service.ready().await.unwrap().call(request).await.unwrap();
910        assert_eq!(res.status(), StatusCode::FORBIDDEN);
911    }
912
913    #[tokio::test]
914    async fn custom_header_name_is_case_insensitive() {
915        let mut service = ServiceBuilder::new()
916            .layer(
917                ValidateRequestHeaderLayer::has_header_value("x-custom-header", "my-value")
918                    .expect("invalid validate header"),
919            )
920            .service_fn(echo);
921
922        let request = Request::get("/")
923            .header("X-Custom-Header", "my-value")
924            .body(Body::empty())
925            .unwrap();
926
927        let res = service.ready().await.unwrap().call(request).await.unwrap();
928        assert_eq!(res.status(), StatusCode::OK);
929    }
930
931    #[tokio::test]
932    async fn custom_header_value_is_case_sensitive() {
933        let mut service = ServiceBuilder::new()
934            .layer(
935                ValidateRequestHeaderLayer::has_header_value("x-custom-header", "My-Value")
936                    .expect("invalid validate header"),
937            )
938            .service_fn(echo);
939
940        let request = Request::get("/")
941            .header("x-custom-header", "my-value")
942            .body(Body::empty())
943            .unwrap();
944
945        let res = service.ready().await.unwrap().call(request).await.unwrap();
946        assert_eq!(res.status(), StatusCode::FORBIDDEN);
947    }
948
949    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
950        Ok(Response::new(req.into_body()))
951    }
952}