Skip to main content

rustauth_axum/
router.rs

1use std::sync::Arc;
2
3use axum::body::Body;
4use axum::extract::State;
5use axum::http::{header, HeaderMap, Request, Uri};
6use axum::response::IntoResponse;
7use axum::routing::any;
8use axum::Router;
9use rustauth::api::RequestBaseUrl;
10use rustauth::auth::oauth::OAuthBaseUrlOverride;
11use rustauth::error::RustAuthError;
12use rustauth::utils::host::is_loopback_host;
13use rustauth::utils::url::{is_valid_forwarded_host, is_valid_forwarded_proto};
14use rustauth::RustAuth;
15
16use crate::error::{internal_error_response, RustAuthAxumError};
17use crate::request::to_api_request;
18use crate::response::from_api_response;
19use crate::RustAuthAxumOptions;
20
21#[derive(Clone)]
22struct RustAuthAxumState {
23    auth: Arc<RustAuth>,
24    options: RustAuthAxumOptions,
25}
26
27/// Convenience extension methods for mounting RustAuth into Axum.
28///
29/// Implemented for [`RustAuth`] and [`Arc<RustAuth>`](rustauth::RustAuth).
30pub trait RustAuthAxumExt {
31    /// Return unmounted RustAuth routes for callers that want to nest manually.
32    fn mount_routes(&self, options: RustAuthAxumOptions) -> Result<Router, RustAuthAxumError>;
33
34    /// Mount RustAuth nested at `RustAuthOptions.base_path`, defaulting to `/api/auth`.
35    fn mount_at_base_path(&self, options: RustAuthAxumOptions)
36        -> Result<Router, RustAuthAxumError>;
37}
38
39impl RustAuthAxumExt for RustAuth {
40    fn mount_routes(&self, options: RustAuthAxumOptions) -> Result<Router, RustAuthAxumError> {
41        routes_from_shared(Arc::new(self.clone()), options)
42    }
43
44    fn mount_at_base_path(
45        &self,
46        options: RustAuthAxumOptions,
47    ) -> Result<Router, RustAuthAxumError> {
48        mount_router_shared(Arc::new(self.clone()), options)
49    }
50}
51
52impl RustAuthAxumExt for Arc<RustAuth> {
53    fn mount_routes(&self, options: RustAuthAxumOptions) -> Result<Router, RustAuthAxumError> {
54        routes_from_shared(Arc::clone(self), options)
55    }
56
57    fn mount_at_base_path(
58        &self,
59        options: RustAuthAxumOptions,
60    ) -> Result<Router, RustAuthAxumError> {
61        mount_router_shared(Arc::clone(self), options)
62    }
63}
64
65fn mount_router_shared(
66    auth: Arc<RustAuth>,
67    options: RustAuthAxumOptions,
68) -> Result<Router, RustAuthAxumError> {
69    validate_base_url_matches_base_path(auth.as_ref())?;
70    let base_path = normalize_base_path(&auth.context().base_path)?;
71    if base_path == "/" {
72        return routes_from_shared(auth, options);
73    }
74    Ok(Router::new().nest(&base_path, routes_from_shared(auth, options)?))
75}
76
77/// Validate that `RustAuthOptions::base_url` and `base_path` are consistent.
78///
79/// Call this before manually nesting [`RustAuthAxumExt::mount_routes`] if you
80/// bypass the fallible mount helpers.
81pub fn validate_mount_config(auth: &RustAuth) -> Result<(), RustAuthAxumError> {
82    validate_base_url_matches_base_path(auth)
83}
84
85fn routes_from_shared(
86    auth: Arc<RustAuth>,
87    options: RustAuthAxumOptions,
88) -> Result<Router, RustAuthAxumError> {
89    validate_mount_config(auth.as_ref())?;
90    Ok(Router::new()
91        .route("/", any(route_handler))
92        .route("/{*path}", any(route_handler))
93        .with_state(RustAuthAxumState { auth, options }))
94}
95
96/// Handle a single Axum request through RustAuth.
97pub async fn handle(
98    auth: &RustAuth,
99    options: RustAuthAxumOptions,
100    request: Request<Body>,
101) -> axum::response::Response {
102    match to_api_request(request, options).await {
103        Ok(mut request) => {
104            maybe_insert_base_url(auth, &mut request, options);
105            match auth.handler_async(request).await {
106                Ok(response) => from_api_response(response),
107                Err(error) => {
108                    log_internal_error(auth, &error);
109                    internal_error_response()
110                }
111            }
112        }
113        Err(response) => response,
114    }
115}
116
117async fn route_handler(
118    State(state): State<RustAuthAxumState>,
119    request: Request<Body>,
120) -> impl IntoResponse {
121    handle(state.auth.as_ref(), state.options, request).await
122}
123
124fn validate_base_url_matches_base_path(auth: &RustAuth) -> Result<(), RustAuthAxumError> {
125    let base_url = auth.context().base_url.as_str();
126    if base_url.is_empty() {
127        return Ok(());
128    }
129
130    let parsed = url::Url::parse(base_url)
131        .map_err(|_| RustAuthAxumError::InvalidBaseUrl(base_url.to_owned()))?;
132    let url_path = trim_path_suffix(parsed.path());
133    let base_path = trim_path_suffix(&auth.context().base_path);
134    if url_path == base_path {
135        return Ok(());
136    }
137
138    Err(RustAuthAxumError::InconsistentBaseUrlPath {
139        url_path,
140        base_path,
141    })
142}
143
144fn trim_path_suffix(path: &str) -> String {
145    let trimmed = path.trim_end_matches('/');
146    if trimmed.is_empty() {
147        "/".to_owned()
148    } else {
149        trimmed.to_owned()
150    }
151}
152
153fn normalize_base_path(base_path: &str) -> Result<String, RustAuthAxumError> {
154    if base_path.is_empty() {
155        return Ok("/".to_owned());
156    }
157    if !is_valid_base_path(base_path) {
158        return Err(RustAuthAxumError::InvalidBasePath(base_path.to_owned()));
159    }
160
161    let trimmed = base_path.trim_end_matches('/');
162    if trimmed.is_empty() {
163        Ok("/".to_owned())
164    } else {
165        Ok(trimmed.to_owned())
166    }
167}
168
169fn maybe_insert_base_url(
170    auth: &RustAuth,
171    request: &mut rustauth::api::ApiRequest,
172    options: RustAuthAxumOptions,
173) {
174    if !options.infer_base_url_from_request
175        || !auth.context().base_url.is_empty()
176        || request.extensions().get::<OAuthBaseUrlOverride>().is_some()
177    {
178        return;
179    }
180
181    if let Some(base_url) = infer_base_url(
182        request.headers(),
183        request.uri(),
184        &auth.context().base_path,
185        options.trust_proxy_headers_for_base_url,
186    ) {
187        request
188            .extensions_mut()
189            .insert(RequestBaseUrl(base_url.clone()));
190        request
191            .extensions_mut()
192            .insert(OAuthBaseUrlOverride(base_url));
193    }
194}
195
196fn infer_base_url(
197    headers: &HeaderMap,
198    uri: &Uri,
199    base_path: &str,
200    trust_proxy_headers: bool,
201) -> Option<String> {
202    let origin = if trust_proxy_headers {
203        forwarded_origin(headers)
204    } else {
205        None
206    }
207    .or_else(|| uri_origin(uri))
208    .or_else(|| host_header_origin(headers))?;
209    Some(with_base_path(origin, base_path))
210}
211
212fn forwarded_origin(headers: &HeaderMap) -> Option<String> {
213    let host = header_str(headers, "x-forwarded-host")?;
214    let proto = header_str(headers, "x-forwarded-proto")?;
215    if !is_valid_forwarded_host(host) || !is_valid_forwarded_proto(proto) {
216        return None;
217    }
218    Some(format!("{}://{}", proto.to_ascii_lowercase(), host))
219}
220
221fn uri_origin(uri: &Uri) -> Option<String> {
222    let scheme = uri.scheme_str()?;
223    if !is_valid_forwarded_proto(scheme) {
224        return None;
225    }
226    let authority = uri.authority()?.as_str();
227    if !is_valid_forwarded_host(authority) {
228        return None;
229    }
230    Some(format!("{}://{}", scheme, authority))
231}
232
233fn host_header_origin(headers: &HeaderMap) -> Option<String> {
234    let host = header_str(headers, header::HOST.as_str())?;
235    if !is_valid_forwarded_host(host) {
236        return None;
237    }
238    let scheme = if is_loopback_host(host) {
239        "http"
240    } else {
241        "https"
242    };
243    Some(format!("{scheme}://{host}"))
244}
245
246fn header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
247    headers.get(name)?.to_str().ok()
248}
249
250fn with_base_path(mut origin: String, base_path: &str) -> String {
251    let base_path = base_path.trim_end_matches('/');
252    if !base_path.is_empty() && base_path != "/" {
253        origin.push_str(base_path);
254    }
255    origin
256}
257
258fn is_valid_base_path(base_path: &str) -> bool {
259    base_path.starts_with('/')
260        && !base_path.contains('?')
261        && !base_path.contains('#')
262        && !base_path.contains('{')
263        && !base_path.contains('}')
264        && !base_path.contains('*')
265}
266
267fn log_internal_error(auth: &RustAuth, error: &RustAuthError) {
268    let message = error.to_string();
269    auth.context()
270        .logger
271        .error("RustAuth Axum handler failed", &[message.as_str()]);
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use axum::http::HeaderValue;
278
279    const SECRET: &str = "test-secret-123456789012345678901234";
280
281    #[test]
282    fn normalize_base_path_trims_trailing_slashes_except_root() -> Result<(), RustAuthAxumError> {
283        assert_eq!(normalize_base_path("")?, "/");
284        assert_eq!(normalize_base_path("/")?, "/");
285        assert_eq!(normalize_base_path("/api/auth/")?, "/api/auth");
286        assert_eq!(normalize_base_path("/api/auth///")?, "/api/auth");
287        Ok(())
288    }
289
290    #[test]
291    fn normalize_base_path_rejects_axum_pattern_syntax_and_non_absolute_paths() {
292        for base_path in [
293            "api/auth",
294            "/api/{auth}",
295            "/api/*auth",
296            "/api/auth?x=1",
297            "/api/auth#x",
298        ] {
299            assert!(matches!(
300                normalize_base_path(base_path),
301                Err(RustAuthAxumError::InvalidBasePath(_))
302            ));
303        }
304    }
305
306    #[test]
307    fn infer_base_url_rejects_malicious_forwarded_headers_and_falls_back_to_host() {
308        let mut headers = HeaderMap::new();
309        headers.insert(
310            "x-forwarded-host",
311            HeaderValue::from_static("javascript:alert(1)"),
312        );
313        headers.insert("x-forwarded-proto", HeaderValue::from_static("http"));
314        headers.insert(header::HOST, HeaderValue::from_static("app.example.com"));
315
316        let base = infer_base_url(
317            &headers,
318            &Uri::from_static("/api/auth/ok"),
319            "/api/auth",
320            true,
321        );
322        assert_eq!(base.as_deref(), Some("https://app.example.com/api/auth"));
323    }
324
325    #[test]
326    fn infer_base_url_uses_forwarded_headers_when_trusted_and_valid() {
327        let mut headers = HeaderMap::new();
328        headers.insert(
329            "x-forwarded-host",
330            HeaderValue::from_static("public.example.com"),
331        );
332        headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
333        headers.insert(header::HOST, HeaderValue::from_static("internal.local"));
334
335        let base = infer_base_url(&headers, &Uri::from_static("/ok"), "/api/auth", true);
336        assert_eq!(base.as_deref(), Some("https://public.example.com/api/auth"));
337    }
338
339    #[test]
340    fn infer_base_url_uses_absolute_request_uri_origin() {
341        let headers = HeaderMap::new();
342        let uri = Uri::from_static("https://app.example.com/api/auth/sign-in/social");
343        let base = infer_base_url(&headers, &uri, "/api/auth", false);
344        assert_eq!(base.as_deref(), Some("https://app.example.com/api/auth"));
345    }
346
347    #[test]
348    fn infer_base_url_uses_http_for_loopback_host_header() {
349        let mut headers = HeaderMap::new();
350        headers.insert(header::HOST, HeaderValue::from_static("127.0.0.1:3000"));
351
352        let base = infer_base_url(&headers, &Uri::from_static("/ok"), "/api/auth", false);
353        assert_eq!(base.as_deref(), Some("http://127.0.0.1:3000/api/auth"));
354    }
355
356    #[tokio::test]
357    async fn validate_base_url_accepts_matching_pathname() -> Result<(), RustAuthError> {
358        let auth = RustAuth::builder()
359            .secret(SECRET)
360            .base_path("/api/auth")
361            .base_url("http://localhost:3000/api/auth/")
362            .build()
363            .await?;
364        assert!(validate_base_url_matches_base_path(&auth).is_ok());
365        Ok(())
366    }
367
368    #[tokio::test]
369    async fn validate_base_url_rejects_mismatched_pathname() -> Result<(), RustAuthError> {
370        let auth = RustAuth::builder()
371            .secret(SECRET)
372            .base_path("/api/auth")
373            .base_url("http://localhost:3000/wrong")
374            .build()
375            .await?;
376        assert!(matches!(
377            validate_base_url_matches_base_path(&auth),
378            Err(RustAuthAxumError::InconsistentBaseUrlPath { .. })
379        ));
380        Ok(())
381    }
382
383    #[tokio::test]
384    async fn validate_base_url_rejects_invalid_absolute_url() -> Result<(), RustAuthError> {
385        let auth = RustAuth::builder()
386            .secret(SECRET)
387            .base_path("/api/auth")
388            .base_url("not-a-url")
389            .build()
390            .await?;
391        assert!(matches!(
392            validate_base_url_matches_base_path(&auth),
393            Err(RustAuthAxumError::InvalidBaseUrl(_))
394        ));
395        Ok(())
396    }
397
398    #[tokio::test]
399    async fn mount_routes_rejects_mismatched_base_url_path() -> Result<(), RustAuthError> {
400        let auth = Arc::new(
401            RustAuth::builder()
402                .secret(SECRET)
403                .base_path("/api/auth")
404                .base_url("http://localhost:3000/wrong")
405                .build()
406                .await?,
407        );
408        assert!(matches!(
409            auth.mount_routes(RustAuthAxumOptions::default()),
410            Err(RustAuthAxumError::InconsistentBaseUrlPath { .. })
411        ));
412        Ok(())
413    }
414
415    #[tokio::test]
416    async fn mount_routes_keeps_shared_auth_available() -> Result<(), RustAuthError> {
417        let auth = Arc::new(RustAuth::builder().secret(SECRET).build().await?);
418        let routes = auth
419            .mount_routes(RustAuthAxumOptions::default())
420            .map_err(|error| RustAuthError::Api(error.to_string()))?;
421        drop(routes);
422
423        assert_eq!(Arc::strong_count(&auth), 1);
424        assert!(validate_mount_config(auth.as_ref()).is_ok());
425        Ok(())
426    }
427
428    #[tokio::test]
429    async fn axum_state_clones_only_the_shared_auth_pointer() -> Result<(), RustAuthError> {
430        let auth = Arc::new(RustAuth::builder().secret(SECRET).build().await?);
431        let state = RustAuthAxumState {
432            auth: Arc::clone(&auth),
433            options: RustAuthAxumOptions::default(),
434        };
435
436        let cloned = state.clone();
437
438        assert_eq!(Arc::strong_count(&auth), 3);
439        drop(cloned);
440        assert_eq!(Arc::strong_count(&auth), 2);
441        Ok(())
442    }
443}