1pub trait ParseableHeader: std::str::FromStr
3where
4 Self::Err: std::fmt::Display,
5{
6 fn name() -> &'static http::HeaderName;
7}
8
9#[cfg(feature = "axum")]
10pub mod extract {
11 use axum::{
12 extract::{FromRequestParts, OptionalFromRequestParts},
13 response::{IntoResponse, Response},
14 };
15
16 #[derive(Debug, thiserror::Error)]
17 pub enum HeaderRejection {
18 #[error("Missing header `{0}`")]
19 MissingHeader(&'static http::HeaderName),
20 #[error("Invalid header `{0}`: not UTF-8")]
21 InvalidUtf8(&'static http::HeaderName),
22 #[error("Invalid header `{0}`: {1}")]
23 InvalidHeaderValue(&'static http::HeaderName, String),
24 }
25
26 impl IntoResponse for HeaderRejection {
27 fn into_response(self) -> Response {
28 (http::StatusCode::BAD_REQUEST, self.to_string()).into_response()
29 }
30 }
31
32 pub fn parse_header<T>(headers: &http::HeaderMap) -> Result<T, HeaderRejection>
33 where
34 T: super::ParseableHeader,
35 T::Err: std::fmt::Display,
36 {
37 let name = T::name();
38 let Some(value) = headers.get(name) else {
39 return Err(HeaderRejection::MissingHeader(name));
40 };
41 let value_str = value
42 .to_str()
43 .map_err(|_| HeaderRejection::InvalidUtf8(name))?;
44 let parsed = value_str
45 .parse::<T>()
46 .map_err(|e| HeaderRejection::InvalidHeaderValue(name, e.to_string()))?;
47 Ok(parsed)
48 }
49
50 #[derive(Debug, Clone)]
51 pub struct Header<T>(pub T);
52
53 impl<S, T> FromRequestParts<S> for Header<T>
54 where
55 S: Send + Sync,
56 T: super::ParseableHeader,
57 T::Err: std::fmt::Display,
58 {
59 type Rejection = HeaderRejection;
60
61 async fn from_request_parts(
62 parts: &mut http::request::Parts,
63 _state: &S,
64 ) -> Result<Self, Self::Rejection> {
65 parse_header(&parts.headers).map(Self)
66 }
67 }
68
69 impl<S, T> OptionalFromRequestParts<S> for Header<T>
70 where
71 S: Send + Sync,
72 T: super::ParseableHeader,
73 T::Err: std::fmt::Display,
74 {
75 type Rejection = HeaderRejection;
76
77 async fn from_request_parts(
78 parts: &mut http::request::Parts,
79 _state: &S,
80 ) -> Result<Option<Self>, Self::Rejection> {
81 match parse_header(&parts.headers) {
82 Ok(value) => Ok(Some(Header(value))),
83 Err(HeaderRejection::MissingHeader(_)) => Ok(None),
84 Err(e) => Err(e),
85 }
86 }
87 }
88
89 pub struct HeaderOpt<T>(pub Option<T>);
91
92 impl<S, T> FromRequestParts<S> for HeaderOpt<T>
93 where
94 S: Send + Sync,
95 T: super::ParseableHeader,
96 T::Err: std::fmt::Display,
97 {
98 type Rejection = HeaderRejection;
99
100 async fn from_request_parts(
101 parts: &mut http::request::Parts,
102 _state: &S,
103 ) -> Result<Self, Self::Rejection> {
104 match parse_header(&parts.headers) {
105 Ok(value) => Ok(Self(Some(value))),
106 Err(HeaderRejection::MissingHeader(_)) => Ok(Self(None)),
107 Err(e) => Err(e),
108 }
109 }
110 }
111}