Skip to main content

victauri_plugin/
auth.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use url::Url;
8
9const BEARER_PREFIX_LEN: usize = "Bearer ".len();
10
11/// Constant-time byte comparison to prevent timing side-channel attacks on token validation.
12fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
13    if a.len() != b.len() {
14        return false;
15    }
16    a.iter()
17        .zip(b.iter())
18        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
19        == 0
20}
21
22/// Generate a random `UUID` v4 token suitable for Bearer authentication.
23#[must_use]
24pub fn generate_token() -> String {
25    uuid::Uuid::new_v4().to_string()
26}
27
28/// Shared authentication state holding the optional Bearer token for the MCP server.
29#[derive(Clone)]
30pub struct AuthState {
31    /// The expected Bearer token, or `None` if authentication is disabled.
32    pub(crate) token: Option<String>,
33}
34
35/// Axum middleware that validates the `Authorization: Bearer <token>` header against [`AuthState`].
36///
37/// # Errors
38///
39/// Returns [`StatusCode::UNAUTHORIZED`] if the token is missing or invalid.
40pub async fn require_auth(
41    axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
42    request: Request,
43    next: Next,
44) -> Result<Response, StatusCode> {
45    let Some(expected) = &auth.token else {
46        return Ok(next.run(request).await);
47    };
48
49    let provided = request
50        .headers()
51        .get("authorization")
52        .and_then(|v| v.to_str().ok())
53        .and_then(|v| {
54            let lower = v.to_lowercase();
55            if lower.starts_with("bearer ") {
56                Some(v[BEARER_PREFIX_LEN..].to_string())
57            } else {
58                None
59            }
60        });
61
62    match provided {
63        Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
64            Ok(next.run(request).await)
65        }
66        _ => {
67            tracing::warn!("Victauri: rejected request — invalid or missing auth token");
68            Err(StatusCode::UNAUTHORIZED)
69        }
70    }
71}
72
73// ── Rate Limiter ───────────────────────────────────────────────────────────
74
75/// Lock-free token-bucket rate limiter using millisecond-precision timestamps for smooth refill.
76pub struct RateLimiterState {
77    tokens: AtomicU64,
78    max_tokens: u64,
79    last_refill_ms: AtomicU64,
80    refill_rate_per_sec: u64,
81}
82
83fn now_ms() -> u64 {
84    std::time::SystemTime::now()
85        .duration_since(std::time::UNIX_EPOCH)
86        .unwrap_or_default()
87        .as_millis() as u64
88}
89
90impl RateLimiterState {
91    /// Create a rate limiter with the given maximum requests per second.
92    #[must_use]
93    pub fn new(max_requests_per_sec: u64) -> Self {
94        Self {
95            tokens: AtomicU64::new(max_requests_per_sec),
96            max_tokens: max_requests_per_sec,
97            last_refill_ms: AtomicU64::new(now_ms()),
98            refill_rate_per_sec: max_requests_per_sec,
99        }
100    }
101
102    /// Atomically consume one token, returning `true` if the request is allowed.
103    pub fn try_acquire(&self) -> bool {
104        self.refill();
105        loop {
106            let current = self.tokens.load(Ordering::Relaxed);
107            if current == 0 {
108                return false;
109            }
110            if self
111                .tokens
112                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
113                .is_ok()
114            {
115                return true;
116            }
117        }
118    }
119
120    fn refill(&self) {
121        let now = now_ms();
122        let last = self.last_refill_ms.load(Ordering::Relaxed);
123        let elapsed_ms = now.saturating_sub(last);
124        if elapsed_ms == 0 {
125            return;
126        }
127        let add = elapsed_ms * self.refill_rate_per_sec / 1000;
128        if add == 0 {
129            return;
130        }
131        if self
132            .last_refill_ms
133            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
134            .is_ok()
135        {
136            loop {
137                let current = self.tokens.load(Ordering::Relaxed);
138                let new_val = (current + add).min(self.max_tokens);
139                if self
140                    .tokens
141                    .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
142                    .is_ok()
143                {
144                    break;
145                }
146            }
147        }
148    }
149}
150
151/// Axum middleware that rejects requests with 429 when the token bucket is exhausted.
152///
153/// # Errors
154///
155/// Returns [`StatusCode::TOO_MANY_REQUESTS`] if the token bucket has no remaining capacity.
156pub async fn rate_limit(
157    axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
158    request: Request,
159    next: Next,
160) -> Result<
161    Response,
162    (
163        StatusCode,
164        [(axum::http::HeaderName, axum::http::HeaderValue); 1],
165    ),
166> {
167    if limiter.try_acquire() {
168        Ok(next.run(request).await)
169    } else {
170        Err((
171            StatusCode::TOO_MANY_REQUESTS,
172            [(
173                axum::http::header::RETRY_AFTER,
174                axum::http::HeaderValue::from_static("1"),
175            )],
176        ))
177    }
178}
179
180const DEFAULT_RATE_LIMIT: u64 = 1000;
181
182/// Create a rate limiter with the default capacity of 1000 requests per second.
183#[must_use]
184pub fn default_rate_limiter() -> Arc<RateLimiterState> {
185    Arc::new(RateLimiterState::new(DEFAULT_RATE_LIMIT))
186}
187
188// ── Security Middlewares ──────────────────────────────────────────────────
189
190/// Axum middleware that blocks DNS rebinding attacks.
191///
192/// Rejects any request where the Host header is not a localhost address.
193///
194/// # Errors
195///
196/// Returns [`StatusCode::FORBIDDEN`] if the `Host` header is not `localhost`, `127.0.0.1`, or `::1`.
197pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
198    let host = request
199        .headers()
200        .get("host")
201        .and_then(|v| v.to_str().ok())
202        .unwrap_or("");
203    let host_name = if host.starts_with('[') {
204        // Bracketed IPv6: [::1] or [::1]:7373
205        host.split(']').next().map_or(host, |s| &s[1..])
206    } else if host.contains("::") {
207        // Bare IPv6 (no brackets): ::1
208        host
209    } else {
210        // IPv4 or hostname, strip port: 127.0.0.1:7373 → 127.0.0.1
211        host.split(':').next().unwrap_or(host)
212    };
213    let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1");
214    if !is_allowed {
215        tracing::warn!("DNS rebinding attempt blocked: Host={host}");
216        return Err(StatusCode::FORBIDDEN);
217    }
218    Ok(next.run(request).await)
219}
220
221/// Axum middleware that blocks cross-origin requests from browsers.
222///
223/// # Errors
224///
225/// Returns [`StatusCode::FORBIDDEN`] if the `Origin` header is present and does not match a
226/// localhost or `tauri://` origin.
227pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
228    if let Some(origin) = request
229        .headers()
230        .get("origin")
231        .and_then(|v| v.to_str().ok())
232        && !is_allowed_origin(origin)
233    {
234        tracing::warn!("Cross-origin request blocked: Origin={origin}");
235        return Err(StatusCode::FORBIDDEN);
236    }
237    Ok(next.run(request).await)
238}
239
240fn is_allowed_origin(origin: &str) -> bool {
241    if origin.starts_with("tauri://") {
242        return true;
243    }
244    let Ok(parsed) = Url::parse(origin) else {
245        return false;
246    };
247    matches!(parsed.scheme(), "http" | "https")
248        && matches!(
249            parsed.host_str(),
250            Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
251        )
252}
253
254/// Axum middleware that sets security-hardening response headers on every response.
255pub async fn security_headers(request: Request, next: Next) -> Response {
256    let mut response = next.run(request).await;
257    let headers = response.headers_mut();
258    headers.insert(
259        axum::http::header::X_CONTENT_TYPE_OPTIONS,
260        axum::http::HeaderValue::from_static("nosniff"),
261    );
262    headers.insert(
263        axum::http::header::CACHE_CONTROL,
264        axum::http::HeaderValue::from_static("no-store"),
265    );
266    headers.insert(
267        axum::http::header::HeaderName::from_static("x-frame-options"),
268        axum::http::HeaderValue::from_static("DENY"),
269    );
270    headers.insert(
271        axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
272        axum::http::HeaderValue::from_static("null"),
273    );
274    headers.insert(
275        axum::http::header::HeaderName::from_static("content-security-policy"),
276        axum::http::HeaderValue::from_static("default-src 'none'"),
277    );
278    response
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use axum::Router;
285    use axum::body::Body;
286    use axum::middleware;
287    use axum::routing::get;
288    use tower::ServiceExt; // for oneshot
289
290    async fn ok_handler() -> &'static str {
291        "ok"
292    }
293
294    #[test]
295    fn token_generation_is_unique() {
296        let t1 = generate_token();
297        let t2 = generate_token();
298        assert_ne!(t1, t2);
299        assert_eq!(t1.len(), 36); // UUID v4 format
300    }
301
302    #[test]
303    fn token_is_valid_uuid() {
304        let token = generate_token();
305        assert!(uuid::Uuid::parse_str(&token).is_ok());
306    }
307
308    #[test]
309    fn rate_limiter_allows_within_budget() {
310        let limiter = RateLimiterState::new(10);
311        for _ in 0..10 {
312            assert!(limiter.try_acquire());
313        }
314    }
315
316    #[test]
317    fn rate_limiter_denies_when_exhausted() {
318        let limiter = RateLimiterState::new(5);
319        for _ in 0..5 {
320            assert!(limiter.try_acquire());
321        }
322        assert!(!limiter.try_acquire());
323    }
324
325    #[test]
326    fn rate_limiter_initial_tokens_match_max() {
327        let limiter = RateLimiterState::new(42);
328        assert_eq!(limiter.tokens.load(Ordering::Relaxed), 42);
329        assert_eq!(limiter.max_tokens, 42);
330    }
331
332    #[test]
333    fn rate_limiter_concurrent_acquire() {
334        // Use a large bucket so time-based refills (1 per second) are negligible
335        let limiter = Arc::new(RateLimiterState::new(1000));
336        let mut handles = vec![];
337        for _ in 0..10 {
338            let l = limiter.clone();
339            handles.push(std::thread::spawn(move || {
340                let mut acquired = 0;
341                for _ in 0..200 {
342                    if l.try_acquire() {
343                        acquired += 1;
344                    }
345                }
346                acquired
347            }));
348        }
349        let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
350        // All 1000 initial tokens dispensed; time-based refills (1000/sec) add
351        // tokens proportional to wall-clock duration, which varies by machine.
352        assert!(
353            total >= 1000,
354            "should dispense at least the initial budget, got {total}"
355        );
356        assert!(total <= 1200, "refill overshoot too high, got {total}");
357    }
358
359    #[test]
360    fn default_rate_limiter_has_expected_tokens() {
361        let limiter = default_rate_limiter();
362        assert_eq!(limiter.max_tokens, 1000);
363    }
364
365    #[test]
366    fn rate_limiter_zero_capacity() {
367        let limiter = RateLimiterState::new(0);
368        assert!(!limiter.try_acquire());
369    }
370
371    // ── DNS Rebinding Guard tests ─────────────────────────────────────────
372
373    fn dns_rebinding_router() -> Router {
374        Router::new()
375            .route("/test", get(ok_handler))
376            .layer(middleware::from_fn(dns_rebinding_guard))
377    }
378
379    fn dns_request(host: Option<&str>) -> Request<Body> {
380        let mut builder = Request::builder().uri("/test");
381        if let Some(h) = host {
382            builder = builder.header("host", h);
383        }
384        builder.body(Body::empty()).unwrap()
385    }
386
387    #[tokio::test]
388    async fn dns_rebinding_allows_localhost() {
389        let app = dns_rebinding_router();
390        let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
391        assert_eq!(resp.status(), StatusCode::OK);
392    }
393
394    #[tokio::test]
395    async fn dns_rebinding_allows_127_0_0_1() {
396        let app = dns_rebinding_router();
397        let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
398        assert_eq!(resp.status(), StatusCode::OK);
399    }
400
401    #[tokio::test]
402    async fn dns_rebinding_allows_ipv6_bracketed() {
403        let app = dns_rebinding_router();
404        let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
405        assert_eq!(resp.status(), StatusCode::OK);
406    }
407
408    #[tokio::test]
409    async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
410        let app = dns_rebinding_router();
411        let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
412        assert_eq!(resp.status(), StatusCode::OK);
413    }
414
415    #[tokio::test]
416    async fn dns_rebinding_allows_ipv6_bare() {
417        let app = dns_rebinding_router();
418        let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
419        assert_eq!(resp.status(), StatusCode::OK);
420    }
421
422    #[tokio::test]
423    async fn dns_rebinding_blocks_empty_host() {
424        let app = dns_rebinding_router();
425        let resp = app.oneshot(dns_request(None)).await.unwrap();
426        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
427    }
428
429    #[tokio::test]
430    async fn dns_rebinding_blocks_evil_com() {
431        let app = dns_rebinding_router();
432        let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
433        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
434    }
435
436    #[tokio::test]
437    async fn dns_rebinding_blocks_localhost_subdomain() {
438        let app = dns_rebinding_router();
439        let resp = app
440            .oneshot(dns_request(Some("localhost.evil.com")))
441            .await
442            .unwrap();
443        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
444    }
445
446    #[tokio::test]
447    async fn dns_rebinding_blocks_ip_subdomain() {
448        let app = dns_rebinding_router();
449        let resp = app
450            .oneshot(dns_request(Some("127.0.0.1.evil.com")))
451            .await
452            .unwrap();
453        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
454    }
455
456    // ── Origin Guard tests ────────────────────────────────────────────────
457
458    fn origin_router() -> Router {
459        Router::new()
460            .route("/test", get(ok_handler))
461            .layer(middleware::from_fn(origin_guard))
462    }
463
464    fn origin_request(origin: Option<&str>) -> Request<Body> {
465        let mut builder = Request::builder().uri("/test");
466        if let Some(o) = origin {
467            builder = builder.header("origin", o);
468        }
469        builder.body(Body::empty()).unwrap()
470    }
471
472    #[tokio::test]
473    async fn origin_allows_no_origin() {
474        let app = origin_router();
475        let resp = app.oneshot(origin_request(None)).await.unwrap();
476        assert_eq!(resp.status(), StatusCode::OK);
477    }
478
479    #[tokio::test]
480    async fn origin_allows_localhost_http() {
481        let app = origin_router();
482        let resp = app
483            .oneshot(origin_request(Some("http://localhost:3000")))
484            .await
485            .unwrap();
486        assert_eq!(resp.status(), StatusCode::OK);
487    }
488
489    #[tokio::test]
490    async fn origin_allows_127_0_0_1_https() {
491        let app = origin_router();
492        let resp = app
493            .oneshot(origin_request(Some("https://127.0.0.1:8080")))
494            .await
495            .unwrap();
496        assert_eq!(resp.status(), StatusCode::OK);
497    }
498
499    #[tokio::test]
500    async fn origin_allows_tauri_scheme() {
501        let app = origin_router();
502        let resp = app
503            .oneshot(origin_request(Some("tauri://localhost")))
504            .await
505            .unwrap();
506        assert_eq!(resp.status(), StatusCode::OK);
507    }
508
509    #[tokio::test]
510    async fn origin_blocks_null() {
511        let app = origin_router();
512        let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
513        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
514    }
515
516    #[tokio::test]
517    async fn origin_blocks_evil_com() {
518        let app = origin_router();
519        let resp = app
520            .oneshot(origin_request(Some("http://evil.com")))
521            .await
522            .unwrap();
523        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
524    }
525
526    // ── Security Headers tests ────────────────────────────────────────────
527
528    fn security_headers_router() -> Router {
529        Router::new()
530            .route("/test", get(ok_handler))
531            .layer(middleware::from_fn(security_headers))
532    }
533
534    #[tokio::test]
535    async fn security_headers_x_content_type_options() {
536        let app = security_headers_router();
537        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
538        let resp = app.oneshot(req).await.unwrap();
539        assert_eq!(resp.status(), StatusCode::OK);
540        assert_eq!(
541            resp.headers().get("x-content-type-options").unwrap(),
542            "nosniff"
543        );
544    }
545
546    #[tokio::test]
547    async fn security_headers_cache_control() {
548        let app = security_headers_router();
549        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
550        let resp = app.oneshot(req).await.unwrap();
551        assert_eq!(resp.status(), StatusCode::OK);
552        assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
553    }
554
555    #[tokio::test]
556    async fn security_headers_x_frame_options() {
557        let app = security_headers_router();
558        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
559        let resp = app.oneshot(req).await.unwrap();
560        assert_eq!(resp.status(), StatusCode::OK);
561        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
562    }
563
564    // ── Auth middleware integration tests ─────────────────────────────────
565
566    fn auth_router(token: Option<&str>) -> Router {
567        let state = Arc::new(AuthState {
568            token: token.map(String::from),
569        });
570        Router::new()
571            .route("/test", get(ok_handler))
572            .layer(middleware::from_fn_with_state(state, require_auth))
573    }
574
575    fn auth_request(token: Option<&str>) -> Request<Body> {
576        let mut builder = Request::builder().uri("/test");
577        if let Some(t) = token {
578            builder = builder.header("authorization", format!("Bearer {t}"));
579        }
580        builder.body(Body::empty()).unwrap()
581    }
582
583    #[tokio::test]
584    async fn auth_allows_correct_token() {
585        let app = auth_router(Some("secret-123"));
586        let resp = app.oneshot(auth_request(Some("secret-123"))).await.unwrap();
587        assert_eq!(resp.status(), StatusCode::OK);
588    }
589
590    #[tokio::test]
591    async fn auth_rejects_wrong_token() {
592        let app = auth_router(Some("secret-123"));
593        let resp = app
594            .oneshot(auth_request(Some("wrong-token")))
595            .await
596            .unwrap();
597        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
598    }
599
600    #[tokio::test]
601    async fn auth_rejects_missing_token() {
602        let app = auth_router(Some("secret-123"));
603        let resp = app.oneshot(auth_request(None)).await.unwrap();
604        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
605    }
606
607    #[tokio::test]
608    async fn auth_allows_any_when_disabled() {
609        let app = auth_router(None);
610        let resp = app.oneshot(auth_request(None)).await.unwrap();
611        assert_eq!(resp.status(), StatusCode::OK);
612    }
613
614    #[tokio::test]
615    async fn auth_case_insensitive_bearer_prefix() {
616        let state = Arc::new(AuthState {
617            token: Some("my-token".into()),
618        });
619        let app = Router::new()
620            .route("/test", get(ok_handler))
621            .layer(middleware::from_fn_with_state(state, require_auth));
622
623        let req = Request::builder()
624            .uri("/test")
625            .header("authorization", "BEARER my-token")
626            .body(Body::empty())
627            .unwrap();
628        let resp = app.oneshot(req).await.unwrap();
629        assert_eq!(resp.status(), StatusCode::OK);
630    }
631
632    #[tokio::test]
633    async fn auth_rejects_non_bearer_scheme() {
634        let app = auth_router(Some("secret"));
635        let req = Request::builder()
636            .uri("/test")
637            .header("authorization", "Basic c2VjcmV0")
638            .body(Body::empty())
639            .unwrap();
640        let resp = app.oneshot(req).await.unwrap();
641        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
642    }
643
644    // ── Rate limiter middleware integration test ──────────────────────────
645
646    #[tokio::test]
647    async fn rate_limiter_returns_429_when_exhausted() {
648        let limiter = Arc::new(RateLimiterState::new(2));
649        let app = Router::new()
650            .route("/test", get(ok_handler))
651            .layer(middleware::from_fn_with_state(limiter, rate_limit));
652
653        let app2 = app.clone();
654        let app3 = app2.clone();
655
656        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
657        assert_eq!(app.oneshot(req).await.unwrap().status(), StatusCode::OK);
658
659        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
660        assert_eq!(app2.oneshot(req).await.unwrap().status(), StatusCode::OK);
661
662        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
663        assert_eq!(
664            app3.oneshot(req).await.unwrap().status(),
665            StatusCode::TOO_MANY_REQUESTS
666        );
667    }
668
669    // ── Combined security layer test ─────────────────────────────────────
670
671    #[tokio::test]
672    async fn combined_layers_enforce_all_guards() {
673        let auth_state = Arc::new(AuthState {
674            token: Some("tok-123".into()),
675        });
676        let limiter = Arc::new(RateLimiterState::new(100));
677
678        let app = Router::new()
679            .route("/test", get(ok_handler))
680            .layer(middleware::from_fn_with_state(auth_state, require_auth))
681            .layer(middleware::from_fn_with_state(limiter, rate_limit))
682            .layer(middleware::from_fn(security_headers))
683            .layer(middleware::from_fn(origin_guard))
684            .layer(middleware::from_fn(dns_rebinding_guard));
685
686        // Good request: all guards pass
687        let req = Request::builder()
688            .uri("/test")
689            .header("authorization", "Bearer tok-123")
690            .header("host", "127.0.0.1:7373")
691            .body(Body::empty())
692            .unwrap();
693        let resp = app.clone().oneshot(req).await.unwrap();
694        assert_eq!(resp.status(), StatusCode::OK);
695        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
696
697        // Bad host: DNS rebinding guard blocks
698        let req = Request::builder()
699            .uri("/test")
700            .header("authorization", "Bearer tok-123")
701            .header("host", "evil.com")
702            .body(Body::empty())
703            .unwrap();
704        let resp = app.clone().oneshot(req).await.unwrap();
705        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
706
707        // Bad origin: origin guard blocks
708        let req = Request::builder()
709            .uri("/test")
710            .header("authorization", "Bearer tok-123")
711            .header("host", "localhost")
712            .header("origin", "https://evil.com")
713            .body(Body::empty())
714            .unwrap();
715        let resp = app.clone().oneshot(req).await.unwrap();
716        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
717
718        // Missing auth: auth middleware blocks
719        let req = Request::builder()
720            .uri("/test")
721            .header("host", "localhost")
722            .body(Body::empty())
723            .unwrap();
724        let resp = app.oneshot(req).await.unwrap();
725        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
726    }
727
728    #[test]
729    fn origin_guard_allows_localhost_variants() {
730        assert!(is_allowed_origin("http://localhost"));
731        assert!(is_allowed_origin("http://localhost:7373"));
732        assert!(is_allowed_origin("https://localhost"));
733        assert!(is_allowed_origin("https://localhost:443"));
734        assert!(is_allowed_origin("http://127.0.0.1"));
735        assert!(is_allowed_origin("http://127.0.0.1:8080"));
736        assert!(is_allowed_origin("https://127.0.0.1"));
737        assert!(is_allowed_origin("http://[::1]"));
738        assert!(is_allowed_origin("http://[::1]:7373"));
739        assert!(is_allowed_origin("tauri://localhost"));
740        assert!(is_allowed_origin("tauri://some-app"));
741    }
742
743    #[test]
744    fn origin_guard_rejects_prefix_smuggling() {
745        assert!(!is_allowed_origin("http://localhost.evil.com"));
746        assert!(!is_allowed_origin("https://localhost.evil.com"));
747        assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
748        assert!(!is_allowed_origin("http://[::1].evil.com"));
749    }
750
751    #[test]
752    fn origin_guard_rejects_userinfo_trick() {
753        assert!(!is_allowed_origin("http://localhost@evil.com"));
754        assert!(!is_allowed_origin("http://127.0.0.1@evil.com"));
755    }
756
757    #[test]
758    fn origin_guard_rejects_foreign_and_malformed() {
759        assert!(!is_allowed_origin("http://evil.com"));
760        assert!(!is_allowed_origin("https://attacker.io"));
761        assert!(!is_allowed_origin("not-a-url"));
762        assert!(!is_allowed_origin(""));
763        assert!(!is_allowed_origin("ftp://localhost"));
764    }
765
766    // ── Constant-time comparison tests ───────────────────────────────────
767
768    #[test]
769    fn constant_time_eq_equal_strings() {
770        assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
771    }
772
773    #[test]
774    fn constant_time_eq_different_strings() {
775        assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
776    }
777
778    #[test]
779    fn constant_time_eq_different_lengths() {
780        assert!(!constant_time_eq(b"short", b"longer-string"));
781    }
782
783    #[test]
784    fn constant_time_eq_empty_strings() {
785        assert!(constant_time_eq(b"", b""));
786    }
787
788    #[test]
789    fn constant_time_eq_one_empty() {
790        assert!(!constant_time_eq(b"", b"notempty"));
791        assert!(!constant_time_eq(b"notempty", b""));
792    }
793
794    #[test]
795    fn constant_time_eq_single_bit_difference() {
796        // 'A' = 0x41, 'B' = 0x42 — differ by one bit
797        assert!(!constant_time_eq(b"A", b"B"));
798    }
799
800    // ── Security headers: CORS + CSP tests ───────────────────────────────
801
802    #[tokio::test]
803    async fn security_headers_cors_deny() {
804        let app = security_headers_router();
805        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
806        let resp = app.oneshot(req).await.unwrap();
807        assert_eq!(
808            resp.headers().get("access-control-allow-origin").unwrap(),
809            "null"
810        );
811    }
812
813    #[tokio::test]
814    async fn security_headers_csp() {
815        let app = security_headers_router();
816        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
817        let resp = app.oneshot(req).await.unwrap();
818        assert_eq!(
819            resp.headers().get("content-security-policy").unwrap(),
820            "default-src 'none'"
821        );
822    }
823}