switchgear_service/axum/extract/
scheme.rs

1use axum::extract::FromRef;
2use axum::http::header;
3use axum::{extract::FromRequestParts, http::request::Parts};
4use std::convert::Infallible;
5
6#[derive(Debug, Clone)]
7pub struct Scheme(pub String);
8
9impl<S> FromRequestParts<S> for Scheme
10where
11    S: Send + Sync,
12    Scheme: FromRef<S>,
13{
14    type Rejection = Infallible;
15
16    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
17        if let Some(proto) = parts
18            .headers
19            .get(header::FORWARDED)
20            .and_then(parse_forwarded_proto)
21        {
22            return Ok(Scheme(proto));
23        }
24        if let Some(proto) = parts
25            .headers
26            .get("x-forwarded-proto")
27            .and_then(|value| value.to_str().ok())
28        {
29            return Ok(Scheme(proto.to_string()));
30        }
31
32        Ok(Scheme::from_ref(state))
33    }
34}
35
36fn parse_forwarded_proto(forwarded: &header::HeaderValue) -> Option<String> {
37    forwarded.to_str().ok()?.split(';').find_map(|s| {
38        let s = s.trim().to_lowercase();
39        if s.starts_with("proto=") {
40            s.split('=').next_back().map(|c| c.to_string())
41        } else {
42            None
43        }
44    })
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use axum::extract::FromRef;
51    use axum::http::{HeaderMap, HeaderValue, Request};
52
53    #[derive(Clone)]
54    struct TestState {
55        default_scheme: Scheme,
56    }
57
58    impl FromRef<TestState> for Scheme {
59        fn from_ref(state: &TestState) -> Self {
60            state.default_scheme.clone()
61        }
62    }
63
64    async fn extract_scheme(headers: HeaderMap, state: TestState) -> Scheme {
65        let request = Request::builder().uri("/").body(()).unwrap();
66
67        let (mut parts, _) = request.into_parts();
68        parts.headers = headers;
69
70        Scheme::from_request_parts(&mut parts, &state)
71            .await
72            .unwrap()
73    }
74
75    #[tokio::test]
76    async fn test_forwarded_header_with_proto() {
77        let mut headers = HeaderMap::new();
78        headers.insert(
79            header::FORWARDED,
80            HeaderValue::from_static("proto=https;host=example.com"),
81        );
82
83        let state = TestState {
84            default_scheme: Scheme("http".to_string()),
85        };
86
87        let scheme = extract_scheme(headers, state).await;
88        assert_eq!(scheme.0, "https");
89    }
90
91    #[tokio::test]
92    async fn test_x_forwarded_proto_header() {
93        let mut headers = HeaderMap::new();
94        headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
95
96        let state = TestState {
97            default_scheme: Scheme("http".to_string()),
98        };
99
100        let scheme = extract_scheme(headers, state).await;
101        assert_eq!(scheme.0, "https");
102    }
103
104    #[tokio::test]
105    async fn test_both_headers_forwarded_takes_precedence() {
106        let mut headers = HeaderMap::new();
107        headers.insert(header::FORWARDED, HeaderValue::from_static("proto=wss"));
108        headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
109
110        let state = TestState {
111            default_scheme: Scheme("http".to_string()),
112        };
113
114        let scheme = extract_scheme(headers, state).await;
115        assert_eq!(scheme.0, "wss");
116    }
117
118    #[tokio::test]
119    async fn test_fallback_to_state() {
120        let headers = HeaderMap::new();
121
122        let state = TestState {
123            default_scheme: Scheme("https".to_string()),
124        };
125
126        let scheme = extract_scheme(headers, state).await;
127        assert_eq!(scheme.0, "https");
128    }
129
130    #[tokio::test]
131    async fn test_invalid_forwarded_header_fallback() {
132        let mut headers = HeaderMap::new();
133        headers.insert(
134            header::FORWARDED,
135            HeaderValue::from_static("invalid-forwarded-header"),
136        );
137        headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
138
139        let state = TestState {
140            default_scheme: Scheme("http".to_string()),
141        };
142
143        let scheme = extract_scheme(headers, state).await;
144        assert_eq!(scheme.0, "https");
145    }
146
147    #[tokio::test]
148    async fn test_forwarded_header_without_proto() {
149        let mut headers = HeaderMap::new();
150        headers.insert(
151            header::FORWARDED,
152            HeaderValue::from_static("for=192.0.2.60;host=example.com"),
153        );
154
155        let state = TestState {
156            default_scheme: Scheme("https".to_string()),
157        };
158
159        let scheme = extract_scheme(headers, state).await;
160        assert_eq!(scheme.0, "https");
161    }
162}