1use std::sync::Arc;
2
3use axum::body::Body;
4use axum::extract::State;
5use axum::http::{header, HeaderMap, Request, Uri};
6use axum::response::IntoResponse;
7use axum::routing::any;
8use axum::Router;
9use rustauth::api::RequestBaseUrl;
10use rustauth::auth::oauth::OAuthBaseUrlOverride;
11use rustauth::error::RustAuthError;
12use rustauth::utils::host::is_loopback_host;
13use rustauth::utils::url::{is_valid_forwarded_host, is_valid_forwarded_proto};
14use rustauth::RustAuth;
15
16use crate::error::{internal_error_response, RustAuthAxumError};
17use crate::request::to_api_request;
18use crate::response::from_api_response;
19use crate::RustAuthAxumOptions;
20
21#[derive(Clone)]
22struct RustAuthAxumState {
23 auth: Arc<RustAuth>,
24 options: RustAuthAxumOptions,
25}
26
27pub trait RustAuthAxumExt {
31 fn mount_routes(&self, options: RustAuthAxumOptions) -> Result<Router, RustAuthAxumError>;
33
34 fn mount_at_base_path(&self, options: RustAuthAxumOptions)
36 -> Result<Router, RustAuthAxumError>;
37}
38
39impl RustAuthAxumExt for RustAuth {
40 fn mount_routes(&self, options: RustAuthAxumOptions) -> Result<Router, RustAuthAxumError> {
41 routes_from_shared(Arc::new(self.clone()), options)
42 }
43
44 fn mount_at_base_path(
45 &self,
46 options: RustAuthAxumOptions,
47 ) -> Result<Router, RustAuthAxumError> {
48 mount_router_shared(Arc::new(self.clone()), options)
49 }
50}
51
52impl RustAuthAxumExt for Arc<RustAuth> {
53 fn mount_routes(&self, options: RustAuthAxumOptions) -> Result<Router, RustAuthAxumError> {
54 routes_from_shared(Arc::clone(self), options)
55 }
56
57 fn mount_at_base_path(
58 &self,
59 options: RustAuthAxumOptions,
60 ) -> Result<Router, RustAuthAxumError> {
61 mount_router_shared(Arc::clone(self), options)
62 }
63}
64
65fn mount_router_shared(
66 auth: Arc<RustAuth>,
67 options: RustAuthAxumOptions,
68) -> Result<Router, RustAuthAxumError> {
69 validate_base_url_matches_base_path(auth.as_ref())?;
70 let base_path = normalize_base_path(&auth.context().base_path)?;
71 if base_path == "/" {
72 return routes_from_shared(auth, options);
73 }
74 Ok(Router::new().nest(&base_path, routes_from_shared(auth, options)?))
75}
76
77pub fn validate_mount_config(auth: &RustAuth) -> Result<(), RustAuthAxumError> {
82 validate_base_url_matches_base_path(auth)
83}
84
85fn routes_from_shared(
86 auth: Arc<RustAuth>,
87 options: RustAuthAxumOptions,
88) -> Result<Router, RustAuthAxumError> {
89 validate_mount_config(auth.as_ref())?;
90 Ok(Router::new()
91 .route("/", any(route_handler))
92 .route("/{*path}", any(route_handler))
93 .with_state(RustAuthAxumState { auth, options }))
94}
95
96pub async fn handle(
98 auth: &RustAuth,
99 options: RustAuthAxumOptions,
100 request: Request<Body>,
101) -> axum::response::Response {
102 match to_api_request(request, options).await {
103 Ok(mut request) => {
104 maybe_insert_base_url(auth, &mut request, options);
105 match auth.handler_async(request).await {
106 Ok(response) => from_api_response(response),
107 Err(error) => {
108 log_internal_error(auth, &error);
109 internal_error_response()
110 }
111 }
112 }
113 Err(response) => response,
114 }
115}
116
117async fn route_handler(
118 State(state): State<RustAuthAxumState>,
119 request: Request<Body>,
120) -> impl IntoResponse {
121 handle(state.auth.as_ref(), state.options, request).await
122}
123
124fn validate_base_url_matches_base_path(auth: &RustAuth) -> Result<(), RustAuthAxumError> {
125 let base_url = auth.context().base_url.as_str();
126 if base_url.is_empty() {
127 return Ok(());
128 }
129
130 let parsed = url::Url::parse(base_url)
131 .map_err(|_| RustAuthAxumError::InvalidBaseUrl(base_url.to_owned()))?;
132 let url_path = trim_path_suffix(parsed.path());
133 let base_path = trim_path_suffix(&auth.context().base_path);
134 if url_path == base_path {
135 return Ok(());
136 }
137
138 Err(RustAuthAxumError::InconsistentBaseUrlPath {
139 url_path,
140 base_path,
141 })
142}
143
144fn trim_path_suffix(path: &str) -> String {
145 let trimmed = path.trim_end_matches('/');
146 if trimmed.is_empty() {
147 "/".to_owned()
148 } else {
149 trimmed.to_owned()
150 }
151}
152
153fn normalize_base_path(base_path: &str) -> Result<String, RustAuthAxumError> {
154 if base_path.is_empty() {
155 return Ok("/".to_owned());
156 }
157 if !is_valid_base_path(base_path) {
158 return Err(RustAuthAxumError::InvalidBasePath(base_path.to_owned()));
159 }
160
161 let trimmed = base_path.trim_end_matches('/');
162 if trimmed.is_empty() {
163 Ok("/".to_owned())
164 } else {
165 Ok(trimmed.to_owned())
166 }
167}
168
169fn maybe_insert_base_url(
170 auth: &RustAuth,
171 request: &mut rustauth::api::ApiRequest,
172 options: RustAuthAxumOptions,
173) {
174 if !options.infer_base_url_from_request
175 || !auth.context().base_url.is_empty()
176 || request.extensions().get::<OAuthBaseUrlOverride>().is_some()
177 {
178 return;
179 }
180
181 if let Some(base_url) = infer_base_url(
182 request.headers(),
183 request.uri(),
184 &auth.context().base_path,
185 options.trust_proxy_headers_for_base_url,
186 ) {
187 request
188 .extensions_mut()
189 .insert(RequestBaseUrl(base_url.clone()));
190 request
191 .extensions_mut()
192 .insert(OAuthBaseUrlOverride(base_url));
193 }
194}
195
196fn infer_base_url(
197 headers: &HeaderMap,
198 uri: &Uri,
199 base_path: &str,
200 trust_proxy_headers: bool,
201) -> Option<String> {
202 let origin = if trust_proxy_headers {
203 forwarded_origin(headers)
204 } else {
205 None
206 }
207 .or_else(|| uri_origin(uri))
208 .or_else(|| host_header_origin(headers))?;
209 Some(with_base_path(origin, base_path))
210}
211
212fn forwarded_origin(headers: &HeaderMap) -> Option<String> {
213 let host = header_str(headers, "x-forwarded-host")?;
214 let proto = header_str(headers, "x-forwarded-proto")?;
215 if !is_valid_forwarded_host(host) || !is_valid_forwarded_proto(proto) {
216 return None;
217 }
218 Some(format!("{}://{}", proto.to_ascii_lowercase(), host))
219}
220
221fn uri_origin(uri: &Uri) -> Option<String> {
222 let scheme = uri.scheme_str()?;
223 if !is_valid_forwarded_proto(scheme) {
224 return None;
225 }
226 let authority = uri.authority()?.as_str();
227 if !is_valid_forwarded_host(authority) {
228 return None;
229 }
230 Some(format!("{}://{}", scheme, authority))
231}
232
233fn host_header_origin(headers: &HeaderMap) -> Option<String> {
234 let host = header_str(headers, header::HOST.as_str())?;
235 if !is_valid_forwarded_host(host) {
236 return None;
237 }
238 let scheme = if is_loopback_host(host) {
239 "http"
240 } else {
241 "https"
242 };
243 Some(format!("{scheme}://{host}"))
244}
245
246fn header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
247 headers.get(name)?.to_str().ok()
248}
249
250fn with_base_path(mut origin: String, base_path: &str) -> String {
251 let base_path = base_path.trim_end_matches('/');
252 if !base_path.is_empty() && base_path != "/" {
253 origin.push_str(base_path);
254 }
255 origin
256}
257
258fn is_valid_base_path(base_path: &str) -> bool {
259 base_path.starts_with('/')
260 && !base_path.contains('?')
261 && !base_path.contains('#')
262 && !base_path.contains('{')
263 && !base_path.contains('}')
264 && !base_path.contains('*')
265}
266
267fn log_internal_error(auth: &RustAuth, error: &RustAuthError) {
268 let message = error.to_string();
269 auth.context()
270 .logger
271 .error("RustAuth Axum handler failed", &[message.as_str()]);
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use axum::http::HeaderValue;
278
279 const SECRET: &str = "test-secret-123456789012345678901234";
280
281 #[test]
282 fn normalize_base_path_trims_trailing_slashes_except_root() -> Result<(), RustAuthAxumError> {
283 assert_eq!(normalize_base_path("")?, "/");
284 assert_eq!(normalize_base_path("/")?, "/");
285 assert_eq!(normalize_base_path("/api/auth/")?, "/api/auth");
286 assert_eq!(normalize_base_path("/api/auth///")?, "/api/auth");
287 Ok(())
288 }
289
290 #[test]
291 fn normalize_base_path_rejects_axum_pattern_syntax_and_non_absolute_paths() {
292 for base_path in [
293 "api/auth",
294 "/api/{auth}",
295 "/api/*auth",
296 "/api/auth?x=1",
297 "/api/auth#x",
298 ] {
299 assert!(matches!(
300 normalize_base_path(base_path),
301 Err(RustAuthAxumError::InvalidBasePath(_))
302 ));
303 }
304 }
305
306 #[test]
307 fn infer_base_url_rejects_malicious_forwarded_headers_and_falls_back_to_host() {
308 let mut headers = HeaderMap::new();
309 headers.insert(
310 "x-forwarded-host",
311 HeaderValue::from_static("javascript:alert(1)"),
312 );
313 headers.insert("x-forwarded-proto", HeaderValue::from_static("http"));
314 headers.insert(header::HOST, HeaderValue::from_static("app.example.com"));
315
316 let base = infer_base_url(
317 &headers,
318 &Uri::from_static("/api/auth/ok"),
319 "/api/auth",
320 true,
321 );
322 assert_eq!(base.as_deref(), Some("https://app.example.com/api/auth"));
323 }
324
325 #[test]
326 fn infer_base_url_uses_forwarded_headers_when_trusted_and_valid() {
327 let mut headers = HeaderMap::new();
328 headers.insert(
329 "x-forwarded-host",
330 HeaderValue::from_static("public.example.com"),
331 );
332 headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
333 headers.insert(header::HOST, HeaderValue::from_static("internal.local"));
334
335 let base = infer_base_url(&headers, &Uri::from_static("/ok"), "/api/auth", true);
336 assert_eq!(base.as_deref(), Some("https://public.example.com/api/auth"));
337 }
338
339 #[test]
340 fn infer_base_url_uses_absolute_request_uri_origin() {
341 let headers = HeaderMap::new();
342 let uri = Uri::from_static("https://app.example.com/api/auth/sign-in/social");
343 let base = infer_base_url(&headers, &uri, "/api/auth", false);
344 assert_eq!(base.as_deref(), Some("https://app.example.com/api/auth"));
345 }
346
347 #[test]
348 fn infer_base_url_uses_http_for_loopback_host_header() {
349 let mut headers = HeaderMap::new();
350 headers.insert(header::HOST, HeaderValue::from_static("127.0.0.1:3000"));
351
352 let base = infer_base_url(&headers, &Uri::from_static("/ok"), "/api/auth", false);
353 assert_eq!(base.as_deref(), Some("http://127.0.0.1:3000/api/auth"));
354 }
355
356 #[tokio::test]
357 async fn validate_base_url_accepts_matching_pathname() -> Result<(), RustAuthError> {
358 let auth = RustAuth::builder()
359 .secret(SECRET)
360 .base_path("/api/auth")
361 .base_url("http://localhost:3000/api/auth/")
362 .build()
363 .await?;
364 assert!(validate_base_url_matches_base_path(&auth).is_ok());
365 Ok(())
366 }
367
368 #[tokio::test]
369 async fn validate_base_url_rejects_mismatched_pathname() -> Result<(), RustAuthError> {
370 let auth = RustAuth::builder()
371 .secret(SECRET)
372 .base_path("/api/auth")
373 .base_url("http://localhost:3000/wrong")
374 .build()
375 .await?;
376 assert!(matches!(
377 validate_base_url_matches_base_path(&auth),
378 Err(RustAuthAxumError::InconsistentBaseUrlPath { .. })
379 ));
380 Ok(())
381 }
382
383 #[tokio::test]
384 async fn validate_base_url_rejects_invalid_absolute_url() -> Result<(), RustAuthError> {
385 let auth = RustAuth::builder()
386 .secret(SECRET)
387 .base_path("/api/auth")
388 .base_url("not-a-url")
389 .build()
390 .await?;
391 assert!(matches!(
392 validate_base_url_matches_base_path(&auth),
393 Err(RustAuthAxumError::InvalidBaseUrl(_))
394 ));
395 Ok(())
396 }
397
398 #[tokio::test]
399 async fn mount_routes_rejects_mismatched_base_url_path() -> Result<(), RustAuthError> {
400 let auth = Arc::new(
401 RustAuth::builder()
402 .secret(SECRET)
403 .base_path("/api/auth")
404 .base_url("http://localhost:3000/wrong")
405 .build()
406 .await?,
407 );
408 assert!(matches!(
409 auth.mount_routes(RustAuthAxumOptions::default()),
410 Err(RustAuthAxumError::InconsistentBaseUrlPath { .. })
411 ));
412 Ok(())
413 }
414
415 #[tokio::test]
416 async fn mount_routes_keeps_shared_auth_available() -> Result<(), RustAuthError> {
417 let auth = Arc::new(RustAuth::builder().secret(SECRET).build().await?);
418 let routes = auth
419 .mount_routes(RustAuthAxumOptions::default())
420 .map_err(|error| RustAuthError::Api(error.to_string()))?;
421 drop(routes);
422
423 assert_eq!(Arc::strong_count(&auth), 1);
424 assert!(validate_mount_config(auth.as_ref()).is_ok());
425 Ok(())
426 }
427
428 #[tokio::test]
429 async fn axum_state_clones_only_the_shared_auth_pointer() -> Result<(), RustAuthError> {
430 let auth = Arc::new(RustAuth::builder().secret(SECRET).build().await?);
431 let state = RustAuthAxumState {
432 auth: Arc::clone(&auth),
433 options: RustAuthAxumOptions::default(),
434 };
435
436 let cloned = state.clone();
437
438 assert_eq!(Arc::strong_count(&auth), 3);
439 drop(cloned);
440 assert_eq!(Arc::strong_count(&auth), 2);
441 Ok(())
442 }
443}