Skip to main content

rustauth_actix_web/
router.rs

1use std::sync::Arc;
2
3use actix_web::web::{self, Payload};
4use actix_web::{HttpRequest, HttpResponse, Scope};
5use http::{header, HeaderMap, Uri};
6use rustauth::api::RequestBaseUrl;
7use rustauth::auth::oauth::OAuthBaseUrlOverride;
8use rustauth::error::RustAuthError;
9use rustauth::utils::host::is_loopback_host;
10use rustauth::utils::url::{is_valid_forwarded_host, is_valid_forwarded_proto};
11use rustauth::RustAuth;
12
13use crate::error::{internal_error_response, RustAuthActixWebError};
14use crate::request::to_api_request;
15use crate::response::from_api_response;
16use crate::RustAuthActixWebOptions;
17
18#[derive(Clone)]
19struct RustAuthActixWebState {
20    auth: Arc<RustAuth>,
21    options: RustAuthActixWebOptions,
22}
23
24/// Convenience extension methods for mounting RustAuth into Actix Web.
25///
26/// Implemented for [`RustAuth`] and [`Arc<RustAuth>`](rustauth::RustAuth).
27pub trait RustAuthActixWebExt {
28    /// Return unmounted RustAuth routes for callers that want to nest manually.
29    fn mount_routes(
30        &self,
31        options: RustAuthActixWebOptions,
32    ) -> Result<Scope, RustAuthActixWebError>;
33
34    /// Mount RustAuth nested at `RustAuthOptions.base_path`, defaulting to `/api/auth`.
35    fn mount_at_base_path(
36        &self,
37        options: RustAuthActixWebOptions,
38    ) -> Result<Scope, RustAuthActixWebError>;
39}
40
41impl RustAuthActixWebExt for RustAuth {
42    fn mount_routes(
43        &self,
44        options: RustAuthActixWebOptions,
45    ) -> Result<Scope, RustAuthActixWebError> {
46        routes_from_shared(Arc::new(self.clone()), options)
47    }
48
49    fn mount_at_base_path(
50        &self,
51        options: RustAuthActixWebOptions,
52    ) -> Result<Scope, RustAuthActixWebError> {
53        mount_scope_shared(Arc::new(self.clone()), options)
54    }
55}
56
57impl RustAuthActixWebExt for Arc<RustAuth> {
58    fn mount_routes(
59        &self,
60        options: RustAuthActixWebOptions,
61    ) -> Result<Scope, RustAuthActixWebError> {
62        routes_from_shared(Arc::clone(self), options)
63    }
64
65    fn mount_at_base_path(
66        &self,
67        options: RustAuthActixWebOptions,
68    ) -> Result<Scope, RustAuthActixWebError> {
69        mount_scope_shared(Arc::clone(self), options)
70    }
71}
72
73fn mount_scope_shared(
74    auth: Arc<RustAuth>,
75    options: RustAuthActixWebOptions,
76) -> Result<Scope, RustAuthActixWebError> {
77    validate_base_url_matches_base_path(auth.as_ref())?;
78    let base_path = normalize_base_path(&auth.context().base_path)?;
79    if base_path == "/" {
80        return routes_from_shared(auth, options);
81    }
82    Ok(web::scope(&base_path).service(routes_from_shared(auth, options)?))
83}
84
85/// Validate that `RustAuthOptions::base_url` and `base_path` are consistent.
86///
87/// Call this before manually nesting [`RustAuthActixWebExt::mount_routes`] if you
88/// bypass the fallible mount helpers.
89pub fn validate_mount_config(auth: &RustAuth) -> Result<(), RustAuthActixWebError> {
90    validate_base_url_matches_base_path(auth)
91}
92
93fn routes_from_shared(
94    auth: Arc<RustAuth>,
95    options: RustAuthActixWebOptions,
96) -> Result<Scope, RustAuthActixWebError> {
97    validate_mount_config(auth.as_ref())?;
98    Ok(web::scope("")
99        .app_data(web::Data::new(RustAuthActixWebState { auth, options }))
100        .default_service(web::route().to(route_handler)))
101}
102
103/// Handle a single Actix Web request through RustAuth.
104pub async fn handle(
105    auth: &RustAuth,
106    options: RustAuthActixWebOptions,
107    request: HttpRequest,
108    payload: Payload,
109) -> HttpResponse {
110    match to_api_request(request, payload, options).await {
111        Ok(mut request) => {
112            maybe_insert_base_url(auth, &mut request, options);
113            match auth.handler_async(request).await {
114                Ok(response) => from_api_response(response),
115                Err(error) => {
116                    log_internal_error(auth, &error);
117                    internal_error_response()
118                }
119            }
120        }
121        Err(response) => response,
122    }
123}
124
125async fn route_handler(
126    state: web::Data<RustAuthActixWebState>,
127    request: HttpRequest,
128    payload: Payload,
129) -> HttpResponse {
130    handle(state.auth.as_ref(), state.options, request, payload).await
131}
132
133fn validate_base_url_matches_base_path(auth: &RustAuth) -> Result<(), RustAuthActixWebError> {
134    let base_url = auth.context().base_url.as_str();
135    if base_url.is_empty() {
136        return Ok(());
137    }
138
139    let parsed = url::Url::parse(base_url)
140        .map_err(|_| RustAuthActixWebError::InvalidBaseUrl(base_url.to_owned()))?;
141    let url_path = trim_path_suffix(parsed.path());
142    let base_path = trim_path_suffix(&auth.context().base_path);
143    if url_path == base_path {
144        return Ok(());
145    }
146
147    Err(RustAuthActixWebError::InconsistentBaseUrlPath {
148        url_path,
149        base_path,
150    })
151}
152
153fn trim_path_suffix(path: &str) -> String {
154    let trimmed = path.trim_end_matches('/');
155    if trimmed.is_empty() {
156        "/".to_owned()
157    } else {
158        trimmed.to_owned()
159    }
160}
161
162fn normalize_base_path(base_path: &str) -> Result<String, RustAuthActixWebError> {
163    if base_path.is_empty() {
164        return Ok("/".to_owned());
165    }
166    if !is_valid_base_path(base_path) {
167        return Err(RustAuthActixWebError::InvalidBasePath(base_path.to_owned()));
168    }
169
170    let trimmed = base_path.trim_end_matches('/');
171    if trimmed.is_empty() {
172        Ok("/".to_owned())
173    } else {
174        Ok(trimmed.to_owned())
175    }
176}
177
178fn maybe_insert_base_url(
179    auth: &RustAuth,
180    request: &mut rustauth::api::ApiRequest,
181    options: RustAuthActixWebOptions,
182) {
183    if !options.infer_base_url_from_request
184        || !auth.context().base_url.is_empty()
185        || request.extensions().get::<OAuthBaseUrlOverride>().is_some()
186    {
187        return;
188    }
189
190    if let Some(base_url) = infer_base_url(
191        request.headers(),
192        request.uri(),
193        &auth.context().base_path,
194        options.trust_proxy_headers_for_base_url,
195    ) {
196        request
197            .extensions_mut()
198            .insert(RequestBaseUrl(base_url.clone()));
199        request
200            .extensions_mut()
201            .insert(OAuthBaseUrlOverride(base_url));
202    }
203}
204
205fn infer_base_url(
206    headers: &HeaderMap,
207    uri: &Uri,
208    base_path: &str,
209    trust_proxy_headers: bool,
210) -> Option<String> {
211    let origin = if trust_proxy_headers {
212        forwarded_origin(headers)
213    } else {
214        None
215    }
216    .or_else(|| uri_origin(uri))
217    .or_else(|| host_header_origin(headers))?;
218    Some(with_base_path(origin, base_path))
219}
220
221fn forwarded_origin(headers: &HeaderMap) -> Option<String> {
222    let host = header_str(headers, "x-forwarded-host")?;
223    let proto = header_str(headers, "x-forwarded-proto")?;
224    if !is_valid_forwarded_host(host) || !is_valid_forwarded_proto(proto) {
225        return None;
226    }
227    Some(format!("{}://{}", proto.to_ascii_lowercase(), host))
228}
229
230fn uri_origin(uri: &Uri) -> Option<String> {
231    let scheme = uri.scheme_str()?;
232    if !is_valid_forwarded_proto(scheme) {
233        return None;
234    }
235    let authority = uri.authority()?.as_str();
236    if !is_valid_forwarded_host(authority) {
237        return None;
238    }
239    Some(format!("{}://{}", scheme, authority))
240}
241
242fn host_header_origin(headers: &HeaderMap) -> Option<String> {
243    let host = header_str(headers, header::HOST.as_str())?;
244    if !is_valid_forwarded_host(host) {
245        return None;
246    }
247    let scheme = if is_loopback_host(host) {
248        "http"
249    } else {
250        "https"
251    };
252    Some(format!("{scheme}://{host}"))
253}
254
255fn header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
256    headers.get(name)?.to_str().ok()
257}
258
259fn with_base_path(mut origin: String, base_path: &str) -> String {
260    let base_path = base_path.trim_end_matches('/');
261    if !base_path.is_empty() && base_path != "/" {
262        origin.push_str(base_path);
263    }
264    origin
265}
266
267fn is_valid_base_path(base_path: &str) -> bool {
268    base_path.starts_with('/')
269        && !base_path.contains('?')
270        && !base_path.contains('#')
271        && !base_path.contains('{')
272        && !base_path.contains('}')
273        && !base_path.contains('*')
274}
275
276fn log_internal_error(auth: &RustAuth, error: &RustAuthError) {
277    let message = error.to_string();
278    auth.context()
279        .logger
280        .error("RustAuth Actix Web handler failed", &[message.as_str()]);
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use http::header::HeaderValue;
287
288    const SECRET: &str = "test-secret-123456789012345678901234";
289
290    #[test]
291    fn normalize_base_path_trims_trailing_slashes_except_root() -> Result<(), RustAuthActixWebError>
292    {
293        assert_eq!(normalize_base_path("")?, "/");
294        assert_eq!(normalize_base_path("/")?, "/");
295        assert_eq!(normalize_base_path("/api/auth/")?, "/api/auth");
296        assert_eq!(normalize_base_path("/api/auth///")?, "/api/auth");
297        Ok(())
298    }
299
300    #[test]
301    fn normalize_base_path_rejects_pattern_syntax_and_non_absolute_paths() {
302        for base_path in [
303            "api/auth",
304            "/api/{auth}",
305            "/api/*auth",
306            "/api/auth?x=1",
307            "/api/auth#x",
308        ] {
309            assert!(matches!(
310                normalize_base_path(base_path),
311                Err(RustAuthActixWebError::InvalidBasePath(_))
312            ));
313        }
314    }
315
316    #[test]
317    fn infer_base_url_rejects_malicious_forwarded_headers_and_falls_back_to_host() {
318        let mut headers = HeaderMap::new();
319        headers.insert(
320            "x-forwarded-host",
321            HeaderValue::from_static("javascript:alert(1)"),
322        );
323        headers.insert("x-forwarded-proto", HeaderValue::from_static("http"));
324        headers.insert(header::HOST, HeaderValue::from_static("app.example.com"));
325
326        let base = infer_base_url(
327            &headers,
328            &Uri::from_static("/api/auth/ok"),
329            "/api/auth",
330            true,
331        );
332        assert_eq!(base.as_deref(), Some("https://app.example.com/api/auth"));
333    }
334
335    #[test]
336    fn infer_base_url_uses_forwarded_headers_when_trusted_and_valid() {
337        let mut headers = HeaderMap::new();
338        headers.insert(
339            "x-forwarded-host",
340            HeaderValue::from_static("public.example.com"),
341        );
342        headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
343        headers.insert(header::HOST, HeaderValue::from_static("internal.local"));
344
345        let base = infer_base_url(&headers, &Uri::from_static("/ok"), "/api/auth", true);
346        assert_eq!(base.as_deref(), Some("https://public.example.com/api/auth"));
347    }
348
349    #[test]
350    fn infer_base_url_uses_absolute_request_uri_origin() {
351        let headers = HeaderMap::new();
352        let uri = Uri::from_static("https://app.example.com/api/auth/sign-in/social");
353        let base = infer_base_url(&headers, &uri, "/api/auth", false);
354        assert_eq!(base.as_deref(), Some("https://app.example.com/api/auth"));
355    }
356
357    #[test]
358    fn infer_base_url_uses_http_for_loopback_host_header() {
359        let mut headers = HeaderMap::new();
360        headers.insert(header::HOST, HeaderValue::from_static("127.0.0.1:3000"));
361
362        let base = infer_base_url(&headers, &Uri::from_static("/ok"), "/api/auth", false);
363        assert_eq!(base.as_deref(), Some("http://127.0.0.1:3000/api/auth"));
364    }
365
366    #[tokio::test]
367    async fn validate_base_url_accepts_matching_pathname() -> Result<(), RustAuthError> {
368        let auth = RustAuth::builder()
369            .secret(SECRET)
370            .base_path("/api/auth")
371            .base_url("http://localhost:3000/api/auth/")
372            .build()
373            .await?;
374        assert!(validate_base_url_matches_base_path(&auth).is_ok());
375        Ok(())
376    }
377
378    #[tokio::test]
379    async fn validate_base_url_rejects_mismatched_pathname() -> Result<(), RustAuthError> {
380        let auth = RustAuth::builder()
381            .secret(SECRET)
382            .base_path("/api/auth")
383            .base_url("http://localhost:3000/wrong")
384            .build()
385            .await?;
386        assert!(matches!(
387            validate_base_url_matches_base_path(&auth),
388            Err(RustAuthActixWebError::InconsistentBaseUrlPath { .. })
389        ));
390        Ok(())
391    }
392
393    #[tokio::test]
394    async fn validate_base_url_rejects_invalid_absolute_url() -> Result<(), RustAuthError> {
395        let auth = RustAuth::builder()
396            .secret(SECRET)
397            .base_path("/api/auth")
398            .base_url("not-a-url")
399            .build()
400            .await?;
401        assert!(matches!(
402            validate_base_url_matches_base_path(&auth),
403            Err(RustAuthActixWebError::InvalidBaseUrl(_))
404        ));
405        Ok(())
406    }
407
408    #[tokio::test]
409    async fn mount_routes_rejects_mismatched_base_url_path() -> Result<(), RustAuthError> {
410        let auth = Arc::new(
411            RustAuth::builder()
412                .secret(SECRET)
413                .base_path("/api/auth")
414                .base_url("http://localhost:3000/wrong")
415                .build()
416                .await?,
417        );
418        assert!(matches!(
419            auth.mount_routes(RustAuthActixWebOptions::default()),
420            Err(RustAuthActixWebError::InconsistentBaseUrlPath { .. })
421        ));
422        Ok(())
423    }
424
425    #[tokio::test]
426    async fn mount_routes_keeps_shared_auth_available() -> Result<(), RustAuthError> {
427        let auth = Arc::new(RustAuth::builder().secret(SECRET).build().await?);
428        let routes = auth
429            .mount_routes(RustAuthActixWebOptions::default())
430            .map_err(|error| RustAuthError::Api(error.to_string()))?;
431        drop(routes);
432
433        assert_eq!(Arc::strong_count(&auth), 1);
434        assert!(validate_mount_config(auth.as_ref()).is_ok());
435        Ok(())
436    }
437
438    #[tokio::test]
439    async fn actix_state_clones_only_the_shared_auth_pointer() -> Result<(), RustAuthError> {
440        let auth = Arc::new(RustAuth::builder().secret(SECRET).build().await?);
441        let state = RustAuthActixWebState {
442            auth: Arc::clone(&auth),
443            options: RustAuthActixWebOptions::default(),
444        };
445
446        let cloned = state.clone();
447
448        assert_eq!(Arc::strong_count(&auth), 3);
449        drop(cloned);
450        assert_eq!(Arc::strong_count(&auth), 2);
451        Ok(())
452    }
453}