Skip to main content

victauri_browser/
auth.rs

1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8const BEARER_PREFIX_LEN: usize = "Bearer ".len();
9
10fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
11    if a.len() != b.len() {
12        return false;
13    }
14    a.iter()
15        .zip(b.iter())
16        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
17        == 0
18}
19
20/// Generate a random UUID v4 token for Bearer authentication.
21#[must_use]
22pub fn generate_token() -> String {
23    uuid::Uuid::new_v4().to_string()
24}
25
26#[derive(Clone)]
27pub struct AuthState {
28    pub token: Option<String>,
29}
30
31/// Axum middleware that validates Bearer token authentication.
32///
33/// # Errors
34///
35/// Returns `401 Unauthorized` if the token is missing or invalid.
36pub async fn require_auth(
37    axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
38    request: Request,
39    next: Next,
40) -> Result<Response, StatusCode> {
41    let Some(expected) = &auth.token else {
42        return Ok(next.run(request).await);
43    };
44
45    let provided = request
46        .headers()
47        .get("authorization")
48        .and_then(|v| v.to_str().ok())
49        .and_then(|v| {
50            let lower = v.to_lowercase();
51            if lower.starts_with("bearer ") {
52                Some(v[BEARER_PREFIX_LEN..].to_string())
53            } else {
54                None
55            }
56        });
57
58    match provided {
59        Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
60            Ok(next.run(request).await)
61        }
62        _ => {
63            tracing::warn!("victauri-browser: rejected request — invalid or missing auth token");
64            Err(StatusCode::UNAUTHORIZED)
65        }
66    }
67}
68
69fn now_ms() -> u64 {
70    std::time::SystemTime::now()
71        .duration_since(std::time::UNIX_EPOCH)
72        .unwrap_or_default()
73        .as_millis() as u64
74}
75
76pub struct RateLimiterState {
77    tokens: AtomicU64,
78    max_tokens: u64,
79    last_refill_ms: AtomicU64,
80    refill_rate_per_sec: u64,
81}
82
83impl RateLimiterState {
84    #[must_use]
85    pub fn new(max_requests_per_sec: u64) -> Self {
86        Self {
87            tokens: AtomicU64::new(max_requests_per_sec),
88            max_tokens: max_requests_per_sec,
89            last_refill_ms: AtomicU64::new(now_ms()),
90            refill_rate_per_sec: max_requests_per_sec,
91        }
92    }
93
94    /// Try to consume one token. Returns `true` if allowed.
95    pub fn try_acquire(&self) -> bool {
96        self.refill();
97        loop {
98            let current = self.tokens.load(Ordering::Relaxed);
99            if current == 0 {
100                return false;
101            }
102            if self
103                .tokens
104                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
105                .is_ok()
106            {
107                return true;
108            }
109        }
110    }
111
112    fn refill(&self) {
113        let now = now_ms();
114        let last = self.last_refill_ms.load(Ordering::Relaxed);
115        let elapsed_ms = now.saturating_sub(last);
116        if elapsed_ms < 10 {
117            return;
118        }
119        let new_tokens = (elapsed_ms * self.refill_rate_per_sec) / 1000;
120        if new_tokens == 0 {
121            return;
122        }
123        if self
124            .last_refill_ms
125            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
126            .is_ok()
127        {
128            let current = self.tokens.load(Ordering::Relaxed);
129            let capped = (current + new_tokens).min(self.max_tokens);
130            self.tokens.store(capped, Ordering::Relaxed);
131        }
132    }
133}
134
135/// Default rate limiter: 1000 requests per second.
136#[must_use]
137pub fn default_rate_limiter() -> Arc<RateLimiterState> {
138    Arc::new(RateLimiterState::new(1000))
139}
140
141/// Axum middleware for rate limiting.
142///
143/// # Errors
144///
145/// Returns `429 Too Many Requests` when the rate limit is exceeded.
146pub async fn rate_limit(
147    axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
148    request: Request,
149    next: Next,
150) -> Result<Response, StatusCode> {
151    if limiter.try_acquire() {
152        Ok(next.run(request).await)
153    } else {
154        Err(StatusCode::TOO_MANY_REQUESTS)
155    }
156}
157
158/// Security headers middleware: X-Content-Type-Options, Cache-Control.
159///
160/// # Panics
161/// Panics if header values cannot be parsed (hardcoded valid values).
162pub async fn security_headers(request: Request, next: Next) -> Response {
163    let mut response = next.run(request).await;
164    let headers = response.headers_mut();
165    headers.insert("x-content-type-options", "nosniff".parse().unwrap());
166    headers.insert("cache-control", "no-store".parse().unwrap());
167    response
168}
169
170/// Localhost origin guard: rejects requests with non-localhost Origin header.
171///
172/// Parses the origin as a URL and checks the host component directly,
173/// preventing bypass via subdomains like "localhost.evil.com".
174///
175/// # Errors
176/// Returns `403 Forbidden` if the Origin header contains a non-localhost host.
177pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
178    if let Some(origin) = request
179        .headers()
180        .get("origin")
181        .and_then(|v| v.to_str().ok())
182    {
183        let is_local = is_localhost_origin(origin);
184        if !is_local {
185            tracing::warn!("rejected non-local origin: {origin}");
186            return Err(StatusCode::FORBIDDEN);
187        }
188    }
189    Ok(next.run(request).await)
190}
191
192fn is_localhost_origin(origin: &str) -> bool {
193    // Extract the host from scheme://host[:port]
194    let after_scheme = match origin.find("://") {
195        Some(i) => &origin[i + 3..],
196        None => origin,
197    };
198    // Strip port if present
199    let host = if after_scheme.starts_with('[') {
200        // IPv6: [::1]:port
201        match after_scheme.find(']') {
202            Some(i) => &after_scheme[..=i],
203            None => after_scheme,
204        }
205    } else {
206        after_scheme.split(':').next().unwrap_or(after_scheme)
207    };
208    // Strip trailing path if any
209    let host = host.split('/').next().unwrap_or(host);
210
211    host == "127.0.0.1" || host == "localhost" || host == "[::1]"
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn token_generation() {
220        let t1 = generate_token();
221        let t2 = generate_token();
222        assert_ne!(t1, t2);
223        assert_eq!(t1.len(), 36);
224    }
225
226    #[test]
227    fn rate_limiter_allows_within_budget() {
228        let limiter = RateLimiterState::new(10);
229        for _ in 0..10 {
230            assert!(limiter.try_acquire());
231        }
232        assert!(!limiter.try_acquire());
233    }
234
235    #[test]
236    fn constant_time_eq_works() {
237        assert!(constant_time_eq(b"hello", b"hello"));
238        assert!(!constant_time_eq(b"hello", b"world"));
239        assert!(!constant_time_eq(b"hello", b"hell"));
240    }
241
242    #[test]
243    fn constant_time_eq_empty_strings() {
244        assert!(constant_time_eq(b"", b""));
245        assert!(!constant_time_eq(b"", b"x"));
246    }
247
248    #[test]
249    fn constant_time_eq_single_bit_diff() {
250        assert!(!constant_time_eq(b"\x00", b"\x01"));
251        assert!(!constant_time_eq(b"\xff", b"\xfe"));
252    }
253
254    #[test]
255    fn rate_limiter_single_token() {
256        let limiter = RateLimiterState::new(1);
257        assert!(limiter.try_acquire());
258        assert!(!limiter.try_acquire());
259    }
260
261    #[test]
262    fn token_format_is_uuid() {
263        let token = generate_token();
264        assert_eq!(token.len(), 36);
265        assert_eq!(token.chars().filter(|c| *c == '-').count(), 4);
266    }
267
268    #[test]
269    fn default_rate_limiter_has_budget() {
270        let limiter = default_rate_limiter();
271        assert!(limiter.try_acquire());
272    }
273
274    // --- Adversarial stress tests ---
275
276    #[test]
277    fn rate_limiter_exact_boundary() {
278        let limiter = RateLimiterState::new(100);
279        for i in 0..100 {
280            assert!(limiter.try_acquire(), "failed at iteration {i}");
281        }
282        assert!(!limiter.try_acquire());
283        assert!(!limiter.try_acquire());
284        assert!(!limiter.try_acquire());
285    }
286
287    #[test]
288    fn rate_limiter_concurrent_contention() {
289        use std::sync::Arc;
290        use std::thread;
291
292        let limiter = Arc::new(RateLimiterState::new(50));
293        let mut handles = vec![];
294
295        for _ in 0..10 {
296            let l = Arc::clone(&limiter);
297            handles.push(thread::spawn(move || {
298                let mut acquired = 0u32;
299                for _ in 0..20 {
300                    if l.try_acquire() {
301                        acquired += 1;
302                    }
303                }
304                acquired
305            }));
306        }
307
308        let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
309        // With 50 tokens and 10 threads each trying 20 times, at most 50 succeed
310        assert!(total <= 50, "acquired {total} but budget was 50");
311        assert!(total >= 45, "should acquire most tokens, got {total}");
312    }
313
314    #[test]
315    fn constant_time_eq_long_strings() {
316        let a = "a".repeat(10_000);
317        let b = "a".repeat(10_000);
318        assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
319
320        let mut c = "a".repeat(10_000);
321        c.push('b');
322        assert!(!constant_time_eq(a.as_bytes(), c.as_bytes()));
323    }
324
325    #[test]
326    fn constant_time_eq_timing_consistency() {
327        let token = "8f14e45f-ceea-367f-a27f-c790e5a0fdc4";
328        let wrong1 = "0000000f-ceea-367f-a27f-c790e5a0fdc4";
329        let wrong2 = "8f14e45f-ceea-367f-a27f-c790e5a0fd00";
330
331        // Both should fail regardless of where the mismatch is
332        assert!(!constant_time_eq(token.as_bytes(), wrong1.as_bytes()));
333        assert!(!constant_time_eq(token.as_bytes(), wrong2.as_bytes()));
334    }
335
336    #[test]
337    fn token_uniqueness_over_1000_generations() {
338        let mut tokens = std::collections::HashSet::new();
339        for _ in 0..1000 {
340            let t = generate_token();
341            assert!(tokens.insert(t), "duplicate token generated");
342        }
343    }
344
345    #[test]
346    fn rate_limiter_zero_budget() {
347        let limiter = RateLimiterState::new(0);
348        assert!(!limiter.try_acquire());
349    }
350
351    #[test]
352    fn constant_time_eq_all_byte_values() {
353        for b in 0..=255u8 {
354            let a = [b];
355            assert!(constant_time_eq(&a, &a));
356            if b < 255 {
357                let c = [b + 1];
358                assert!(!constant_time_eq(&a, &c));
359            }
360        }
361    }
362
363    // --- Origin guard tests ---
364
365    #[test]
366    fn localhost_origin_accepted() {
367        assert!(is_localhost_origin("http://localhost:3000"));
368        assert!(is_localhost_origin("http://localhost"));
369        assert!(is_localhost_origin("https://localhost:7474"));
370    }
371
372    #[test]
373    fn ipv4_loopback_accepted() {
374        assert!(is_localhost_origin("http://127.0.0.1:7474"));
375        assert!(is_localhost_origin("http://127.0.0.1"));
376        assert!(is_localhost_origin("https://127.0.0.1:443"));
377    }
378
379    #[test]
380    fn ipv6_loopback_accepted() {
381        assert!(is_localhost_origin("http://[::1]:7474"));
382        assert!(is_localhost_origin("http://[::1]"));
383    }
384
385    #[test]
386    fn subdomain_bypass_rejected() {
387        assert!(!is_localhost_origin("https://localhost.evil.com"));
388        assert!(!is_localhost_origin("https://127.0.0.1.evil.com"));
389        assert!(!is_localhost_origin("https://evil-localhost.com"));
390    }
391
392    #[test]
393    fn path_bypass_rejected() {
394        assert!(!is_localhost_origin("https://evil.com/localhost"));
395        assert!(!is_localhost_origin("https://evil.com/127.0.0.1"));
396    }
397
398    #[test]
399    fn external_origins_rejected() {
400        assert!(!is_localhost_origin("https://google.com"));
401        assert!(!is_localhost_origin("https://example.com:443"));
402        assert!(!is_localhost_origin("http://attacker.com"));
403    }
404}