switchgear_service/axum/extract/
scheme.rs1use axum::extract::FromRef;
2use axum::http::header;
3use axum::{extract::FromRequestParts, http::request::Parts};
4use std::convert::Infallible;
5
6#[derive(Debug, Clone)]
7pub struct Scheme(pub String);
8
9impl<S> FromRequestParts<S> for Scheme
10where
11 S: Send + Sync,
12 Scheme: FromRef<S>,
13{
14 type Rejection = Infallible;
15
16 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
17 if let Some(proto) = parts
18 .headers
19 .get(header::FORWARDED)
20 .and_then(parse_forwarded_proto)
21 {
22 return Ok(Scheme(proto));
23 }
24 if let Some(proto) = parts
25 .headers
26 .get("x-forwarded-proto")
27 .and_then(|value| value.to_str().ok())
28 {
29 return Ok(Scheme(proto.to_string()));
30 }
31
32 Ok(Scheme::from_ref(state))
33 }
34}
35
36fn parse_forwarded_proto(forwarded: &header::HeaderValue) -> Option<String> {
37 forwarded.to_str().ok()?.split(';').find_map(|s| {
38 let s = s.trim().to_lowercase();
39 if s.starts_with("proto=") {
40 s.split('=').next_back().map(|c| c.to_string())
41 } else {
42 None
43 }
44 })
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50 use axum::extract::FromRef;
51 use axum::http::{HeaderMap, HeaderValue, Request};
52
53 #[derive(Clone)]
54 struct TestState {
55 default_scheme: Scheme,
56 }
57
58 impl FromRef<TestState> for Scheme {
59 fn from_ref(state: &TestState) -> Self {
60 state.default_scheme.clone()
61 }
62 }
63
64 async fn extract_scheme(headers: HeaderMap, state: TestState) -> Scheme {
65 let request = Request::builder().uri("/").body(()).unwrap();
66
67 let (mut parts, _) = request.into_parts();
68 parts.headers = headers;
69
70 Scheme::from_request_parts(&mut parts, &state)
71 .await
72 .unwrap()
73 }
74
75 #[tokio::test]
76 async fn test_forwarded_header_with_proto() {
77 let mut headers = HeaderMap::new();
78 headers.insert(
79 header::FORWARDED,
80 HeaderValue::from_static("proto=https;host=example.com"),
81 );
82
83 let state = TestState {
84 default_scheme: Scheme("http".to_string()),
85 };
86
87 let scheme = extract_scheme(headers, state).await;
88 assert_eq!(scheme.0, "https");
89 }
90
91 #[tokio::test]
92 async fn test_x_forwarded_proto_header() {
93 let mut headers = HeaderMap::new();
94 headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
95
96 let state = TestState {
97 default_scheme: Scheme("http".to_string()),
98 };
99
100 let scheme = extract_scheme(headers, state).await;
101 assert_eq!(scheme.0, "https");
102 }
103
104 #[tokio::test]
105 async fn test_both_headers_forwarded_takes_precedence() {
106 let mut headers = HeaderMap::new();
107 headers.insert(header::FORWARDED, HeaderValue::from_static("proto=wss"));
108 headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
109
110 let state = TestState {
111 default_scheme: Scheme("http".to_string()),
112 };
113
114 let scheme = extract_scheme(headers, state).await;
115 assert_eq!(scheme.0, "wss");
116 }
117
118 #[tokio::test]
119 async fn test_fallback_to_state() {
120 let headers = HeaderMap::new();
121
122 let state = TestState {
123 default_scheme: Scheme("https".to_string()),
124 };
125
126 let scheme = extract_scheme(headers, state).await;
127 assert_eq!(scheme.0, "https");
128 }
129
130 #[tokio::test]
131 async fn test_invalid_forwarded_header_fallback() {
132 let mut headers = HeaderMap::new();
133 headers.insert(
134 header::FORWARDED,
135 HeaderValue::from_static("invalid-forwarded-header"),
136 );
137 headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
138
139 let state = TestState {
140 default_scheme: Scheme("http".to_string()),
141 };
142
143 let scheme = extract_scheme(headers, state).await;
144 assert_eq!(scheme.0, "https");
145 }
146
147 #[tokio::test]
148 async fn test_forwarded_header_without_proto() {
149 let mut headers = HeaderMap::new();
150 headers.insert(
151 header::FORWARDED,
152 HeaderValue::from_static("for=192.0.2.60;host=example.com"),
153 );
154
155 let state = TestState {
156 default_scheme: Scheme("https".to_string()),
157 };
158
159 let scheme = extract_scheme(headers, state).await;
160 assert_eq!(scheme.0, "https");
161 }
162}