1use http::header::InvalidHeaderName;
208use http::{header, header::HeaderName, Request, Response, StatusCode};
209use mime::{Mime, MimeIter};
210use pin_project_lite::pin_project;
211use std::{
212 fmt,
213 future::Future,
214 marker::PhantomData,
215 pin::Pin,
216 sync::Arc,
217 task::{Context, Poll},
218};
219use tower_layer::Layer;
220use tower_service::Service;
221
222#[derive(Debug, Clone)]
226pub struct ValidateRequestHeaderLayer<T> {
227 validate: T,
228}
229
230impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
231 pub fn accept(value: &str) -> Self
253 where
254 ResBody: Default,
255 {
256 Self::custom(AcceptHeader::new(value))
257 }
258}
259
260impl<ResBody> ValidateRequestHeaderLayer<RequiredHeaderValue<ResBody>> {
261 pub fn has_header_value(
306 expected_header_name: &str,
307 expected_header_value: &str,
308 ) -> Result<Self, InvalidHeaderName>
309 where
310 ResBody: Default,
311 {
312 Ok(Self::custom(RequiredHeaderValue::new(
313 expected_header_name.parse::<HeaderName>()?,
314 expected_header_value,
315 )))
316 }
317}
318
319impl<T> ValidateRequestHeaderLayer<T> {
320 pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
322 Self { validate }
323 }
324}
325
326impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
327where
328 T: Clone,
329{
330 type Service = ValidateRequestHeader<S, T>;
331
332 fn layer(&self, inner: S) -> Self::Service {
333 ValidateRequestHeader::new(inner, self.validate.clone())
334 }
335}
336
337#[derive(Clone, Debug)]
341pub struct ValidateRequestHeader<S, T> {
342 inner: S,
343 validate: T,
344}
345
346impl<S, T> ValidateRequestHeader<S, T> {
347 fn new(inner: S, validate: T) -> Self {
348 Self::custom(inner, validate)
349 }
350
351 define_inner_service_accessors!();
352}
353
354impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
355 pub fn accept(inner: S, value: &str) -> Self
364 where
365 ResBody: Default,
366 {
367 Self::custom(inner, AcceptHeader::new(value))
368 }
369}
370
371impl<S, T> ValidateRequestHeader<S, T> {
372 pub fn custom(inner: S, validate: T) -> ValidateRequestHeader<S, T> {
374 Self { inner, validate }
375 }
376}
377
378impl<ReqBody, ResBody, S, V> Service<Request<ReqBody>> for ValidateRequestHeader<S, V>
379where
380 V: ValidateRequest<ReqBody, ResponseBody = ResBody>,
381 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
382{
383 type Response = Response<ResBody>;
384 type Error = S::Error;
385 type Future = ResponseFuture<S::Future, ResBody>;
386
387 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
388 self.inner.poll_ready(cx)
389 }
390
391 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
392 match self.validate.validate(&mut req) {
393 Ok(_) => ResponseFuture::future(self.inner.call(req)),
394 Err(res) => ResponseFuture::invalid_header_value(res),
395 }
396 }
397}
398
399pin_project! {
400 pub struct ResponseFuture<F, B> {
402 #[pin]
403 kind: Kind<F, B>,
404 }
405}
406
407impl<F, B> ResponseFuture<F, B> {
408 fn future(future: F) -> Self {
409 Self {
410 kind: Kind::Future { future },
411 }
412 }
413
414 fn invalid_header_value(res: Response<B>) -> Self {
415 Self {
416 kind: Kind::Error {
417 response: Some(res),
418 },
419 }
420 }
421}
422
423pin_project! {
424 #[project = KindProj]
425 enum Kind<F, B> {
426 Future {
427 #[pin]
428 future: F,
429 },
430 Error {
431 response: Option<Response<B>>,
432 },
433 }
434}
435
436impl<F, B, E> Future for ResponseFuture<F, B>
437where
438 F: Future<Output = Result<Response<B>, E>>,
439{
440 type Output = F::Output;
441
442 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
443 match self.project().kind.project() {
444 KindProj::Future { future } => future.poll(cx),
445 KindProj::Error { response } => {
446 let response = response.take().expect("future polled after completion");
447 Poll::Ready(Ok(response))
448 }
449 }
450 }
451}
452
453pub trait ValidateRequest<B> {
455 type ResponseBody;
457
458 fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>>;
462}
463
464impl<B, F, ResBody> ValidateRequest<B> for F
465where
466 F: FnMut(&mut Request<B>) -> Result<(), Response<ResBody>>,
467{
468 type ResponseBody = ResBody;
469
470 fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
471 self(request)
472 }
473}
474
475pub struct AcceptHeader<ResBody> {
477 header_value: Arc<Mime>,
478 _ty: PhantomData<fn() -> ResBody>,
479}
480
481impl<ResBody> AcceptHeader<ResBody> {
482 fn new(header_value: &str) -> Self
488 where
489 ResBody: Default,
490 {
491 Self {
492 header_value: Arc::new(
493 header_value
494 .parse::<Mime>()
495 .expect("value is not a valid header value"),
496 ),
497 _ty: PhantomData,
498 }
499 }
500}
501
502impl<ResBody> Clone for AcceptHeader<ResBody> {
503 fn clone(&self) -> Self {
504 Self {
505 header_value: self.header_value.clone(),
506 _ty: PhantomData,
507 }
508 }
509}
510
511impl<ResBody> fmt::Debug for AcceptHeader<ResBody> {
512 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
513 f.debug_struct("AcceptHeader")
514 .field("header_value", &self.header_value)
515 .finish()
516 }
517}
518
519impl<B, ResBody> ValidateRequest<B> for AcceptHeader<ResBody>
520where
521 ResBody: Default,
522{
523 type ResponseBody = ResBody;
524
525 fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
526 if !req.headers().contains_key(header::ACCEPT) {
527 return Ok(());
528 }
529 if req
530 .headers()
531 .get_all(header::ACCEPT)
532 .into_iter()
533 .filter_map(|header| header.to_str().ok())
534 .any(|h| {
535 MimeIter::new(h)
536 .map(|mim| {
537 if let Ok(mim) = mim {
538 let typ = self.header_value.type_();
539 let subtype = self.header_value.subtype();
540 match (mim.type_(), mim.subtype()) {
541 (t, s) if t == typ && s == subtype => true,
542 (t, mime::STAR) if t == typ => true,
543 (mime::STAR, mime::STAR) => true,
544 _ => false,
545 }
546 } else {
547 false
548 }
549 })
550 .reduce(|acc, mim| acc || mim)
551 .unwrap_or(false)
552 })
553 {
554 return Ok(());
555 }
556 let mut res = Response::new(ResBody::default());
557 *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
558 Err(res)
559 }
560}
561
562pub struct RequiredHeaderValue<ResBody> {
564 expected_header_name: HeaderName,
565 expected_header_value: Arc<str>,
566 _ty: PhantomData<fn() -> ResBody>,
567}
568
569impl<ResBody> RequiredHeaderValue<ResBody> {
570 fn new(expected_header_name: HeaderName, expected_header_value: &str) -> Self
571 where
572 ResBody: Default,
573 {
574 Self {
575 expected_header_name,
576 expected_header_value: expected_header_value.into(),
577 _ty: PhantomData,
578 }
579 }
580}
581
582impl<ResBody> Clone for RequiredHeaderValue<ResBody> {
583 fn clone(&self) -> Self {
584 Self {
585 expected_header_name: self.expected_header_name.clone(),
586 expected_header_value: self.expected_header_value.clone(),
587 _ty: PhantomData,
588 }
589 }
590}
591
592impl<ResBody> fmt::Debug for RequiredHeaderValue<ResBody> {
593 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
594 f.debug_struct("RequiredHeaderValue")
595 .field("expected_header_name", &self.expected_header_name)
596 .field("expected_header_value", &self.expected_header_value)
597 .finish()
598 }
599}
600
601impl<B, ResBody> ValidateRequest<B> for RequiredHeaderValue<ResBody>
602where
603 ResBody: Default,
604{
605 type ResponseBody = ResBody;
606
607 fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
608 let request_header_value = req
609 .headers()
610 .get(&self.expected_header_name)
611 .and_then(|v| v.to_str().ok());
612
613 if request_header_value != Some(&*self.expected_header_value) {
614 let mut res = Response::new(ResBody::default());
615 *res.status_mut() = StatusCode::FORBIDDEN;
616 return Err(res);
617 }
618
619 Ok(())
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 #[allow(unused_imports)]
626 use super::*;
627 use crate::test_helpers::Body;
628 use http::header;
629 use tower::{BoxError, ServiceBuilder, ServiceExt};
630
631 #[tokio::test]
632 async fn valid_accept_header() {
633 let mut service = ServiceBuilder::new()
634 .layer(ValidateRequestHeaderLayer::accept("application/json"))
635 .service_fn(echo);
636
637 let request = Request::get("/")
638 .header(header::ACCEPT, "application/json")
639 .body(Body::empty())
640 .unwrap();
641
642 let res = service.ready().await.unwrap().call(request).await.unwrap();
643
644 assert_eq!(res.status(), StatusCode::OK);
645 }
646
647 #[tokio::test]
648 async fn valid_accept_header_accept_all_json() {
649 let mut service = ServiceBuilder::new()
650 .layer(ValidateRequestHeaderLayer::accept("application/json"))
651 .service_fn(echo);
652
653 let request = Request::get("/")
654 .header(header::ACCEPT, "application/*")
655 .body(Body::empty())
656 .unwrap();
657
658 let res = service.ready().await.unwrap().call(request).await.unwrap();
659
660 assert_eq!(res.status(), StatusCode::OK);
661 }
662
663 #[tokio::test]
664 async fn valid_accept_header_accept_all() {
665 let mut service = ServiceBuilder::new()
666 .layer(ValidateRequestHeaderLayer::accept("application/json"))
667 .service_fn(echo);
668
669 let request = Request::get("/")
670 .header(header::ACCEPT, "*/*")
671 .body(Body::empty())
672 .unwrap();
673
674 let res = service.ready().await.unwrap().call(request).await.unwrap();
675
676 assert_eq!(res.status(), StatusCode::OK);
677 }
678
679 #[tokio::test]
680 async fn invalid_accept_header() {
681 let mut service = ServiceBuilder::new()
682 .layer(ValidateRequestHeaderLayer::accept("application/json"))
683 .service_fn(echo);
684
685 let request = Request::get("/")
686 .header(header::ACCEPT, "invalid")
687 .body(Body::empty())
688 .unwrap();
689
690 let res = service.ready().await.unwrap().call(request).await.unwrap();
691
692 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
693 }
694 #[tokio::test]
695 async fn not_accepted_accept_header_subtype() {
696 let mut service = ServiceBuilder::new()
697 .layer(ValidateRequestHeaderLayer::accept("application/json"))
698 .service_fn(echo);
699
700 let request = Request::get("/")
701 .header(header::ACCEPT, "application/strings")
702 .body(Body::empty())
703 .unwrap();
704
705 let res = service.ready().await.unwrap().call(request).await.unwrap();
706
707 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
708 }
709
710 #[tokio::test]
711 async fn not_accepted_accept_header() {
712 let mut service = ServiceBuilder::new()
713 .layer(ValidateRequestHeaderLayer::accept("application/json"))
714 .service_fn(echo);
715
716 let request = Request::get("/")
717 .header(header::ACCEPT, "text/strings")
718 .body(Body::empty())
719 .unwrap();
720
721 let res = service.ready().await.unwrap().call(request).await.unwrap();
722
723 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
724 }
725
726 #[tokio::test]
727 async fn accepted_multiple_header_value() {
728 let mut service = ServiceBuilder::new()
729 .layer(ValidateRequestHeaderLayer::accept("application/json"))
730 .service_fn(echo);
731
732 let request = Request::get("/")
733 .header(header::ACCEPT, "text/strings")
734 .header(header::ACCEPT, "invalid, application/json")
735 .body(Body::empty())
736 .unwrap();
737
738 let res = service.ready().await.unwrap().call(request).await.unwrap();
739
740 assert_eq!(res.status(), StatusCode::OK);
741 }
742
743 #[tokio::test]
744 async fn accepted_inner_header_value() {
745 let mut service = ServiceBuilder::new()
746 .layer(ValidateRequestHeaderLayer::accept("application/json"))
747 .service_fn(echo);
748
749 let request = Request::get("/")
750 .header(header::ACCEPT, "text/strings, invalid, application/json")
751 .body(Body::empty())
752 .unwrap();
753
754 let res = service.ready().await.unwrap().call(request).await.unwrap();
755
756 assert_eq!(res.status(), StatusCode::OK);
757 }
758
759 #[tokio::test]
760 async fn accepted_header_with_quotes_valid() {
761 let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
762 let mut service = ServiceBuilder::new()
763 .layer(ValidateRequestHeaderLayer::accept("application/xml"))
764 .service_fn(echo);
765
766 let request = Request::get("/")
767 .header(header::ACCEPT, value)
768 .body(Body::empty())
769 .unwrap();
770
771 let res = service.ready().await.unwrap().call(request).await.unwrap();
772
773 assert_eq!(res.status(), StatusCode::OK);
774 }
775
776 #[tokio::test]
777 async fn accepted_header_with_quotes_invalid() {
778 let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
779 let mut service = ServiceBuilder::new()
780 .layer(ValidateRequestHeaderLayer::accept("text/html"))
781 .service_fn(echo);
782
783 let request = Request::get("/")
784 .header(header::ACCEPT, value)
785 .body(Body::empty())
786 .unwrap();
787
788 let res = service.ready().await.unwrap().call(request).await.unwrap();
789
790 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
791 }
792
793 #[tokio::test]
794 async fn valid_custom_header() {
795 let mut service = ServiceBuilder::new()
796 .layer(
797 ValidateRequestHeaderLayer::has_header_value(
798 "x-custom-header",
799 "random-value-1234567890",
800 )
801 .expect("invalid validate header"),
802 )
803 .service_fn(echo);
804
805 let request = Request::get("/")
806 .header("x-custom-header", "random-value-1234567890")
807 .body(Body::empty())
808 .unwrap();
809
810 let res = service.ready().await.unwrap().call(request).await.unwrap();
811
812 assert_eq!(res.status(), StatusCode::OK);
813 }
814
815 #[tokio::test]
816 async fn invalid_custom_header() {
817 let mut service = ServiceBuilder::new()
818 .layer(
819 ValidateRequestHeaderLayer::has_header_value(
820 "x-custom-header",
821 "random-value-1234567890",
822 )
823 .expect("invalid validate header"),
824 )
825 .service_fn(echo);
826
827 let request = Request::get("/")
828 .header("x-custom-header", "wrong-value")
829 .body(Body::empty())
830 .unwrap();
831
832 let res = service.ready().await.unwrap().call(request).await.unwrap();
833
834 assert_eq!(res.status(), StatusCode::FORBIDDEN);
835 }
836
837 #[tokio::test]
838 async fn missing_custom_header() {
839 let mut service = ServiceBuilder::new()
840 .layer(
841 ValidateRequestHeaderLayer::has_header_value(
842 "x-custom-header",
843 "random-value-1234567890",
844 )
845 .expect("invalid validate header"),
846 )
847 .service_fn(echo);
848
849 let request = Request::get("/").body(Body::empty()).unwrap();
850
851 let res = service.ready().await.unwrap().call(request).await.unwrap();
852
853 assert_eq!(res.status(), StatusCode::FORBIDDEN);
854 }
855
856 #[tokio::test]
857 async fn custom_header_multiple_values_uses_first() {
858 let mut service = ServiceBuilder::new()
859 .layer(
860 ValidateRequestHeaderLayer::has_header_value("x-custom-header", "correct-value")
861 .expect("invalid validate header"),
862 )
863 .service_fn(echo);
864
865 let request = Request::get("/")
867 .header("x-custom-header", "correct-value")
868 .header("x-custom-header", "other-value")
869 .body(Body::empty())
870 .unwrap();
871
872 let res = service.ready().await.unwrap().call(request).await.unwrap();
873 assert_eq!(res.status(), StatusCode::OK);
874
875 let request = Request::get("/")
877 .header("x-custom-header", "wrong-value")
878 .header("x-custom-header", "correct-value")
879 .body(Body::empty())
880 .unwrap();
881
882 let res = service.ready().await.unwrap().call(request).await.unwrap();
883 assert_eq!(res.status(), StatusCode::FORBIDDEN);
884 }
885
886 #[test]
887 fn invalid_header_name_returns_error() {
888 let result = ValidateRequestHeaderLayer::<RequiredHeaderValue<Body>>::has_header_value(
889 "invalid header name with spaces",
890 "value",
891 );
892 assert!(result.is_err());
893 }
894
895 #[tokio::test]
896 async fn custom_header_non_utf8_value_rejects() {
897 let mut service = ServiceBuilder::new()
898 .layer(
899 ValidateRequestHeaderLayer::has_header_value("x-custom-header", "expected-value")
900 .expect("invalid validate header"),
901 )
902 .service_fn(echo);
903
904 let request = Request::get("/")
905 .header("x-custom-header", b"\xff\xfe".as_slice())
906 .body(Body::empty())
907 .unwrap();
908
909 let res = service.ready().await.unwrap().call(request).await.unwrap();
910 assert_eq!(res.status(), StatusCode::FORBIDDEN);
911 }
912
913 #[tokio::test]
914 async fn custom_header_name_is_case_insensitive() {
915 let mut service = ServiceBuilder::new()
916 .layer(
917 ValidateRequestHeaderLayer::has_header_value("x-custom-header", "my-value")
918 .expect("invalid validate header"),
919 )
920 .service_fn(echo);
921
922 let request = Request::get("/")
923 .header("X-Custom-Header", "my-value")
924 .body(Body::empty())
925 .unwrap();
926
927 let res = service.ready().await.unwrap().call(request).await.unwrap();
928 assert_eq!(res.status(), StatusCode::OK);
929 }
930
931 #[tokio::test]
932 async fn custom_header_value_is_case_sensitive() {
933 let mut service = ServiceBuilder::new()
934 .layer(
935 ValidateRequestHeaderLayer::has_header_value("x-custom-header", "My-Value")
936 .expect("invalid validate header"),
937 )
938 .service_fn(echo);
939
940 let request = Request::get("/")
941 .header("x-custom-header", "my-value")
942 .body(Body::empty())
943 .unwrap();
944
945 let res = service.ready().await.unwrap().call(request).await.unwrap();
946 assert_eq!(res.status(), StatusCode::FORBIDDEN);
947 }
948
949 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
950 Ok(Response::new(req.into_body()))
951 }
952}