1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8const BEARER_PREFIX_LEN: usize = "Bearer ".len();
9
10#[must_use]
12pub fn generate_token() -> String {
13 uuid::Uuid::new_v4().to_string()
14}
15
16#[derive(Clone)]
18pub struct AuthState {
19 pub token: Option<String>,
21}
22
23pub async fn require_auth(
29 axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
30 request: Request,
31 next: Next,
32) -> Result<Response, StatusCode> {
33 let Some(expected) = &auth.token else {
34 return Ok(next.run(request).await);
35 };
36
37 let provided = request
38 .headers()
39 .get("authorization")
40 .and_then(|v| v.to_str().ok())
41 .and_then(|v| {
42 let lower = v.to_lowercase();
43 if lower.starts_with("bearer ") {
44 Some(v[BEARER_PREFIX_LEN..].to_string())
45 } else {
46 None
47 }
48 });
49
50 match provided {
51 Some(ref token) if token == expected => Ok(next.run(request).await),
52 _ => Err(StatusCode::UNAUTHORIZED),
53 }
54}
55
56pub struct RateLimiterState {
60 tokens: AtomicU64,
61 max_tokens: u64,
62 last_refill_ms: AtomicU64,
63 refill_rate_per_sec: u64,
64}
65
66fn now_ms() -> u64 {
67 std::time::SystemTime::now()
68 .duration_since(std::time::UNIX_EPOCH)
69 .unwrap_or_default()
70 .as_millis() as u64
71}
72
73impl RateLimiterState {
74 #[must_use]
76 pub fn new(max_requests_per_sec: u64) -> Self {
77 Self {
78 tokens: AtomicU64::new(max_requests_per_sec),
79 max_tokens: max_requests_per_sec,
80 last_refill_ms: AtomicU64::new(now_ms()),
81 refill_rate_per_sec: max_requests_per_sec,
82 }
83 }
84
85 pub fn try_acquire(&self) -> bool {
87 self.refill();
88 loop {
89 let current = self.tokens.load(Ordering::Relaxed);
90 if current == 0 {
91 return false;
92 }
93 if self
94 .tokens
95 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
96 .is_ok()
97 {
98 return true;
99 }
100 }
101 }
102
103 fn refill(&self) {
104 let now = now_ms();
105 let last = self.last_refill_ms.load(Ordering::Relaxed);
106 let elapsed_ms = now.saturating_sub(last);
107 if elapsed_ms == 0 {
108 return;
109 }
110 let add = elapsed_ms * self.refill_rate_per_sec / 1000;
111 if add == 0 {
112 return;
113 }
114 if self
115 .last_refill_ms
116 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
117 .is_ok()
118 {
119 loop {
120 let current = self.tokens.load(Ordering::Relaxed);
121 let new_val = (current + add).min(self.max_tokens);
122 if self
123 .tokens
124 .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
125 .is_ok()
126 {
127 break;
128 }
129 }
130 }
131 }
132}
133
134pub async fn rate_limit(
140 axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
141 request: Request,
142 next: Next,
143) -> Result<Response, StatusCode> {
144 if limiter.try_acquire() {
145 Ok(next.run(request).await)
146 } else {
147 Err(StatusCode::TOO_MANY_REQUESTS)
148 }
149}
150
151const DEFAULT_RATE_LIMIT: u64 = 1000;
152
153#[must_use]
155pub fn default_rate_limiter() -> Arc<RateLimiterState> {
156 Arc::new(RateLimiterState::new(DEFAULT_RATE_LIMIT))
157}
158
159pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
169 let host = request
170 .headers()
171 .get("host")
172 .and_then(|v| v.to_str().ok())
173 .unwrap_or("");
174 let host_name = if host.starts_with('[') {
175 host.split(']').next().map_or(host, |s| &s[1..])
176 } else {
177 host.split(':').next().unwrap_or(host)
178 };
179 let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1" | "");
180 if !is_allowed {
181 tracing::warn!("DNS rebinding attempt blocked: Host={host}");
182 return Err(StatusCode::FORBIDDEN);
183 }
184 Ok(next.run(request).await)
185}
186
187pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
194 if let Some(origin) = request
195 .headers()
196 .get("origin")
197 .and_then(|v| v.to_str().ok())
198 {
199 let allowed = origin.starts_with("http://localhost")
200 || origin.starts_with("https://localhost")
201 || origin.starts_with("http://127.0.0.1")
202 || origin.starts_with("https://127.0.0.1")
203 || origin.starts_with("http://[::1]")
204 || origin.starts_with("https://[::1]")
205 || origin.starts_with("tauri://");
206 if !allowed {
207 tracing::warn!("Cross-origin request blocked: Origin={origin}");
208 return Err(StatusCode::FORBIDDEN);
209 }
210 }
211 Ok(next.run(request).await)
212}
213
214pub async fn security_headers(request: Request, next: Next) -> Response {
216 let mut response = next.run(request).await;
217 let headers = response.headers_mut();
218 headers.insert(
219 axum::http::header::X_CONTENT_TYPE_OPTIONS,
220 axum::http::HeaderValue::from_static("nosniff"),
221 );
222 headers.insert(
223 axum::http::header::CACHE_CONTROL,
224 axum::http::HeaderValue::from_static("no-store"),
225 );
226 headers.insert(
227 axum::http::header::HeaderName::from_static("x-frame-options"),
228 axum::http::HeaderValue::from_static("DENY"),
229 );
230 response
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use axum::Router;
237 use axum::body::Body;
238 use axum::middleware;
239 use axum::routing::get;
240 use tower::ServiceExt; async fn ok_handler() -> &'static str {
243 "ok"
244 }
245
246 #[test]
247 fn token_generation_is_unique() {
248 let t1 = generate_token();
249 let t2 = generate_token();
250 assert_ne!(t1, t2);
251 assert_eq!(t1.len(), 36); }
253
254 #[test]
255 fn token_is_valid_uuid() {
256 let token = generate_token();
257 assert!(uuid::Uuid::parse_str(&token).is_ok());
258 }
259
260 #[test]
261 fn rate_limiter_allows_within_budget() {
262 let limiter = RateLimiterState::new(10);
263 for _ in 0..10 {
264 assert!(limiter.try_acquire());
265 }
266 }
267
268 #[test]
269 fn rate_limiter_denies_when_exhausted() {
270 let limiter = RateLimiterState::new(5);
271 for _ in 0..5 {
272 assert!(limiter.try_acquire());
273 }
274 assert!(!limiter.try_acquire());
275 }
276
277 #[test]
278 fn rate_limiter_initial_tokens_match_max() {
279 let limiter = RateLimiterState::new(42);
280 assert_eq!(limiter.tokens.load(Ordering::Relaxed), 42);
281 assert_eq!(limiter.max_tokens, 42);
282 }
283
284 #[test]
285 fn rate_limiter_concurrent_acquire() {
286 let limiter = Arc::new(RateLimiterState::new(1000));
288 let mut handles = vec![];
289 for _ in 0..10 {
290 let l = limiter.clone();
291 handles.push(std::thread::spawn(move || {
292 let mut acquired = 0;
293 for _ in 0..200 {
294 if l.try_acquire() {
295 acquired += 1;
296 }
297 }
298 acquired
299 }));
300 }
301 let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
302 assert!((1000..=1010).contains(&total));
304 }
305
306 #[test]
307 fn default_rate_limiter_has_expected_tokens() {
308 let limiter = default_rate_limiter();
309 assert_eq!(limiter.max_tokens, 1000);
310 }
311
312 #[test]
313 fn rate_limiter_zero_capacity() {
314 let limiter = RateLimiterState::new(0);
315 assert!(!limiter.try_acquire());
316 }
317
318 fn dns_rebinding_router() -> Router {
321 Router::new()
322 .route("/test", get(ok_handler))
323 .layer(middleware::from_fn(dns_rebinding_guard))
324 }
325
326 fn dns_request(host: Option<&str>) -> Request<Body> {
327 let mut builder = Request::builder().uri("/test");
328 if let Some(h) = host {
329 builder = builder.header("host", h);
330 }
331 builder.body(Body::empty()).unwrap()
332 }
333
334 #[tokio::test]
335 async fn dns_rebinding_allows_localhost() {
336 let app = dns_rebinding_router();
337 let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
338 assert_eq!(resp.status(), StatusCode::OK);
339 }
340
341 #[tokio::test]
342 async fn dns_rebinding_allows_127_0_0_1() {
343 let app = dns_rebinding_router();
344 let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
345 assert_eq!(resp.status(), StatusCode::OK);
346 }
347
348 #[tokio::test]
349 async fn dns_rebinding_allows_ipv6_bracketed() {
350 let app = dns_rebinding_router();
351 let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
352 assert_eq!(resp.status(), StatusCode::OK);
353 }
354
355 #[tokio::test]
356 async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
357 let app = dns_rebinding_router();
358 let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
359 assert_eq!(resp.status(), StatusCode::OK);
360 }
361
362 #[tokio::test]
363 async fn dns_rebinding_allows_ipv6_bare() {
364 let app = dns_rebinding_router();
365 let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
366 assert_eq!(resp.status(), StatusCode::OK);
367 }
368
369 #[tokio::test]
370 async fn dns_rebinding_allows_empty_host() {
371 let app = dns_rebinding_router();
372 let resp = app.oneshot(dns_request(None)).await.unwrap();
373 assert_eq!(resp.status(), StatusCode::OK);
374 }
375
376 #[tokio::test]
377 async fn dns_rebinding_blocks_evil_com() {
378 let app = dns_rebinding_router();
379 let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
380 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
381 }
382
383 #[tokio::test]
384 async fn dns_rebinding_blocks_localhost_subdomain() {
385 let app = dns_rebinding_router();
386 let resp = app
387 .oneshot(dns_request(Some("localhost.evil.com")))
388 .await
389 .unwrap();
390 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
391 }
392
393 #[tokio::test]
394 async fn dns_rebinding_blocks_ip_subdomain() {
395 let app = dns_rebinding_router();
396 let resp = app
397 .oneshot(dns_request(Some("127.0.0.1.evil.com")))
398 .await
399 .unwrap();
400 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
401 }
402
403 fn origin_router() -> Router {
406 Router::new()
407 .route("/test", get(ok_handler))
408 .layer(middleware::from_fn(origin_guard))
409 }
410
411 fn origin_request(origin: Option<&str>) -> Request<Body> {
412 let mut builder = Request::builder().uri("/test");
413 if let Some(o) = origin {
414 builder = builder.header("origin", o);
415 }
416 builder.body(Body::empty()).unwrap()
417 }
418
419 #[tokio::test]
420 async fn origin_allows_no_origin() {
421 let app = origin_router();
422 let resp = app.oneshot(origin_request(None)).await.unwrap();
423 assert_eq!(resp.status(), StatusCode::OK);
424 }
425
426 #[tokio::test]
427 async fn origin_allows_localhost_http() {
428 let app = origin_router();
429 let resp = app
430 .oneshot(origin_request(Some("http://localhost:3000")))
431 .await
432 .unwrap();
433 assert_eq!(resp.status(), StatusCode::OK);
434 }
435
436 #[tokio::test]
437 async fn origin_allows_127_0_0_1_https() {
438 let app = origin_router();
439 let resp = app
440 .oneshot(origin_request(Some("https://127.0.0.1:8080")))
441 .await
442 .unwrap();
443 assert_eq!(resp.status(), StatusCode::OK);
444 }
445
446 #[tokio::test]
447 async fn origin_allows_tauri_scheme() {
448 let app = origin_router();
449 let resp = app
450 .oneshot(origin_request(Some("tauri://localhost")))
451 .await
452 .unwrap();
453 assert_eq!(resp.status(), StatusCode::OK);
454 }
455
456 #[tokio::test]
457 async fn origin_blocks_null() {
458 let app = origin_router();
459 let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
460 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
461 }
462
463 #[tokio::test]
464 async fn origin_blocks_evil_com() {
465 let app = origin_router();
466 let resp = app
467 .oneshot(origin_request(Some("http://evil.com")))
468 .await
469 .unwrap();
470 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
471 }
472
473 fn security_headers_router() -> Router {
476 Router::new()
477 .route("/test", get(ok_handler))
478 .layer(middleware::from_fn(security_headers))
479 }
480
481 #[tokio::test]
482 async fn security_headers_x_content_type_options() {
483 let app = security_headers_router();
484 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
485 let resp = app.oneshot(req).await.unwrap();
486 assert_eq!(resp.status(), StatusCode::OK);
487 assert_eq!(
488 resp.headers().get("x-content-type-options").unwrap(),
489 "nosniff"
490 );
491 }
492
493 #[tokio::test]
494 async fn security_headers_cache_control() {
495 let app = security_headers_router();
496 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
497 let resp = app.oneshot(req).await.unwrap();
498 assert_eq!(resp.status(), StatusCode::OK);
499 assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
500 }
501
502 #[tokio::test]
503 async fn security_headers_x_frame_options() {
504 let app = security_headers_router();
505 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
506 let resp = app.oneshot(req).await.unwrap();
507 assert_eq!(resp.status(), StatusCode::OK);
508 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
509 }
510
511 fn auth_router(token: Option<&str>) -> Router {
514 let state = Arc::new(AuthState {
515 token: token.map(String::from),
516 });
517 Router::new()
518 .route("/test", get(ok_handler))
519 .layer(middleware::from_fn_with_state(state, require_auth))
520 }
521
522 fn auth_request(token: Option<&str>) -> Request<Body> {
523 let mut builder = Request::builder().uri("/test");
524 if let Some(t) = token {
525 builder = builder.header("authorization", format!("Bearer {t}"));
526 }
527 builder.body(Body::empty()).unwrap()
528 }
529
530 #[tokio::test]
531 async fn auth_allows_correct_token() {
532 let app = auth_router(Some("secret-123"));
533 let resp = app.oneshot(auth_request(Some("secret-123"))).await.unwrap();
534 assert_eq!(resp.status(), StatusCode::OK);
535 }
536
537 #[tokio::test]
538 async fn auth_rejects_wrong_token() {
539 let app = auth_router(Some("secret-123"));
540 let resp = app
541 .oneshot(auth_request(Some("wrong-token")))
542 .await
543 .unwrap();
544 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
545 }
546
547 #[tokio::test]
548 async fn auth_rejects_missing_token() {
549 let app = auth_router(Some("secret-123"));
550 let resp = app.oneshot(auth_request(None)).await.unwrap();
551 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
552 }
553
554 #[tokio::test]
555 async fn auth_allows_any_when_disabled() {
556 let app = auth_router(None);
557 let resp = app.oneshot(auth_request(None)).await.unwrap();
558 assert_eq!(resp.status(), StatusCode::OK);
559 }
560
561 #[tokio::test]
562 async fn auth_case_insensitive_bearer_prefix() {
563 let state = Arc::new(AuthState {
564 token: Some("my-token".into()),
565 });
566 let app = Router::new()
567 .route("/test", get(ok_handler))
568 .layer(middleware::from_fn_with_state(state, require_auth));
569
570 let req = Request::builder()
571 .uri("/test")
572 .header("authorization", "BEARER my-token")
573 .body(Body::empty())
574 .unwrap();
575 let resp = app.oneshot(req).await.unwrap();
576 assert_eq!(resp.status(), StatusCode::OK);
577 }
578
579 #[tokio::test]
580 async fn auth_rejects_non_bearer_scheme() {
581 let app = auth_router(Some("secret"));
582 let req = Request::builder()
583 .uri("/test")
584 .header("authorization", "Basic c2VjcmV0")
585 .body(Body::empty())
586 .unwrap();
587 let resp = app.oneshot(req).await.unwrap();
588 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
589 }
590
591 #[tokio::test]
594 async fn rate_limiter_returns_429_when_exhausted() {
595 let limiter = Arc::new(RateLimiterState::new(2));
596 let app = Router::new()
597 .route("/test", get(ok_handler))
598 .layer(middleware::from_fn_with_state(limiter, rate_limit));
599
600 let app2 = app.clone();
601 let app3 = app2.clone();
602
603 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
604 assert_eq!(app.oneshot(req).await.unwrap().status(), StatusCode::OK);
605
606 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
607 assert_eq!(app2.oneshot(req).await.unwrap().status(), StatusCode::OK);
608
609 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
610 assert_eq!(
611 app3.oneshot(req).await.unwrap().status(),
612 StatusCode::TOO_MANY_REQUESTS
613 );
614 }
615
616 #[tokio::test]
619 async fn combined_layers_enforce_all_guards() {
620 let auth_state = Arc::new(AuthState {
621 token: Some("tok-123".into()),
622 });
623 let limiter = Arc::new(RateLimiterState::new(100));
624
625 let app = Router::new()
626 .route("/test", get(ok_handler))
627 .layer(middleware::from_fn_with_state(auth_state, require_auth))
628 .layer(middleware::from_fn_with_state(limiter, rate_limit))
629 .layer(middleware::from_fn(security_headers))
630 .layer(middleware::from_fn(origin_guard))
631 .layer(middleware::from_fn(dns_rebinding_guard));
632
633 let req = Request::builder()
635 .uri("/test")
636 .header("authorization", "Bearer tok-123")
637 .header("host", "127.0.0.1:7373")
638 .body(Body::empty())
639 .unwrap();
640 let resp = app.clone().oneshot(req).await.unwrap();
641 assert_eq!(resp.status(), StatusCode::OK);
642 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
643
644 let req = Request::builder()
646 .uri("/test")
647 .header("authorization", "Bearer tok-123")
648 .header("host", "evil.com")
649 .body(Body::empty())
650 .unwrap();
651 let resp = app.clone().oneshot(req).await.unwrap();
652 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
653
654 let req = Request::builder()
656 .uri("/test")
657 .header("authorization", "Bearer tok-123")
658 .header("host", "localhost")
659 .header("origin", "https://evil.com")
660 .body(Body::empty())
661 .unwrap();
662 let resp = app.clone().oneshot(req).await.unwrap();
663 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
664
665 let req = Request::builder()
667 .uri("/test")
668 .header("host", "localhost")
669 .body(Body::empty())
670 .unwrap();
671 let resp = app.oneshot(req).await.unwrap();
672 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
673 }
674}