Skip to main content

victauri_core/
security.rs

1//! Shared security primitives for Victauri's localhost HTTP server.
2//!
3//! This module provides the pure-logic building blocks that `victauri-plugin`
4//! uses in its axum middleware stack. Keeping them here (rather than inline in the
5//! plugin) keeps the security logic unit-testable without a Tauri runtime.
6
7use std::sync::atomic::{AtomicU64, Ordering};
8
9// ── Constant-time comparison ─────────────────────────────────────────────
10
11/// Constant-time byte comparison to prevent timing side-channel attacks on
12/// token validation.
13///
14/// Returns `true` only when `a` and `b` are the same length **and** every
15/// byte matches.  The comparison always examines every byte so that the
16/// execution time depends only on the length, never on where the first
17/// mismatch occurs.
18#[must_use]
19pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
20    if a.len() != b.len() {
21        return false;
22    }
23    a.iter()
24        .zip(b.iter())
25        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
26        == 0
27}
28
29// ── Token generation ─────────────────────────────────────────────────────
30
31/// Generate a random UUID v4 token suitable for Bearer authentication.
32#[must_use]
33pub fn generate_token() -> String {
34    uuid::Uuid::new_v4().to_string()
35}
36
37// ── Rate limiter ─────────────────────────────────────────────────────────
38
39/// Lock-free token-bucket rate limiter using monotonic timestamps for smooth
40/// refill.
41///
42/// Uses [`std::time::Instant`] instead of `SystemTime` so the refill clock is
43/// immune to NTP adjustments and pre-epoch system clocks.
44///
45/// Thread-safe via `AtomicU64` — no mutexes, no allocations on the hot path.
46pub struct RateLimiter {
47    tokens: AtomicU64,
48    max_tokens: u64,
49    last_refill_ms: AtomicU64,
50    refill_rate_per_sec: u64,
51    epoch: std::time::Instant,
52}
53
54impl RateLimiter {
55    /// Create a rate limiter with the given maximum requests per second.
56    #[must_use]
57    pub fn new(max_requests_per_sec: u64) -> Self {
58        Self {
59            tokens: AtomicU64::new(max_requests_per_sec),
60            max_tokens: max_requests_per_sec,
61            last_refill_ms: AtomicU64::new(0),
62            refill_rate_per_sec: max_requests_per_sec,
63            epoch: std::time::Instant::now(),
64        }
65    }
66
67    /// Atomically consume one token, returning `true` if the request is
68    /// allowed.
69    pub fn try_acquire(&self) -> bool {
70        self.refill();
71        loop {
72            let current = self.tokens.load(Ordering::Relaxed);
73            if current == 0 {
74                return false;
75            }
76            if self
77                .tokens
78                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
79                .is_ok()
80            {
81                return true;
82            }
83        }
84    }
85
86    /// Maximum token capacity.
87    #[must_use]
88    pub fn max_tokens(&self) -> u64 {
89        self.max_tokens
90    }
91
92    /// Current token count (snapshot — may change immediately after reading).
93    #[must_use]
94    pub fn current_tokens(&self) -> u64 {
95        self.tokens.load(Ordering::Relaxed)
96    }
97
98    fn elapsed_ms(&self) -> u64 {
99        self.epoch.elapsed().as_millis() as u64
100    }
101
102    fn refill(&self) {
103        let now = self.elapsed_ms();
104        let last = self.last_refill_ms.load(Ordering::Relaxed);
105        let elapsed_ms = now.saturating_sub(last);
106        if elapsed_ms == 0 {
107            return;
108        }
109        let add = elapsed_ms * self.refill_rate_per_sec / 1000;
110        if add == 0 {
111            return;
112        }
113        if self
114            .last_refill_ms
115            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
116            .is_ok()
117        {
118            loop {
119                let current = self.tokens.load(Ordering::Relaxed);
120                let new_val = (current + add).min(self.max_tokens);
121                if self
122                    .tokens
123                    .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
124                    .is_ok()
125                {
126                    break;
127                }
128            }
129        }
130    }
131}
132
133/// Default rate limit: 1 000 requests per second.
134pub const DEFAULT_RATE_LIMIT: u64 = 1000;
135
136// ── Host validation (DNS rebinding guard) ────────────────────────────────
137
138/// Returns `true` if `host` (from the HTTP `Host` header) resolves to a
139/// localhost address.
140///
141/// Handles `localhost`, `127.0.0.1`, `::1`, and any of those with a port
142/// suffix (e.g. `localhost:7373`, `[::1]:7373`).
143#[must_use]
144pub fn is_localhost_host(host: &str) -> bool {
145    let host_name = if let Some(rest) = host.strip_prefix('[') {
146        // Bracketed IPv6: [::1] or [::1]:7373. The bytes after `]` MUST be empty or a
147        // `:port` suffix — anything else (e.g. `[::1].evil.com`, `[::1]@x`) is rejected so a
148        // bracket-prefixed host can't smuggle a non-localhost authority past the guard.
149        match rest.split_once(']') {
150            Some((inner, "")) => inner,
151            Some((inner, after)) if after.strip_prefix(':').is_some_and(valid_port) => inner,
152            _ => return false,
153        }
154    } else if host.contains("::") {
155        // Bare IPv6 (no brackets): ::1
156        host
157    } else {
158        // IPv4 or hostname, strip a valid port: 127.0.0.1:7373 → 127.0.0.1.
159        match host.split_once(':') {
160            Some((name, port)) if valid_port(port) => name,
161            Some(_) => return false,
162            None => host,
163        }
164    };
165    host_name.eq_ignore_ascii_case("localhost") || matches!(host_name, "127.0.0.1" | "::1")
166}
167
168fn valid_port(port: &str) -> bool {
169    !port.is_empty() && port.parse::<u16>().is_ok()
170}
171
172// ── Origin validation (cross-origin guard) ───────────────────────────────
173
174/// Returns `true` if `origin` (from the HTTP `Origin` header) is a
175/// localhost origin, a `tauri://` origin, or absent.
176///
177/// Uses [`url::Url::parse`] internally so that subdomain-smuggling attacks
178/// like `localhost.evil.com` are caught by comparing the **parsed host**
179/// rather than doing prefix matching.
180#[must_use]
181pub fn is_allowed_origin(origin: &str) -> bool {
182    if origin.starts_with("tauri://") {
183        return true;
184    }
185    let Ok(parsed) = url::Url::parse(origin) else {
186        return false;
187    };
188    parsed.username().is_empty()
189        && parsed.password().is_none()
190        && matches!(parsed.scheme(), "http" | "https")
191        && matches!(
192            parsed.host_str(),
193            Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
194        )
195}
196
197// ── Tests ────────────────────────────────────────────────────────────────
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    // constant_time_eq
204
205    #[test]
206    fn ct_eq_equal() {
207        assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
208    }
209
210    #[test]
211    fn ct_eq_different() {
212        assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
213    }
214
215    #[test]
216    fn ct_eq_different_lengths() {
217        assert!(!constant_time_eq(b"short", b"longer-string"));
218    }
219
220    #[test]
221    fn ct_eq_empty() {
222        assert!(constant_time_eq(b"", b""));
223    }
224
225    #[test]
226    fn ct_eq_one_empty() {
227        assert!(!constant_time_eq(b"", b"notempty"));
228        assert!(!constant_time_eq(b"notempty", b""));
229    }
230
231    #[test]
232    fn ct_eq_single_bit_difference() {
233        assert!(!constant_time_eq(b"A", b"B"));
234    }
235
236    #[test]
237    fn ct_eq_long_strings() {
238        let a = "a".repeat(10_000);
239        let b = "a".repeat(10_000);
240        assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
241    }
242
243    #[test]
244    fn ct_eq_all_byte_values() {
245        for b in 0..=255u8 {
246            let a = [b];
247            assert!(constant_time_eq(&a, &a));
248            if b < 255 {
249                assert!(!constant_time_eq(&a, &[b + 1]));
250            }
251        }
252    }
253
254    // generate_token
255
256    #[test]
257    fn tokens_are_unique() {
258        let t1 = generate_token();
259        let t2 = generate_token();
260        assert_ne!(t1, t2);
261        assert_eq!(t1.len(), 36);
262    }
263
264    #[test]
265    fn token_is_valid_uuid() {
266        let token = generate_token();
267        assert!(uuid::Uuid::parse_str(&token).is_ok());
268    }
269
270    #[test]
271    fn token_uniqueness_over_1000() {
272        let mut set = std::collections::HashSet::new();
273        for _ in 0..1000 {
274            assert!(set.insert(generate_token()), "duplicate token");
275        }
276    }
277
278    // RateLimiter
279
280    #[test]
281    fn rate_limiter_allows_within_budget() {
282        let limiter = RateLimiter::new(10);
283        for _ in 0..10 {
284            assert!(limiter.try_acquire());
285        }
286    }
287
288    #[test]
289    fn rate_limiter_denies_when_exhausted() {
290        let limiter = RateLimiter::new(5);
291        for _ in 0..5 {
292            assert!(limiter.try_acquire());
293        }
294        assert!(!limiter.try_acquire());
295    }
296
297    #[test]
298    fn rate_limiter_initial_tokens_match_max() {
299        let limiter = RateLimiter::new(42);
300        assert_eq!(limiter.current_tokens(), 42);
301        assert_eq!(limiter.max_tokens(), 42);
302    }
303
304    #[test]
305    fn rate_limiter_zero_capacity() {
306        let limiter = RateLimiter::new(0);
307        assert!(!limiter.try_acquire());
308    }
309
310    #[test]
311    fn rate_limiter_concurrent() {
312        let limiter = std::sync::Arc::new(RateLimiter::new(1000));
313        let mut handles = vec![];
314        for _ in 0..10 {
315            let l = limiter.clone();
316            handles.push(std::thread::spawn(move || {
317                let mut acquired: u64 = 0;
318                for _ in 0..200 {
319                    if l.try_acquire() {
320                        acquired += 1;
321                    }
322                }
323                acquired
324            }));
325        }
326        let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
327        assert!(
328            total >= 1000,
329            "should dispense at least the initial budget, got {total}"
330        );
331        assert!(total <= 1200, "refill overshoot too high, got {total}");
332    }
333
334    // is_localhost_host
335
336    #[test]
337    fn host_allows_localhost() {
338        assert!(is_localhost_host("localhost"));
339        assert!(is_localhost_host("LOCALHOST"));
340        assert!(is_localhost_host("localhost:7373"));
341        assert!(is_localhost_host("LocalHost:7373"));
342    }
343
344    #[test]
345    fn host_allows_ipv4() {
346        assert!(is_localhost_host("127.0.0.1"));
347        assert!(is_localhost_host("127.0.0.1:7373"));
348    }
349
350    #[test]
351    fn host_allows_ipv6() {
352        assert!(is_localhost_host("[::1]"));
353        assert!(is_localhost_host("[::1]:7373"));
354        assert!(is_localhost_host("::1"));
355    }
356
357    #[test]
358    fn host_blocks_evil() {
359        assert!(!is_localhost_host("evil.com"));
360        assert!(!is_localhost_host("localhost.evil.com"));
361        assert!(!is_localhost_host("127.0.0.1.evil.com"));
362        assert!(!is_localhost_host(""));
363    }
364
365    #[test]
366    fn host_blocks_bracketed_ipv6_smuggling() {
367        // Bytes after `]` must be empty or a :port — a bracket-prefixed host cannot smuggle a
368        // non-localhost authority past the guard (regression for the audit-prep A-F1 finding).
369        assert!(!is_localhost_host("[::1].evil.com"));
370        assert!(!is_localhost_host("[::1]@evil.com"));
371        assert!(!is_localhost_host("[::1]evil"));
372        assert!(!is_localhost_host("[2001:db8::1]")); // bracketed but not loopback
373        assert!(!is_localhost_host("[::1].evil.com:7373")); // trailing-garbage-then-port
374        // Valid bracketed loopback forms still pass (incl. a bracketed IPv4 loopback, which
375        // still resolves to 127.0.0.1 — harmless, it's localhost either way).
376        assert!(is_localhost_host("[::1]"));
377        assert!(is_localhost_host("[::1]:7373"));
378        assert!(is_localhost_host("[127.0.0.1]"));
379    }
380
381    #[test]
382    fn host_blocks_malformed_port_suffixes() {
383        assert!(!is_localhost_host("localhost:notaport"));
384        assert!(!is_localhost_host("localhost:"));
385        assert!(!is_localhost_host("localhost:7373:extra"));
386        assert!(!is_localhost_host("127.0.0.1:notaport"));
387        assert!(!is_localhost_host("[::1]:notaport"));
388        assert!(!is_localhost_host("[::1]:"));
389        assert!(!is_localhost_host("[::1]:7373:extra"));
390        assert!(!is_localhost_host("[::1] :7373"));
391    }
392
393    // is_allowed_origin
394
395    #[test]
396    fn origin_allows_localhost_variants() {
397        assert!(is_allowed_origin("http://localhost"));
398        assert!(is_allowed_origin("http://localhost:7373"));
399        assert!(is_allowed_origin("https://localhost"));
400        assert!(is_allowed_origin("http://127.0.0.1"));
401        assert!(is_allowed_origin("http://127.0.0.1:8080"));
402        assert!(is_allowed_origin("http://[::1]"));
403        assert!(is_allowed_origin("http://[::1]:7373"));
404        assert!(is_allowed_origin("tauri://localhost"));
405        assert!(is_allowed_origin("tauri://some-app"));
406    }
407
408    #[test]
409    fn origin_blocks_smuggling() {
410        assert!(!is_allowed_origin("http://localhost.evil.com"));
411        assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
412        assert!(!is_allowed_origin("http://localhost@evil.com"));
413        assert!(!is_allowed_origin("http://user:pass@localhost:7373"));
414    }
415
416    #[test]
417    fn origin_blocks_external() {
418        assert!(!is_allowed_origin("http://evil.com"));
419        assert!(!is_allowed_origin("https://attacker.io"));
420        assert!(!is_allowed_origin("not-a-url"));
421        assert!(!is_allowed_origin(""));
422        assert!(!is_allowed_origin("null"));
423        assert!(!is_allowed_origin("ftp://localhost"));
424    }
425}