1use std::sync::Arc;
2
3use actix_web::web::{self, Payload};
4use actix_web::{HttpRequest, HttpResponse, Scope};
5use http::{header, HeaderMap, Uri};
6use rustauth::api::RequestBaseUrl;
7use rustauth::auth::oauth::OAuthBaseUrlOverride;
8use rustauth::error::RustAuthError;
9use rustauth::utils::host::is_loopback_host;
10use rustauth::utils::url::{is_valid_forwarded_host, is_valid_forwarded_proto};
11use rustauth::RustAuth;
12
13use crate::error::{internal_error_response, RustAuthActixWebError};
14use crate::request::to_api_request;
15use crate::response::from_api_response;
16use crate::RustAuthActixWebOptions;
17
18#[derive(Clone)]
19struct RustAuthActixWebState {
20 auth: Arc<RustAuth>,
21 options: RustAuthActixWebOptions,
22}
23
24pub trait RustAuthActixWebExt {
28 fn mount_routes(
30 &self,
31 options: RustAuthActixWebOptions,
32 ) -> Result<Scope, RustAuthActixWebError>;
33
34 fn mount_at_base_path(
36 &self,
37 options: RustAuthActixWebOptions,
38 ) -> Result<Scope, RustAuthActixWebError>;
39}
40
41impl RustAuthActixWebExt for RustAuth {
42 fn mount_routes(
43 &self,
44 options: RustAuthActixWebOptions,
45 ) -> Result<Scope, RustAuthActixWebError> {
46 routes_from_shared(Arc::new(self.clone()), options)
47 }
48
49 fn mount_at_base_path(
50 &self,
51 options: RustAuthActixWebOptions,
52 ) -> Result<Scope, RustAuthActixWebError> {
53 mount_scope_shared(Arc::new(self.clone()), options)
54 }
55}
56
57impl RustAuthActixWebExt for Arc<RustAuth> {
58 fn mount_routes(
59 &self,
60 options: RustAuthActixWebOptions,
61 ) -> Result<Scope, RustAuthActixWebError> {
62 routes_from_shared(Arc::clone(self), options)
63 }
64
65 fn mount_at_base_path(
66 &self,
67 options: RustAuthActixWebOptions,
68 ) -> Result<Scope, RustAuthActixWebError> {
69 mount_scope_shared(Arc::clone(self), options)
70 }
71}
72
73fn mount_scope_shared(
74 auth: Arc<RustAuth>,
75 options: RustAuthActixWebOptions,
76) -> Result<Scope, RustAuthActixWebError> {
77 validate_base_url_matches_base_path(auth.as_ref())?;
78 let base_path = normalize_base_path(&auth.context().base_path)?;
79 if base_path == "/" {
80 return routes_from_shared(auth, options);
81 }
82 Ok(web::scope(&base_path).service(routes_from_shared(auth, options)?))
83}
84
85pub fn validate_mount_config(auth: &RustAuth) -> Result<(), RustAuthActixWebError> {
90 validate_base_url_matches_base_path(auth)
91}
92
93fn routes_from_shared(
94 auth: Arc<RustAuth>,
95 options: RustAuthActixWebOptions,
96) -> Result<Scope, RustAuthActixWebError> {
97 validate_mount_config(auth.as_ref())?;
98 Ok(web::scope("")
99 .app_data(web::Data::new(RustAuthActixWebState { auth, options }))
100 .default_service(web::route().to(route_handler)))
101}
102
103pub async fn handle(
105 auth: &RustAuth,
106 options: RustAuthActixWebOptions,
107 request: HttpRequest,
108 payload: Payload,
109) -> HttpResponse {
110 match to_api_request(request, payload, options).await {
111 Ok(mut request) => {
112 maybe_insert_base_url(auth, &mut request, options);
113 match auth.handler_async(request).await {
114 Ok(response) => from_api_response(response),
115 Err(error) => {
116 log_internal_error(auth, &error);
117 internal_error_response()
118 }
119 }
120 }
121 Err(response) => response,
122 }
123}
124
125async fn route_handler(
126 state: web::Data<RustAuthActixWebState>,
127 request: HttpRequest,
128 payload: Payload,
129) -> HttpResponse {
130 handle(state.auth.as_ref(), state.options, request, payload).await
131}
132
133fn validate_base_url_matches_base_path(auth: &RustAuth) -> Result<(), RustAuthActixWebError> {
134 let base_url = auth.context().base_url.as_str();
135 if base_url.is_empty() {
136 return Ok(());
137 }
138
139 let parsed = url::Url::parse(base_url)
140 .map_err(|_| RustAuthActixWebError::InvalidBaseUrl(base_url.to_owned()))?;
141 let url_path = trim_path_suffix(parsed.path());
142 let base_path = trim_path_suffix(&auth.context().base_path);
143 if url_path == base_path {
144 return Ok(());
145 }
146
147 Err(RustAuthActixWebError::InconsistentBaseUrlPath {
148 url_path,
149 base_path,
150 })
151}
152
153fn trim_path_suffix(path: &str) -> String {
154 let trimmed = path.trim_end_matches('/');
155 if trimmed.is_empty() {
156 "/".to_owned()
157 } else {
158 trimmed.to_owned()
159 }
160}
161
162fn normalize_base_path(base_path: &str) -> Result<String, RustAuthActixWebError> {
163 if base_path.is_empty() {
164 return Ok("/".to_owned());
165 }
166 if !is_valid_base_path(base_path) {
167 return Err(RustAuthActixWebError::InvalidBasePath(base_path.to_owned()));
168 }
169
170 let trimmed = base_path.trim_end_matches('/');
171 if trimmed.is_empty() {
172 Ok("/".to_owned())
173 } else {
174 Ok(trimmed.to_owned())
175 }
176}
177
178fn maybe_insert_base_url(
179 auth: &RustAuth,
180 request: &mut rustauth::api::ApiRequest,
181 options: RustAuthActixWebOptions,
182) {
183 if !options.infer_base_url_from_request
184 || !auth.context().base_url.is_empty()
185 || request.extensions().get::<OAuthBaseUrlOverride>().is_some()
186 {
187 return;
188 }
189
190 if let Some(base_url) = infer_base_url(
191 request.headers(),
192 request.uri(),
193 &auth.context().base_path,
194 options.trust_proxy_headers_for_base_url,
195 ) {
196 request
197 .extensions_mut()
198 .insert(RequestBaseUrl(base_url.clone()));
199 request
200 .extensions_mut()
201 .insert(OAuthBaseUrlOverride(base_url));
202 }
203}
204
205fn infer_base_url(
206 headers: &HeaderMap,
207 uri: &Uri,
208 base_path: &str,
209 trust_proxy_headers: bool,
210) -> Option<String> {
211 let origin = if trust_proxy_headers {
212 forwarded_origin(headers)
213 } else {
214 None
215 }
216 .or_else(|| uri_origin(uri))
217 .or_else(|| host_header_origin(headers))?;
218 Some(with_base_path(origin, base_path))
219}
220
221fn forwarded_origin(headers: &HeaderMap) -> Option<String> {
222 let host = header_str(headers, "x-forwarded-host")?;
223 let proto = header_str(headers, "x-forwarded-proto")?;
224 if !is_valid_forwarded_host(host) || !is_valid_forwarded_proto(proto) {
225 return None;
226 }
227 Some(format!("{}://{}", proto.to_ascii_lowercase(), host))
228}
229
230fn uri_origin(uri: &Uri) -> Option<String> {
231 let scheme = uri.scheme_str()?;
232 if !is_valid_forwarded_proto(scheme) {
233 return None;
234 }
235 let authority = uri.authority()?.as_str();
236 if !is_valid_forwarded_host(authority) {
237 return None;
238 }
239 Some(format!("{}://{}", scheme, authority))
240}
241
242fn host_header_origin(headers: &HeaderMap) -> Option<String> {
243 let host = header_str(headers, header::HOST.as_str())?;
244 if !is_valid_forwarded_host(host) {
245 return None;
246 }
247 let scheme = if is_loopback_host(host) {
248 "http"
249 } else {
250 "https"
251 };
252 Some(format!("{scheme}://{host}"))
253}
254
255fn header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
256 headers.get(name)?.to_str().ok()
257}
258
259fn with_base_path(mut origin: String, base_path: &str) -> String {
260 let base_path = base_path.trim_end_matches('/');
261 if !base_path.is_empty() && base_path != "/" {
262 origin.push_str(base_path);
263 }
264 origin
265}
266
267fn is_valid_base_path(base_path: &str) -> bool {
268 base_path.starts_with('/')
269 && !base_path.contains('?')
270 && !base_path.contains('#')
271 && !base_path.contains('{')
272 && !base_path.contains('}')
273 && !base_path.contains('*')
274}
275
276fn log_internal_error(auth: &RustAuth, error: &RustAuthError) {
277 let message = error.to_string();
278 auth.context()
279 .logger
280 .error("RustAuth Actix Web handler failed", &[message.as_str()]);
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use http::header::HeaderValue;
287
288 const SECRET: &str = "test-secret-123456789012345678901234";
289
290 #[test]
291 fn normalize_base_path_trims_trailing_slashes_except_root() -> Result<(), RustAuthActixWebError>
292 {
293 assert_eq!(normalize_base_path("")?, "/");
294 assert_eq!(normalize_base_path("/")?, "/");
295 assert_eq!(normalize_base_path("/api/auth/")?, "/api/auth");
296 assert_eq!(normalize_base_path("/api/auth///")?, "/api/auth");
297 Ok(())
298 }
299
300 #[test]
301 fn normalize_base_path_rejects_pattern_syntax_and_non_absolute_paths() {
302 for base_path in [
303 "api/auth",
304 "/api/{auth}",
305 "/api/*auth",
306 "/api/auth?x=1",
307 "/api/auth#x",
308 ] {
309 assert!(matches!(
310 normalize_base_path(base_path),
311 Err(RustAuthActixWebError::InvalidBasePath(_))
312 ));
313 }
314 }
315
316 #[test]
317 fn infer_base_url_rejects_malicious_forwarded_headers_and_falls_back_to_host() {
318 let mut headers = HeaderMap::new();
319 headers.insert(
320 "x-forwarded-host",
321 HeaderValue::from_static("javascript:alert(1)"),
322 );
323 headers.insert("x-forwarded-proto", HeaderValue::from_static("http"));
324 headers.insert(header::HOST, HeaderValue::from_static("app.example.com"));
325
326 let base = infer_base_url(
327 &headers,
328 &Uri::from_static("/api/auth/ok"),
329 "/api/auth",
330 true,
331 );
332 assert_eq!(base.as_deref(), Some("https://app.example.com/api/auth"));
333 }
334
335 #[test]
336 fn infer_base_url_uses_forwarded_headers_when_trusted_and_valid() {
337 let mut headers = HeaderMap::new();
338 headers.insert(
339 "x-forwarded-host",
340 HeaderValue::from_static("public.example.com"),
341 );
342 headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
343 headers.insert(header::HOST, HeaderValue::from_static("internal.local"));
344
345 let base = infer_base_url(&headers, &Uri::from_static("/ok"), "/api/auth", true);
346 assert_eq!(base.as_deref(), Some("https://public.example.com/api/auth"));
347 }
348
349 #[test]
350 fn infer_base_url_uses_absolute_request_uri_origin() {
351 let headers = HeaderMap::new();
352 let uri = Uri::from_static("https://app.example.com/api/auth/sign-in/social");
353 let base = infer_base_url(&headers, &uri, "/api/auth", false);
354 assert_eq!(base.as_deref(), Some("https://app.example.com/api/auth"));
355 }
356
357 #[test]
358 fn infer_base_url_uses_http_for_loopback_host_header() {
359 let mut headers = HeaderMap::new();
360 headers.insert(header::HOST, HeaderValue::from_static("127.0.0.1:3000"));
361
362 let base = infer_base_url(&headers, &Uri::from_static("/ok"), "/api/auth", false);
363 assert_eq!(base.as_deref(), Some("http://127.0.0.1:3000/api/auth"));
364 }
365
366 #[tokio::test]
367 async fn validate_base_url_accepts_matching_pathname() -> Result<(), RustAuthError> {
368 let auth = RustAuth::builder()
369 .secret(SECRET)
370 .base_path("/api/auth")
371 .base_url("http://localhost:3000/api/auth/")
372 .build()
373 .await?;
374 assert!(validate_base_url_matches_base_path(&auth).is_ok());
375 Ok(())
376 }
377
378 #[tokio::test]
379 async fn validate_base_url_rejects_mismatched_pathname() -> Result<(), RustAuthError> {
380 let auth = RustAuth::builder()
381 .secret(SECRET)
382 .base_path("/api/auth")
383 .base_url("http://localhost:3000/wrong")
384 .build()
385 .await?;
386 assert!(matches!(
387 validate_base_url_matches_base_path(&auth),
388 Err(RustAuthActixWebError::InconsistentBaseUrlPath { .. })
389 ));
390 Ok(())
391 }
392
393 #[tokio::test]
394 async fn validate_base_url_rejects_invalid_absolute_url() -> Result<(), RustAuthError> {
395 let auth = RustAuth::builder()
396 .secret(SECRET)
397 .base_path("/api/auth")
398 .base_url("not-a-url")
399 .build()
400 .await?;
401 assert!(matches!(
402 validate_base_url_matches_base_path(&auth),
403 Err(RustAuthActixWebError::InvalidBaseUrl(_))
404 ));
405 Ok(())
406 }
407
408 #[tokio::test]
409 async fn mount_routes_rejects_mismatched_base_url_path() -> Result<(), RustAuthError> {
410 let auth = Arc::new(
411 RustAuth::builder()
412 .secret(SECRET)
413 .base_path("/api/auth")
414 .base_url("http://localhost:3000/wrong")
415 .build()
416 .await?,
417 );
418 assert!(matches!(
419 auth.mount_routes(RustAuthActixWebOptions::default()),
420 Err(RustAuthActixWebError::InconsistentBaseUrlPath { .. })
421 ));
422 Ok(())
423 }
424
425 #[tokio::test]
426 async fn mount_routes_keeps_shared_auth_available() -> Result<(), RustAuthError> {
427 let auth = Arc::new(RustAuth::builder().secret(SECRET).build().await?);
428 let routes = auth
429 .mount_routes(RustAuthActixWebOptions::default())
430 .map_err(|error| RustAuthError::Api(error.to_string()))?;
431 drop(routes);
432
433 assert_eq!(Arc::strong_count(&auth), 1);
434 assert!(validate_mount_config(auth.as_ref()).is_ok());
435 Ok(())
436 }
437
438 #[tokio::test]
439 async fn actix_state_clones_only_the_shared_auth_pointer() -> Result<(), RustAuthError> {
440 let auth = Arc::new(RustAuth::builder().secret(SECRET).build().await?);
441 let state = RustAuthActixWebState {
442 auth: Arc::clone(&auth),
443 options: RustAuthActixWebOptions::default(),
444 };
445
446 let cloned = state.clone();
447
448 assert_eq!(Arc::strong_count(&auth), 3);
449 drop(cloned);
450 assert_eq!(Arc::strong_count(&auth), 2);
451 Ok(())
452 }
453}