Skip to main content

victauri_plugin/
auth.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7use url::Url;
8
9const BEARER_PREFIX_LEN: usize = "Bearer ".len();
10
11/// Generate a random `UUID` v4 token suitable for Bearer authentication.
12#[must_use]
13pub fn generate_token() -> String {
14    uuid::Uuid::new_v4().to_string()
15}
16
17/// Shared authentication state holding the optional Bearer token for the MCP server.
18#[derive(Clone)]
19pub struct AuthState {
20    /// The expected Bearer token, or `None` if authentication is disabled.
21    pub token: Option<String>,
22}
23
24/// Axum middleware that validates the `Authorization: Bearer <token>` header against [`AuthState`].
25///
26/// # Errors
27///
28/// Returns [`StatusCode::UNAUTHORIZED`] if the token is missing or invalid.
29pub async fn require_auth(
30    axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
31    request: Request,
32    next: Next,
33) -> Result<Response, StatusCode> {
34    let Some(expected) = &auth.token else {
35        return Ok(next.run(request).await);
36    };
37
38    let provided = request
39        .headers()
40        .get("authorization")
41        .and_then(|v| v.to_str().ok())
42        .and_then(|v| {
43            let lower = v.to_lowercase();
44            if lower.starts_with("bearer ") {
45                Some(v[BEARER_PREFIX_LEN..].to_string())
46            } else {
47                None
48            }
49        });
50
51    match provided {
52        Some(ref token) if token == expected => Ok(next.run(request).await),
53        _ => Err(StatusCode::UNAUTHORIZED),
54    }
55}
56
57// ── Rate Limiter ───────────────────────────────────────────────────────────
58
59/// Lock-free token-bucket rate limiter using millisecond-precision timestamps for smooth refill.
60pub struct RateLimiterState {
61    tokens: AtomicU64,
62    max_tokens: u64,
63    last_refill_ms: AtomicU64,
64    refill_rate_per_sec: u64,
65}
66
67fn now_ms() -> u64 {
68    std::time::SystemTime::now()
69        .duration_since(std::time::UNIX_EPOCH)
70        .unwrap_or_default()
71        .as_millis() as u64
72}
73
74impl RateLimiterState {
75    /// Create a rate limiter with the given maximum requests per second.
76    #[must_use]
77    pub fn new(max_requests_per_sec: u64) -> Self {
78        Self {
79            tokens: AtomicU64::new(max_requests_per_sec),
80            max_tokens: max_requests_per_sec,
81            last_refill_ms: AtomicU64::new(now_ms()),
82            refill_rate_per_sec: max_requests_per_sec,
83        }
84    }
85
86    /// Atomically consume one token, returning `true` if the request is allowed.
87    pub fn try_acquire(&self) -> bool {
88        self.refill();
89        loop {
90            let current = self.tokens.load(Ordering::Relaxed);
91            if current == 0 {
92                return false;
93            }
94            if self
95                .tokens
96                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
97                .is_ok()
98            {
99                return true;
100            }
101        }
102    }
103
104    fn refill(&self) {
105        let now = now_ms();
106        let last = self.last_refill_ms.load(Ordering::Relaxed);
107        let elapsed_ms = now.saturating_sub(last);
108        if elapsed_ms == 0 {
109            return;
110        }
111        let add = elapsed_ms * self.refill_rate_per_sec / 1000;
112        if add == 0 {
113            return;
114        }
115        if self
116            .last_refill_ms
117            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
118            .is_ok()
119        {
120            loop {
121                let current = self.tokens.load(Ordering::Relaxed);
122                let new_val = (current + add).min(self.max_tokens);
123                if self
124                    .tokens
125                    .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
126                    .is_ok()
127                {
128                    break;
129                }
130            }
131        }
132    }
133}
134
135/// Axum middleware that rejects requests with 429 when the token bucket is exhausted.
136///
137/// # Errors
138///
139/// Returns [`StatusCode::TOO_MANY_REQUESTS`] if the token bucket has no remaining capacity.
140pub async fn rate_limit(
141    axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
142    request: Request,
143    next: Next,
144) -> Result<Response, StatusCode> {
145    if limiter.try_acquire() {
146        Ok(next.run(request).await)
147    } else {
148        Err(StatusCode::TOO_MANY_REQUESTS)
149    }
150}
151
152const DEFAULT_RATE_LIMIT: u64 = 1000;
153
154/// Create a rate limiter with the default capacity of 1000 requests per second.
155#[must_use]
156pub fn default_rate_limiter() -> Arc<RateLimiterState> {
157    Arc::new(RateLimiterState::new(DEFAULT_RATE_LIMIT))
158}
159
160// ── Security Middlewares ──────────────────────────────────────────────────
161
162/// Axum middleware that blocks DNS rebinding attacks.
163///
164/// Rejects any request where the Host header is not a localhost address.
165///
166/// # Errors
167///
168/// Returns [`StatusCode::FORBIDDEN`] if the `Host` header is not `localhost`, `127.0.0.1`, or `::1`.
169pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
170    let host = request
171        .headers()
172        .get("host")
173        .and_then(|v| v.to_str().ok())
174        .unwrap_or("");
175    let host_name = if host.starts_with('[') {
176        host.split(']').next().map_or(host, |s| &s[1..])
177    } else {
178        host.split(':').next().unwrap_or(host)
179    };
180    let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1" | "");
181    if !is_allowed {
182        tracing::warn!("DNS rebinding attempt blocked: Host={host}");
183        return Err(StatusCode::FORBIDDEN);
184    }
185    Ok(next.run(request).await)
186}
187
188/// Axum middleware that blocks cross-origin requests from browsers.
189///
190/// # Errors
191///
192/// Returns [`StatusCode::FORBIDDEN`] if the `Origin` header is present and does not match a
193/// localhost or `tauri://` origin.
194pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
195    if let Some(origin) = request
196        .headers()
197        .get("origin")
198        .and_then(|v| v.to_str().ok())
199        && !is_allowed_origin(origin)
200    {
201        tracing::warn!("Cross-origin request blocked: Origin={origin}");
202        return Err(StatusCode::FORBIDDEN);
203    }
204    Ok(next.run(request).await)
205}
206
207fn is_allowed_origin(origin: &str) -> bool {
208    if origin.starts_with("tauri://") {
209        return true;
210    }
211    let Ok(parsed) = Url::parse(origin) else {
212        return false;
213    };
214    matches!(parsed.scheme(), "http" | "https")
215        && matches!(
216            parsed.host_str(),
217            Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
218        )
219}
220
221/// Axum middleware that sets security-hardening response headers on every response.
222pub async fn security_headers(request: Request, next: Next) -> Response {
223    let mut response = next.run(request).await;
224    let headers = response.headers_mut();
225    headers.insert(
226        axum::http::header::X_CONTENT_TYPE_OPTIONS,
227        axum::http::HeaderValue::from_static("nosniff"),
228    );
229    headers.insert(
230        axum::http::header::CACHE_CONTROL,
231        axum::http::HeaderValue::from_static("no-store"),
232    );
233    headers.insert(
234        axum::http::header::HeaderName::from_static("x-frame-options"),
235        axum::http::HeaderValue::from_static("DENY"),
236    );
237    response
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use axum::Router;
244    use axum::body::Body;
245    use axum::middleware;
246    use axum::routing::get;
247    use tower::ServiceExt; // for oneshot
248
249    async fn ok_handler() -> &'static str {
250        "ok"
251    }
252
253    #[test]
254    fn token_generation_is_unique() {
255        let t1 = generate_token();
256        let t2 = generate_token();
257        assert_ne!(t1, t2);
258        assert_eq!(t1.len(), 36); // UUID v4 format
259    }
260
261    #[test]
262    fn token_is_valid_uuid() {
263        let token = generate_token();
264        assert!(uuid::Uuid::parse_str(&token).is_ok());
265    }
266
267    #[test]
268    fn rate_limiter_allows_within_budget() {
269        let limiter = RateLimiterState::new(10);
270        for _ in 0..10 {
271            assert!(limiter.try_acquire());
272        }
273    }
274
275    #[test]
276    fn rate_limiter_denies_when_exhausted() {
277        let limiter = RateLimiterState::new(5);
278        for _ in 0..5 {
279            assert!(limiter.try_acquire());
280        }
281        assert!(!limiter.try_acquire());
282    }
283
284    #[test]
285    fn rate_limiter_initial_tokens_match_max() {
286        let limiter = RateLimiterState::new(42);
287        assert_eq!(limiter.tokens.load(Ordering::Relaxed), 42);
288        assert_eq!(limiter.max_tokens, 42);
289    }
290
291    #[test]
292    fn rate_limiter_concurrent_acquire() {
293        // Use a large bucket so time-based refills (1 per second) are negligible
294        let limiter = Arc::new(RateLimiterState::new(1000));
295        let mut handles = vec![];
296        for _ in 0..10 {
297            let l = limiter.clone();
298            handles.push(std::thread::spawn(move || {
299                let mut acquired = 0;
300                for _ in 0..200 {
301                    if l.try_acquire() {
302                        acquired += 1;
303                    }
304                }
305                acquired
306            }));
307        }
308        let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
309        // All 1000 tokens should be dispensed; a time-based refill may add a few
310        assert!((1000..=1010).contains(&total));
311    }
312
313    #[test]
314    fn default_rate_limiter_has_expected_tokens() {
315        let limiter = default_rate_limiter();
316        assert_eq!(limiter.max_tokens, 1000);
317    }
318
319    #[test]
320    fn rate_limiter_zero_capacity() {
321        let limiter = RateLimiterState::new(0);
322        assert!(!limiter.try_acquire());
323    }
324
325    // ── DNS Rebinding Guard tests ─────────────────────────────────────────
326
327    fn dns_rebinding_router() -> Router {
328        Router::new()
329            .route("/test", get(ok_handler))
330            .layer(middleware::from_fn(dns_rebinding_guard))
331    }
332
333    fn dns_request(host: Option<&str>) -> Request<Body> {
334        let mut builder = Request::builder().uri("/test");
335        if let Some(h) = host {
336            builder = builder.header("host", h);
337        }
338        builder.body(Body::empty()).unwrap()
339    }
340
341    #[tokio::test]
342    async fn dns_rebinding_allows_localhost() {
343        let app = dns_rebinding_router();
344        let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
345        assert_eq!(resp.status(), StatusCode::OK);
346    }
347
348    #[tokio::test]
349    async fn dns_rebinding_allows_127_0_0_1() {
350        let app = dns_rebinding_router();
351        let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
352        assert_eq!(resp.status(), StatusCode::OK);
353    }
354
355    #[tokio::test]
356    async fn dns_rebinding_allows_ipv6_bracketed() {
357        let app = dns_rebinding_router();
358        let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
359        assert_eq!(resp.status(), StatusCode::OK);
360    }
361
362    #[tokio::test]
363    async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
364        let app = dns_rebinding_router();
365        let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
366        assert_eq!(resp.status(), StatusCode::OK);
367    }
368
369    #[tokio::test]
370    async fn dns_rebinding_allows_ipv6_bare() {
371        let app = dns_rebinding_router();
372        let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
373        assert_eq!(resp.status(), StatusCode::OK);
374    }
375
376    #[tokio::test]
377    async fn dns_rebinding_allows_empty_host() {
378        let app = dns_rebinding_router();
379        let resp = app.oneshot(dns_request(None)).await.unwrap();
380        assert_eq!(resp.status(), StatusCode::OK);
381    }
382
383    #[tokio::test]
384    async fn dns_rebinding_blocks_evil_com() {
385        let app = dns_rebinding_router();
386        let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
387        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
388    }
389
390    #[tokio::test]
391    async fn dns_rebinding_blocks_localhost_subdomain() {
392        let app = dns_rebinding_router();
393        let resp = app
394            .oneshot(dns_request(Some("localhost.evil.com")))
395            .await
396            .unwrap();
397        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
398    }
399
400    #[tokio::test]
401    async fn dns_rebinding_blocks_ip_subdomain() {
402        let app = dns_rebinding_router();
403        let resp = app
404            .oneshot(dns_request(Some("127.0.0.1.evil.com")))
405            .await
406            .unwrap();
407        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
408    }
409
410    // ── Origin Guard tests ────────────────────────────────────────────────
411
412    fn origin_router() -> Router {
413        Router::new()
414            .route("/test", get(ok_handler))
415            .layer(middleware::from_fn(origin_guard))
416    }
417
418    fn origin_request(origin: Option<&str>) -> Request<Body> {
419        let mut builder = Request::builder().uri("/test");
420        if let Some(o) = origin {
421            builder = builder.header("origin", o);
422        }
423        builder.body(Body::empty()).unwrap()
424    }
425
426    #[tokio::test]
427    async fn origin_allows_no_origin() {
428        let app = origin_router();
429        let resp = app.oneshot(origin_request(None)).await.unwrap();
430        assert_eq!(resp.status(), StatusCode::OK);
431    }
432
433    #[tokio::test]
434    async fn origin_allows_localhost_http() {
435        let app = origin_router();
436        let resp = app
437            .oneshot(origin_request(Some("http://localhost:3000")))
438            .await
439            .unwrap();
440        assert_eq!(resp.status(), StatusCode::OK);
441    }
442
443    #[tokio::test]
444    async fn origin_allows_127_0_0_1_https() {
445        let app = origin_router();
446        let resp = app
447            .oneshot(origin_request(Some("https://127.0.0.1:8080")))
448            .await
449            .unwrap();
450        assert_eq!(resp.status(), StatusCode::OK);
451    }
452
453    #[tokio::test]
454    async fn origin_allows_tauri_scheme() {
455        let app = origin_router();
456        let resp = app
457            .oneshot(origin_request(Some("tauri://localhost")))
458            .await
459            .unwrap();
460        assert_eq!(resp.status(), StatusCode::OK);
461    }
462
463    #[tokio::test]
464    async fn origin_blocks_null() {
465        let app = origin_router();
466        let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
467        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
468    }
469
470    #[tokio::test]
471    async fn origin_blocks_evil_com() {
472        let app = origin_router();
473        let resp = app
474            .oneshot(origin_request(Some("http://evil.com")))
475            .await
476            .unwrap();
477        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
478    }
479
480    // ── Security Headers tests ────────────────────────────────────────────
481
482    fn security_headers_router() -> Router {
483        Router::new()
484            .route("/test", get(ok_handler))
485            .layer(middleware::from_fn(security_headers))
486    }
487
488    #[tokio::test]
489    async fn security_headers_x_content_type_options() {
490        let app = security_headers_router();
491        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
492        let resp = app.oneshot(req).await.unwrap();
493        assert_eq!(resp.status(), StatusCode::OK);
494        assert_eq!(
495            resp.headers().get("x-content-type-options").unwrap(),
496            "nosniff"
497        );
498    }
499
500    #[tokio::test]
501    async fn security_headers_cache_control() {
502        let app = security_headers_router();
503        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
504        let resp = app.oneshot(req).await.unwrap();
505        assert_eq!(resp.status(), StatusCode::OK);
506        assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
507    }
508
509    #[tokio::test]
510    async fn security_headers_x_frame_options() {
511        let app = security_headers_router();
512        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
513        let resp = app.oneshot(req).await.unwrap();
514        assert_eq!(resp.status(), StatusCode::OK);
515        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
516    }
517
518    // ── Auth middleware integration tests ─────────────────────────────────
519
520    fn auth_router(token: Option<&str>) -> Router {
521        let state = Arc::new(AuthState {
522            token: token.map(String::from),
523        });
524        Router::new()
525            .route("/test", get(ok_handler))
526            .layer(middleware::from_fn_with_state(state, require_auth))
527    }
528
529    fn auth_request(token: Option<&str>) -> Request<Body> {
530        let mut builder = Request::builder().uri("/test");
531        if let Some(t) = token {
532            builder = builder.header("authorization", format!("Bearer {t}"));
533        }
534        builder.body(Body::empty()).unwrap()
535    }
536
537    #[tokio::test]
538    async fn auth_allows_correct_token() {
539        let app = auth_router(Some("secret-123"));
540        let resp = app.oneshot(auth_request(Some("secret-123"))).await.unwrap();
541        assert_eq!(resp.status(), StatusCode::OK);
542    }
543
544    #[tokio::test]
545    async fn auth_rejects_wrong_token() {
546        let app = auth_router(Some("secret-123"));
547        let resp = app
548            .oneshot(auth_request(Some("wrong-token")))
549            .await
550            .unwrap();
551        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
552    }
553
554    #[tokio::test]
555    async fn auth_rejects_missing_token() {
556        let app = auth_router(Some("secret-123"));
557        let resp = app.oneshot(auth_request(None)).await.unwrap();
558        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
559    }
560
561    #[tokio::test]
562    async fn auth_allows_any_when_disabled() {
563        let app = auth_router(None);
564        let resp = app.oneshot(auth_request(None)).await.unwrap();
565        assert_eq!(resp.status(), StatusCode::OK);
566    }
567
568    #[tokio::test]
569    async fn auth_case_insensitive_bearer_prefix() {
570        let state = Arc::new(AuthState {
571            token: Some("my-token".into()),
572        });
573        let app = Router::new()
574            .route("/test", get(ok_handler))
575            .layer(middleware::from_fn_with_state(state, require_auth));
576
577        let req = Request::builder()
578            .uri("/test")
579            .header("authorization", "BEARER my-token")
580            .body(Body::empty())
581            .unwrap();
582        let resp = app.oneshot(req).await.unwrap();
583        assert_eq!(resp.status(), StatusCode::OK);
584    }
585
586    #[tokio::test]
587    async fn auth_rejects_non_bearer_scheme() {
588        let app = auth_router(Some("secret"));
589        let req = Request::builder()
590            .uri("/test")
591            .header("authorization", "Basic c2VjcmV0")
592            .body(Body::empty())
593            .unwrap();
594        let resp = app.oneshot(req).await.unwrap();
595        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
596    }
597
598    // ── Rate limiter middleware integration test ──────────────────────────
599
600    #[tokio::test]
601    async fn rate_limiter_returns_429_when_exhausted() {
602        let limiter = Arc::new(RateLimiterState::new(2));
603        let app = Router::new()
604            .route("/test", get(ok_handler))
605            .layer(middleware::from_fn_with_state(limiter, rate_limit));
606
607        let app2 = app.clone();
608        let app3 = app2.clone();
609
610        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
611        assert_eq!(app.oneshot(req).await.unwrap().status(), StatusCode::OK);
612
613        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
614        assert_eq!(app2.oneshot(req).await.unwrap().status(), StatusCode::OK);
615
616        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
617        assert_eq!(
618            app3.oneshot(req).await.unwrap().status(),
619            StatusCode::TOO_MANY_REQUESTS
620        );
621    }
622
623    // ── Combined security layer test ─────────────────────────────────────
624
625    #[tokio::test]
626    async fn combined_layers_enforce_all_guards() {
627        let auth_state = Arc::new(AuthState {
628            token: Some("tok-123".into()),
629        });
630        let limiter = Arc::new(RateLimiterState::new(100));
631
632        let app = Router::new()
633            .route("/test", get(ok_handler))
634            .layer(middleware::from_fn_with_state(auth_state, require_auth))
635            .layer(middleware::from_fn_with_state(limiter, rate_limit))
636            .layer(middleware::from_fn(security_headers))
637            .layer(middleware::from_fn(origin_guard))
638            .layer(middleware::from_fn(dns_rebinding_guard));
639
640        // Good request: all guards pass
641        let req = Request::builder()
642            .uri("/test")
643            .header("authorization", "Bearer tok-123")
644            .header("host", "127.0.0.1:7373")
645            .body(Body::empty())
646            .unwrap();
647        let resp = app.clone().oneshot(req).await.unwrap();
648        assert_eq!(resp.status(), StatusCode::OK);
649        assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
650
651        // Bad host: DNS rebinding guard blocks
652        let req = Request::builder()
653            .uri("/test")
654            .header("authorization", "Bearer tok-123")
655            .header("host", "evil.com")
656            .body(Body::empty())
657            .unwrap();
658        let resp = app.clone().oneshot(req).await.unwrap();
659        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
660
661        // Bad origin: origin guard blocks
662        let req = Request::builder()
663            .uri("/test")
664            .header("authorization", "Bearer tok-123")
665            .header("host", "localhost")
666            .header("origin", "https://evil.com")
667            .body(Body::empty())
668            .unwrap();
669        let resp = app.clone().oneshot(req).await.unwrap();
670        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
671
672        // Missing auth: auth middleware blocks
673        let req = Request::builder()
674            .uri("/test")
675            .header("host", "localhost")
676            .body(Body::empty())
677            .unwrap();
678        let resp = app.oneshot(req).await.unwrap();
679        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
680    }
681
682    #[test]
683    fn origin_guard_allows_localhost_variants() {
684        assert!(is_allowed_origin("http://localhost"));
685        assert!(is_allowed_origin("http://localhost:7373"));
686        assert!(is_allowed_origin("https://localhost"));
687        assert!(is_allowed_origin("https://localhost:443"));
688        assert!(is_allowed_origin("http://127.0.0.1"));
689        assert!(is_allowed_origin("http://127.0.0.1:8080"));
690        assert!(is_allowed_origin("https://127.0.0.1"));
691        assert!(is_allowed_origin("http://[::1]"));
692        assert!(is_allowed_origin("http://[::1]:7373"));
693        assert!(is_allowed_origin("tauri://localhost"));
694        assert!(is_allowed_origin("tauri://some-app"));
695    }
696
697    #[test]
698    fn origin_guard_rejects_prefix_smuggling() {
699        assert!(!is_allowed_origin("http://localhost.evil.com"));
700        assert!(!is_allowed_origin("https://localhost.evil.com"));
701        assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
702        assert!(!is_allowed_origin("http://[::1].evil.com"));
703    }
704
705    #[test]
706    fn origin_guard_rejects_userinfo_trick() {
707        assert!(!is_allowed_origin("http://localhost@evil.com"));
708        assert!(!is_allowed_origin("http://127.0.0.1@evil.com"));
709    }
710
711    #[test]
712    fn origin_guard_rejects_foreign_and_malformed() {
713        assert!(!is_allowed_origin("http://evil.com"));
714        assert!(!is_allowed_origin("https://attacker.io"));
715        assert!(!is_allowed_origin("not-a-url"));
716        assert!(!is_allowed_origin(""));
717        assert!(!is_allowed_origin("ftp://localhost"));
718    }
719}