Skip to main content

victauri_plugin/
auth.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use url::Url;
8
9const BEARER_PREFIX_LEN: usize = "Bearer ".len();
10
11/// Constant-time byte comparison to prevent timing side-channel attacks on token validation.
12fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
13    if a.len() != b.len() {
14        return false;
15    }
16    a.iter()
17        .zip(b.iter())
18        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
19        == 0
20}
21
22/// Generate a random `UUID` v4 token suitable for Bearer authentication.
23#[must_use]
24pub fn generate_token() -> String {
25    uuid::Uuid::new_v4().to_string()
26}
27
28/// Shared authentication state holding the optional Bearer token for the MCP server.
29#[derive(Clone)]
30pub struct AuthState {
31    /// The expected Bearer token, or `None` if authentication is disabled.
32    pub(crate) token: Option<String>,
33}
34
35/// Axum middleware that validates the `Authorization: Bearer <token>` header against [`AuthState`].
36///
37/// # Errors
38///
39/// Returns [`StatusCode::UNAUTHORIZED`] if the token is missing or invalid.
40pub async fn require_auth(
41    axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
42    request: Request,
43    next: Next,
44) -> Result<Response, StatusCode> {
45    let Some(expected) = &auth.token else {
46        return Ok(next.run(request).await);
47    };
48
49    let provided = request
50        .headers()
51        .get("authorization")
52        .and_then(|v| v.to_str().ok())
53        .and_then(|v| {
54            let lower = v.to_lowercase();
55            if lower.starts_with("bearer ") {
56                Some(v[BEARER_PREFIX_LEN..].to_string())
57            } else {
58                None
59            }
60        });
61
62    match provided {
63        Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
64            Ok(next.run(request).await)
65        }
66        _ => {
67            tracing::warn!("Victauri: rejected request — invalid or missing auth token");
68            Err(StatusCode::UNAUTHORIZED)
69        }
70    }
71}
72
73// ── Rate Limiter ───────────────────────────────────────────────────────────
74
75/// Lock-free token-bucket rate limiter using millisecond-precision timestamps for smooth refill.
76pub struct RateLimiterState {
77    tokens: AtomicU64,
78    max_tokens: u64,
79    last_refill_ms: AtomicU64,
80    refill_rate_per_sec: u64,
81}
82
83fn now_ms() -> u64 {
84    std::time::SystemTime::now()
85        .duration_since(std::time::UNIX_EPOCH)
86        .unwrap_or_default()
87        .as_millis() as u64
88}
89
90impl RateLimiterState {
91    /// Create a rate limiter with the given maximum requests per second.
92    #[must_use]
93    pub fn new(max_requests_per_sec: u64) -> Self {
94        Self {
95            tokens: AtomicU64::new(max_requests_per_sec),
96            max_tokens: max_requests_per_sec,
97            last_refill_ms: AtomicU64::new(now_ms()),
98            refill_rate_per_sec: max_requests_per_sec,
99        }
100    }
101
102    /// Atomically consume one token, returning `true` if the request is allowed.
103    pub fn try_acquire(&self) -> bool {
104        self.refill();
105        loop {
106            let current = self.tokens.load(Ordering::Relaxed);
107            if current == 0 {
108                return false;
109            }
110            if self
111                .tokens
112                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
113                .is_ok()
114            {
115                return true;
116            }
117        }
118    }
119
120    fn refill(&self) {
121        let now = now_ms();
122        let last = self.last_refill_ms.load(Ordering::Relaxed);
123        let elapsed_ms = now.saturating_sub(last);
124        if elapsed_ms == 0 {
125            return;
126        }
127        let add = elapsed_ms * self.refill_rate_per_sec / 1000;
128        if add == 0 {
129            return;
130        }
131        if self
132            .last_refill_ms
133            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
134            .is_ok()
135        {
136            loop {
137                let current = self.tokens.load(Ordering::Relaxed);
138                let new_val = (current + add).min(self.max_tokens);
139                if self
140                    .tokens
141                    .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
142                    .is_ok()
143                {
144                    break;
145                }
146            }
147        }
148    }
149}
150
151/// Axum middleware that rejects requests with 429 when the token bucket is exhausted.
152///
153/// # Errors
154///
155/// Returns [`StatusCode::TOO_MANY_REQUESTS`] if the token bucket has no remaining capacity.
156pub async fn rate_limit(
157    axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
158    request: Request,
159    next: Next,
160) -> Result<Response, StatusCode> {
161    if limiter.try_acquire() {
162        Ok(next.run(request).await)
163    } else {
164        Err(StatusCode::TOO_MANY_REQUESTS)
165    }
166}
167
168const DEFAULT_RATE_LIMIT: u64 = 1000;
169
170/// Create a rate limiter with the default capacity of 1000 requests per second.
171#[must_use]
172pub fn default_rate_limiter() -> Arc<RateLimiterState> {
173    Arc::new(RateLimiterState::new(DEFAULT_RATE_LIMIT))
174}
175
176// ── Security Middlewares ──────────────────────────────────────────────────
177
178/// Axum middleware that blocks DNS rebinding attacks.
179///
180/// Rejects any request where the Host header is not a localhost address.
181///
182/// # Errors
183///
184/// Returns [`StatusCode::FORBIDDEN`] if the `Host` header is not `localhost`, `127.0.0.1`, or `::1`.
185pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
186    let host = request
187        .headers()
188        .get("host")
189        .and_then(|v| v.to_str().ok())
190        .unwrap_or("");
191    let host_name = if host.starts_with('[') {
192        // Bracketed IPv6: [::1] or [::1]:7373
193        host.split(']').next().map_or(host, |s| &s[1..])
194    } else if host.contains("::") {
195        // Bare IPv6 (no brackets): ::1
196        host
197    } else {
198        // IPv4 or hostname, strip port: 127.0.0.1:7373 → 127.0.0.1
199        host.split(':').next().unwrap_or(host)
200    };
201    let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1");
202    if !is_allowed {
203        tracing::warn!("DNS rebinding attempt blocked: Host={host}");
204        return Err(StatusCode::FORBIDDEN);
205    }
206    Ok(next.run(request).await)
207}
208
209/// Axum middleware that blocks cross-origin requests from browsers.
210///
211/// # Errors
212///
213/// Returns [`StatusCode::FORBIDDEN`] if the `Origin` header is present and does not match a
214/// localhost or `tauri://` origin.
215pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
216    if let Some(origin) = request
217        .headers()
218        .get("origin")
219        .and_then(|v| v.to_str().ok())
220        && !is_allowed_origin(origin)
221    {
222        tracing::warn!("Cross-origin request blocked: Origin={origin}");
223        return Err(StatusCode::FORBIDDEN);
224    }
225    Ok(next.run(request).await)
226}
227
228fn is_allowed_origin(origin: &str) -> bool {
229    if origin.starts_with("tauri://") {
230        return true;
231    }
232    let Ok(parsed) = Url::parse(origin) else {
233        return false;
234    };
235    matches!(parsed.scheme(), "http" | "https")
236        && matches!(
237            parsed.host_str(),
238            Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
239        )
240}
241
242/// Axum middleware that sets security-hardening response headers on every response.
243pub async fn security_headers(request: Request, next: Next) -> Response {
244    let mut response = next.run(request).await;
245    let headers = response.headers_mut();
246    headers.insert(
247        axum::http::header::X_CONTENT_TYPE_OPTIONS,
248        axum::http::HeaderValue::from_static("nosniff"),
249    );
250    headers.insert(
251        axum::http::header::CACHE_CONTROL,
252        axum::http::HeaderValue::from_static("no-store"),
253    );
254    headers.insert(
255        axum::http::header::HeaderName::from_static("x-frame-options"),
256        axum::http::HeaderValue::from_static("DENY"),
257    );
258    headers.insert(
259        axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
260        axum::http::HeaderValue::from_static("null"),
261    );
262    headers.insert(
263        axum::http::header::HeaderName::from_static("content-security-policy"),
264        axum::http::HeaderValue::from_static("default-src 'none'"),
265    );
266    response
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use axum::Router;
273    use axum::body::Body;
274    use axum::middleware;
275    use axum::routing::get;
276    use tower::ServiceExt; // for oneshot
277
278    async fn ok_handler() -> &'static str {
279        "ok"
280    }
281
282    #[test]
283    fn token_generation_is_unique() {
284        let t1 = generate_token();
285        let t2 = generate_token();
286        assert_ne!(t1, t2);
287        assert_eq!(t1.len(), 36); // UUID v4 format
288    }
289
290    #[test]
291    fn token_is_valid_uuid() {
292        let token = generate_token();
293        assert!(uuid::Uuid::parse_str(&token).is_ok());
294    }
295
296    #[test]
297    fn rate_limiter_allows_within_budget() {
298        let limiter = RateLimiterState::new(10);
299        for _ in 0..10 {
300            assert!(limiter.try_acquire());
301        }
302    }
303
304    #[test]
305    fn rate_limiter_denies_when_exhausted() {
306        let limiter = RateLimiterState::new(5);
307        for _ in 0..5 {
308            assert!(limiter.try_acquire());
309        }
310        assert!(!limiter.try_acquire());
311    }
312
313    #[test]
314    fn rate_limiter_initial_tokens_match_max() {
315        let limiter = RateLimiterState::new(42);
316        assert_eq!(limiter.tokens.load(Ordering::Relaxed), 42);
317        assert_eq!(limiter.max_tokens, 42);
318    }
319
320    #[test]
321    fn rate_limiter_concurrent_acquire() {
322        // Use a large bucket so time-based refills (1 per second) are negligible
323        let limiter = Arc::new(RateLimiterState::new(1000));
324        let mut handles = vec![];
325        for _ in 0..10 {
326            let l = limiter.clone();
327            handles.push(std::thread::spawn(move || {
328                let mut acquired = 0;
329                for _ in 0..200 {
330                    if l.try_acquire() {
331                        acquired += 1;
332                    }
333                }
334                acquired
335            }));
336        }
337        let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
338        // All 1000 tokens should be dispensed; a time-based refill may add a few
339        assert!((1000..=1010).contains(&total));
340    }
341
342    #[test]
343    fn default_rate_limiter_has_expected_tokens() {
344        let limiter = default_rate_limiter();
345        assert_eq!(limiter.max_tokens, 1000);
346    }
347
348    #[test]
349    fn rate_limiter_zero_capacity() {
350        let limiter = RateLimiterState::new(0);
351        assert!(!limiter.try_acquire());
352    }
353
354    // ── DNS Rebinding Guard tests ─────────────────────────────────────────
355
356    fn dns_rebinding_router() -> Router {
357        Router::new()
358            .route("/test", get(ok_handler))
359            .layer(middleware::from_fn(dns_rebinding_guard))
360    }
361
362    fn dns_request(host: Option<&str>) -> Request<Body> {
363        let mut builder = Request::builder().uri("/test");
364        if let Some(h) = host {
365            builder = builder.header("host", h);
366        }
367        builder.body(Body::empty()).unwrap()
368    }
369
370    #[tokio::test]
371    async fn dns_rebinding_allows_localhost() {
372        let app = dns_rebinding_router();
373        let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
374        assert_eq!(resp.status(), StatusCode::OK);
375    }
376
377    #[tokio::test]
378    async fn dns_rebinding_allows_127_0_0_1() {
379        let app = dns_rebinding_router();
380        let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
381        assert_eq!(resp.status(), StatusCode::OK);
382    }
383
384    #[tokio::test]
385    async fn dns_rebinding_allows_ipv6_bracketed() {
386        let app = dns_rebinding_router();
387        let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
388        assert_eq!(resp.status(), StatusCode::OK);
389    }
390
391    #[tokio::test]
392    async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
393        let app = dns_rebinding_router();
394        let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
395        assert_eq!(resp.status(), StatusCode::OK);
396    }
397
398    #[tokio::test]
399    async fn dns_rebinding_allows_ipv6_bare() {
400        let app = dns_rebinding_router();
401        let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
402        assert_eq!(resp.status(), StatusCode::OK);
403    }
404
405    #[tokio::test]
406    async fn dns_rebinding_blocks_empty_host() {
407        let app = dns_rebinding_router();
408        let resp = app.oneshot(dns_request(None)).await.unwrap();
409        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
410    }
411
412    #[tokio::test]
413    async fn dns_rebinding_blocks_evil_com() {
414        let app = dns_rebinding_router();
415        let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
416        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
417    }
418
419    #[tokio::test]
420    async fn dns_rebinding_blocks_localhost_subdomain() {
421        let app = dns_rebinding_router();
422        let resp = app
423            .oneshot(dns_request(Some("localhost.evil.com")))
424            .await
425            .unwrap();
426        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
427    }
428
429    #[tokio::test]
430    async fn dns_rebinding_blocks_ip_subdomain() {
431        let app = dns_rebinding_router();
432        let resp = app
433            .oneshot(dns_request(Some("127.0.0.1.evil.com")))
434            .await
435            .unwrap();
436        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
437    }
438
439    // ── Origin Guard tests ────────────────────────────────────────────────
440
441    fn origin_router() -> Router {
442        Router::new()
443            .route("/test", get(ok_handler))
444            .layer(middleware::from_fn(origin_guard))
445    }
446
447    fn origin_request(origin: Option<&str>) -> Request<Body> {
448        let mut builder = Request::builder().uri("/test");
449        if let Some(o) = origin {
450            builder = builder.header("origin", o);
451        }
452        builder.body(Body::empty()).unwrap()
453    }
454
455    #[tokio::test]
456    async fn origin_allows_no_origin() {
457        let app = origin_router();
458        let resp = app.oneshot(origin_request(None)).await.unwrap();
459        assert_eq!(resp.status(), StatusCode::OK);
460    }
461
462    #[tokio::test]
463    async fn origin_allows_localhost_http() {
464        let app = origin_router();
465        let resp = app
466            .oneshot(origin_request(Some("http://localhost:3000")))
467            .await
468            .unwrap();
469        assert_eq!(resp.status(), StatusCode::OK);
470    }
471
472    #[tokio::test]
473    async fn origin_allows_127_0_0_1_https() {
474        let app = origin_router();
475        let resp = app
476            .oneshot(origin_request(Some("https://127.0.0.1:8080")))
477            .await
478            .unwrap();
479        assert_eq!(resp.status(), StatusCode::OK);
480    }
481
482    #[tokio::test]
483    async fn origin_allows_tauri_scheme() {
484        let app = origin_router();
485        let resp = app
486            .oneshot(origin_request(Some("tauri://localhost")))
487            .await
488            .unwrap();
489        assert_eq!(resp.status(), StatusCode::OK);
490    }
491
492    #[tokio::test]
493    async fn origin_blocks_null() {
494        let app = origin_router();
495        let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
496        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
497    }
498
499    #[tokio::test]
500    async fn origin_blocks_evil_com() {
501        let app = origin_router();
502        let resp = app
503            .oneshot(origin_request(Some("http://evil.com")))
504            .await
505            .unwrap();
506        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
507    }
508
509    // ── Security Headers tests ────────────────────────────────────────────
510
511    fn security_headers_router() -> Router {
512        Router::new()
513            .route("/test", get(ok_handler))
514            .layer(middleware::from_fn(security_headers))
515    }
516
517    #[tokio::test]
518    async fn security_headers_x_content_type_options() {
519        let app = security_headers_router();
520        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
521        let resp = app.oneshot(req).await.unwrap();
522        assert_eq!(resp.status(), StatusCode::OK);
523        assert_eq!(
524            resp.headers().get("x-content-type-options").unwrap(),
525            "nosniff"
526        );
527    }
528
529    #[tokio::test]
530    async fn security_headers_cache_control() {
531        let app = security_headers_router();
532        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
533        let resp = app.oneshot(req).await.unwrap();
534        assert_eq!(resp.status(), StatusCode::OK);
535        assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
536    }
537
538    #[tokio::test]
539    async fn security_headers_x_frame_options() {
540        let app = security_headers_router();
541        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
542        let resp = app.oneshot(req).await.unwrap();
543        assert_eq!(resp.status(), StatusCode::OK);
544        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
545    }
546
547    // ── Auth middleware integration tests ─────────────────────────────────
548
549    fn auth_router(token: Option<&str>) -> Router {
550        let state = Arc::new(AuthState {
551            token: token.map(String::from),
552        });
553        Router::new()
554            .route("/test", get(ok_handler))
555            .layer(middleware::from_fn_with_state(state, require_auth))
556    }
557
558    fn auth_request(token: Option<&str>) -> Request<Body> {
559        let mut builder = Request::builder().uri("/test");
560        if let Some(t) = token {
561            builder = builder.header("authorization", format!("Bearer {t}"));
562        }
563        builder.body(Body::empty()).unwrap()
564    }
565
566    #[tokio::test]
567    async fn auth_allows_correct_token() {
568        let app = auth_router(Some("secret-123"));
569        let resp = app.oneshot(auth_request(Some("secret-123"))).await.unwrap();
570        assert_eq!(resp.status(), StatusCode::OK);
571    }
572
573    #[tokio::test]
574    async fn auth_rejects_wrong_token() {
575        let app = auth_router(Some("secret-123"));
576        let resp = app
577            .oneshot(auth_request(Some("wrong-token")))
578            .await
579            .unwrap();
580        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
581    }
582
583    #[tokio::test]
584    async fn auth_rejects_missing_token() {
585        let app = auth_router(Some("secret-123"));
586        let resp = app.oneshot(auth_request(None)).await.unwrap();
587        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
588    }
589
590    #[tokio::test]
591    async fn auth_allows_any_when_disabled() {
592        let app = auth_router(None);
593        let resp = app.oneshot(auth_request(None)).await.unwrap();
594        assert_eq!(resp.status(), StatusCode::OK);
595    }
596
597    #[tokio::test]
598    async fn auth_case_insensitive_bearer_prefix() {
599        let state = Arc::new(AuthState {
600            token: Some("my-token".into()),
601        });
602        let app = Router::new()
603            .route("/test", get(ok_handler))
604            .layer(middleware::from_fn_with_state(state, require_auth));
605
606        let req = Request::builder()
607            .uri("/test")
608            .header("authorization", "BEARER my-token")
609            .body(Body::empty())
610            .unwrap();
611        let resp = app.oneshot(req).await.unwrap();
612        assert_eq!(resp.status(), StatusCode::OK);
613    }
614
615    #[tokio::test]
616    async fn auth_rejects_non_bearer_scheme() {
617        let app = auth_router(Some("secret"));
618        let req = Request::builder()
619            .uri("/test")
620            .header("authorization", "Basic c2VjcmV0")
621            .body(Body::empty())
622            .unwrap();
623        let resp = app.oneshot(req).await.unwrap();
624        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
625    }
626
627    // ── Rate limiter middleware integration test ──────────────────────────
628
629    #[tokio::test]
630    async fn rate_limiter_returns_429_when_exhausted() {
631        let limiter = Arc::new(RateLimiterState::new(2));
632        let app = Router::new()
633            .route("/test", get(ok_handler))
634            .layer(middleware::from_fn_with_state(limiter, rate_limit));
635
636        let app2 = app.clone();
637        let app3 = app2.clone();
638
639        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
640        assert_eq!(app.oneshot(req).await.unwrap().status(), StatusCode::OK);
641
642        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
643        assert_eq!(app2.oneshot(req).await.unwrap().status(), StatusCode::OK);
644
645        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
646        assert_eq!(
647            app3.oneshot(req).await.unwrap().status(),
648            StatusCode::TOO_MANY_REQUESTS
649        );
650    }
651
652    // ── Combined security layer test ─────────────────────────────────────
653
654    #[tokio::test]
655    async fn combined_layers_enforce_all_guards() {
656        let auth_state = Arc::new(AuthState {
657            token: Some("tok-123".into()),
658        });
659        let limiter = Arc::new(RateLimiterState::new(100));
660
661        let app = Router::new()
662            .route("/test", get(ok_handler))
663            .layer(middleware::from_fn_with_state(auth_state, require_auth))
664            .layer(middleware::from_fn_with_state(limiter, rate_limit))
665            .layer(middleware::from_fn(security_headers))
666            .layer(middleware::from_fn(origin_guard))
667            .layer(middleware::from_fn(dns_rebinding_guard));
668
669        // Good request: all guards pass
670        let req = Request::builder()
671            .uri("/test")
672            .header("authorization", "Bearer tok-123")
673            .header("host", "127.0.0.1:7373")
674            .body(Body::empty())
675            .unwrap();
676        let resp = app.clone().oneshot(req).await.unwrap();
677        assert_eq!(resp.status(), StatusCode::OK);
678        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
679
680        // Bad host: DNS rebinding guard blocks
681        let req = Request::builder()
682            .uri("/test")
683            .header("authorization", "Bearer tok-123")
684            .header("host", "evil.com")
685            .body(Body::empty())
686            .unwrap();
687        let resp = app.clone().oneshot(req).await.unwrap();
688        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
689
690        // Bad origin: origin guard blocks
691        let req = Request::builder()
692            .uri("/test")
693            .header("authorization", "Bearer tok-123")
694            .header("host", "localhost")
695            .header("origin", "https://evil.com")
696            .body(Body::empty())
697            .unwrap();
698        let resp = app.clone().oneshot(req).await.unwrap();
699        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
700
701        // Missing auth: auth middleware blocks
702        let req = Request::builder()
703            .uri("/test")
704            .header("host", "localhost")
705            .body(Body::empty())
706            .unwrap();
707        let resp = app.oneshot(req).await.unwrap();
708        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
709    }
710
711    #[test]
712    fn origin_guard_allows_localhost_variants() {
713        assert!(is_allowed_origin("http://localhost"));
714        assert!(is_allowed_origin("http://localhost:7373"));
715        assert!(is_allowed_origin("https://localhost"));
716        assert!(is_allowed_origin("https://localhost:443"));
717        assert!(is_allowed_origin("http://127.0.0.1"));
718        assert!(is_allowed_origin("http://127.0.0.1:8080"));
719        assert!(is_allowed_origin("https://127.0.0.1"));
720        assert!(is_allowed_origin("http://[::1]"));
721        assert!(is_allowed_origin("http://[::1]:7373"));
722        assert!(is_allowed_origin("tauri://localhost"));
723        assert!(is_allowed_origin("tauri://some-app"));
724    }
725
726    #[test]
727    fn origin_guard_rejects_prefix_smuggling() {
728        assert!(!is_allowed_origin("http://localhost.evil.com"));
729        assert!(!is_allowed_origin("https://localhost.evil.com"));
730        assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
731        assert!(!is_allowed_origin("http://[::1].evil.com"));
732    }
733
734    #[test]
735    fn origin_guard_rejects_userinfo_trick() {
736        assert!(!is_allowed_origin("http://localhost@evil.com"));
737        assert!(!is_allowed_origin("http://127.0.0.1@evil.com"));
738    }
739
740    #[test]
741    fn origin_guard_rejects_foreign_and_malformed() {
742        assert!(!is_allowed_origin("http://evil.com"));
743        assert!(!is_allowed_origin("https://attacker.io"));
744        assert!(!is_allowed_origin("not-a-url"));
745        assert!(!is_allowed_origin(""));
746        assert!(!is_allowed_origin("ftp://localhost"));
747    }
748
749    // ── Constant-time comparison tests ───────────────────────────────────
750
751    #[test]
752    fn constant_time_eq_equal_strings() {
753        assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
754    }
755
756    #[test]
757    fn constant_time_eq_different_strings() {
758        assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
759    }
760
761    #[test]
762    fn constant_time_eq_different_lengths() {
763        assert!(!constant_time_eq(b"short", b"longer-string"));
764    }
765
766    #[test]
767    fn constant_time_eq_empty_strings() {
768        assert!(constant_time_eq(b"", b""));
769    }
770
771    #[test]
772    fn constant_time_eq_one_empty() {
773        assert!(!constant_time_eq(b"", b"notempty"));
774        assert!(!constant_time_eq(b"notempty", b""));
775    }
776
777    #[test]
778    fn constant_time_eq_single_bit_difference() {
779        // 'A' = 0x41, 'B' = 0x42 — differ by one bit
780        assert!(!constant_time_eq(b"A", b"B"));
781    }
782
783    // ── Security headers: CORS + CSP tests ───────────────────────────────
784
785    #[tokio::test]
786    async fn security_headers_cors_deny() {
787        let app = security_headers_router();
788        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
789        let resp = app.oneshot(req).await.unwrap();
790        assert_eq!(
791            resp.headers().get("access-control-allow-origin").unwrap(),
792            "null"
793        );
794    }
795
796    #[tokio::test]
797    async fn security_headers_csp() {
798        let app = security_headers_router();
799        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
800        let resp = app.oneshot(req).await.unwrap();
801        assert_eq!(
802            resp.headers().get("content-security-policy").unwrap(),
803            "default-src 'none'"
804        );
805    }
806}