s2_common/
http.rs

1/// An HTTP header that can be parsed from a UTF8 string value.
2pub trait ParseableHeader: std::str::FromStr
3where
4    Self::Err: std::fmt::Display,
5{
6    fn name() -> &'static http::HeaderName;
7}
8
9#[cfg(feature = "axum")]
10pub mod extract {
11    use axum::{
12        extract::{FromRequestParts, OptionalFromRequestParts},
13        response::{IntoResponse, Response},
14    };
15
16    #[derive(Debug, thiserror::Error)]
17    pub enum HeaderRejection {
18        #[error("Missing header `{0}`")]
19        MissingHeader(&'static http::HeaderName),
20        #[error("Invalid header `{0}`: not UTF-8")]
21        InvalidUtf8(&'static http::HeaderName),
22        #[error("Invalid header `{0}`: {1}")]
23        InvalidHeaderValue(&'static http::HeaderName, String),
24    }
25
26    impl IntoResponse for HeaderRejection {
27        fn into_response(self) -> Response {
28            (http::StatusCode::BAD_REQUEST, self.to_string()).into_response()
29        }
30    }
31
32    pub fn parse_header<T>(headers: &http::HeaderMap) -> Result<T, HeaderRejection>
33    where
34        T: super::ParseableHeader,
35        T::Err: std::fmt::Display,
36    {
37        let name = T::name();
38        let Some(value) = headers.get(name) else {
39            return Err(HeaderRejection::MissingHeader(name));
40        };
41        let value_str = value
42            .to_str()
43            .map_err(|_| HeaderRejection::InvalidUtf8(name))?;
44        let parsed = value_str
45            .parse::<T>()
46            .map_err(|e| HeaderRejection::InvalidHeaderValue(name, e.to_string()))?;
47        Ok(parsed)
48    }
49
50    #[derive(Debug, Clone)]
51    pub struct Header<T>(pub T);
52
53    impl<S, T> FromRequestParts<S> for Header<T>
54    where
55        S: Send + Sync,
56        T: super::ParseableHeader,
57        T::Err: std::fmt::Display,
58    {
59        type Rejection = HeaderRejection;
60
61        async fn from_request_parts(
62            parts: &mut http::request::Parts,
63            _state: &S,
64        ) -> Result<Self, Self::Rejection> {
65            parse_header(&parts.headers).map(Self)
66        }
67    }
68
69    impl<S, T> OptionalFromRequestParts<S> for Header<T>
70    where
71        S: Send + Sync,
72        T: super::ParseableHeader,
73        T::Err: std::fmt::Display,
74    {
75        type Rejection = HeaderRejection;
76
77        async fn from_request_parts(
78            parts: &mut http::request::Parts,
79            _state: &S,
80        ) -> Result<Option<Self>, Self::Rejection> {
81            match parse_header(&parts.headers) {
82                Ok(value) => Ok(Some(Header(value))),
83                Err(HeaderRejection::MissingHeader(_)) => Ok(None),
84                Err(e) => Err(e),
85            }
86        }
87    }
88
89    /// Workaround for https://github.com/tokio-rs/axum/issues/3623
90    pub struct HeaderOpt<T>(pub Option<T>);
91
92    impl<S, T> FromRequestParts<S> for HeaderOpt<T>
93    where
94        S: Send + Sync,
95        T: super::ParseableHeader,
96        T::Err: std::fmt::Display,
97    {
98        type Rejection = HeaderRejection;
99
100        async fn from_request_parts(
101            parts: &mut http::request::Parts,
102            _state: &S,
103        ) -> Result<Self, Self::Rejection> {
104            match parse_header(&parts.headers) {
105                Ok(value) => Ok(Self(Some(value))),
106                Err(HeaderRejection::MissingHeader(_)) => Ok(Self(None)),
107                Err(e) => Err(e),
108            }
109        }
110    }
111}