Skip to main content

victauri_core/
security.rs

1//! Shared security primitives for Victauri's localhost HTTP servers.
2//!
3//! This module provides the pure-logic building blocks that both `victauri-plugin`
4//! and `victauri-browser` use in their axum middleware stacks.  Keeping them here
5//! eliminates copy-paste drift between the two crates.
6
7use std::sync::atomic::{AtomicU64, Ordering};
8
9// ── Constant-time comparison ─────────────────────────────────────────────
10
11/// Constant-time byte comparison to prevent timing side-channel attacks on
12/// token validation.
13///
14/// Returns `true` only when `a` and `b` are the same length **and** every
15/// byte matches.  The comparison always examines every byte so that the
16/// execution time depends only on the length, never on where the first
17/// mismatch occurs.
18#[must_use]
19pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
20    if a.len() != b.len() {
21        return false;
22    }
23    a.iter()
24        .zip(b.iter())
25        .fold(0u8, |acc, (x, y)| acc | (x ^ y))
26        == 0
27}
28
29// ── Token generation ─────────────────────────────────────────────────────
30
31/// Generate a random UUID v4 token suitable for Bearer authentication.
32#[must_use]
33pub fn generate_token() -> String {
34    uuid::Uuid::new_v4().to_string()
35}
36
37// ── Rate limiter ─────────────────────────────────────────────────────────
38
39/// Lock-free token-bucket rate limiter using monotonic timestamps for smooth
40/// refill.
41///
42/// Uses [`std::time::Instant`] instead of `SystemTime` so the refill clock is
43/// immune to NTP adjustments and pre-epoch system clocks.
44///
45/// Thread-safe via `AtomicU64` — no mutexes, no allocations on the hot path.
46pub struct RateLimiter {
47    tokens: AtomicU64,
48    max_tokens: u64,
49    last_refill_ms: AtomicU64,
50    refill_rate_per_sec: u64,
51    epoch: std::time::Instant,
52}
53
54impl RateLimiter {
55    /// Create a rate limiter with the given maximum requests per second.
56    #[must_use]
57    pub fn new(max_requests_per_sec: u64) -> Self {
58        Self {
59            tokens: AtomicU64::new(max_requests_per_sec),
60            max_tokens: max_requests_per_sec,
61            last_refill_ms: AtomicU64::new(0),
62            refill_rate_per_sec: max_requests_per_sec,
63            epoch: std::time::Instant::now(),
64        }
65    }
66
67    /// Atomically consume one token, returning `true` if the request is
68    /// allowed.
69    pub fn try_acquire(&self) -> bool {
70        self.refill();
71        loop {
72            let current = self.tokens.load(Ordering::Relaxed);
73            if current == 0 {
74                return false;
75            }
76            if self
77                .tokens
78                .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
79                .is_ok()
80            {
81                return true;
82            }
83        }
84    }
85
86    /// Maximum token capacity.
87    #[must_use]
88    pub fn max_tokens(&self) -> u64 {
89        self.max_tokens
90    }
91
92    /// Current token count (snapshot — may change immediately after reading).
93    #[must_use]
94    pub fn current_tokens(&self) -> u64 {
95        self.tokens.load(Ordering::Relaxed)
96    }
97
98    fn elapsed_ms(&self) -> u64 {
99        self.epoch.elapsed().as_millis() as u64
100    }
101
102    fn refill(&self) {
103        let now = self.elapsed_ms();
104        let last = self.last_refill_ms.load(Ordering::Relaxed);
105        let elapsed_ms = now.saturating_sub(last);
106        if elapsed_ms == 0 {
107            return;
108        }
109        let add = elapsed_ms * self.refill_rate_per_sec / 1000;
110        if add == 0 {
111            return;
112        }
113        if self
114            .last_refill_ms
115            .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
116            .is_ok()
117        {
118            loop {
119                let current = self.tokens.load(Ordering::Relaxed);
120                let new_val = (current + add).min(self.max_tokens);
121                if self
122                    .tokens
123                    .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
124                    .is_ok()
125                {
126                    break;
127                }
128            }
129        }
130    }
131}
132
133/// Default rate limit: 1 000 requests per second.
134pub const DEFAULT_RATE_LIMIT: u64 = 1000;
135
136// ── Host validation (DNS rebinding guard) ────────────────────────────────
137
138/// Returns `true` if `host` (from the HTTP `Host` header) resolves to a
139/// localhost address.
140///
141/// Handles `localhost`, `127.0.0.1`, `::1`, and any of those with a port
142/// suffix (e.g. `localhost:7373`, `[::1]:7373`).
143#[must_use]
144pub fn is_localhost_host(host: &str) -> bool {
145    let host_name = if host.starts_with('[') {
146        // Bracketed IPv6: [::1] or [::1]:7373
147        host.split(']').next().map_or(host, |s| &s[1..])
148    } else if host.contains("::") {
149        // Bare IPv6 (no brackets): ::1
150        host
151    } else {
152        // IPv4 or hostname, strip port: 127.0.0.1:7373 → 127.0.0.1
153        host.split(':').next().unwrap_or(host)
154    };
155    matches!(host_name, "localhost" | "127.0.0.1" | "::1")
156}
157
158// ── Origin validation (cross-origin guard) ───────────────────────────────
159
160/// Returns `true` if `origin` (from the HTTP `Origin` header) is a
161/// localhost origin, a `tauri://` origin, or absent.
162///
163/// Uses [`url::Url::parse`] internally so that subdomain-smuggling attacks
164/// like `localhost.evil.com` are caught by comparing the **parsed host**
165/// rather than doing prefix matching.
166#[must_use]
167pub fn is_allowed_origin(origin: &str) -> bool {
168    if origin.starts_with("tauri://") {
169        return true;
170    }
171    let Ok(parsed) = url::Url::parse(origin) else {
172        return false;
173    };
174    parsed.username().is_empty()
175        && parsed.password().is_none()
176        && matches!(parsed.scheme(), "http" | "https")
177        && matches!(
178            parsed.host_str(),
179            Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
180        )
181}
182
183// ── Tests ────────────────────────────────────────────────────────────────
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    // constant_time_eq
190
191    #[test]
192    fn ct_eq_equal() {
193        assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
194    }
195
196    #[test]
197    fn ct_eq_different() {
198        assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
199    }
200
201    #[test]
202    fn ct_eq_different_lengths() {
203        assert!(!constant_time_eq(b"short", b"longer-string"));
204    }
205
206    #[test]
207    fn ct_eq_empty() {
208        assert!(constant_time_eq(b"", b""));
209    }
210
211    #[test]
212    fn ct_eq_one_empty() {
213        assert!(!constant_time_eq(b"", b"notempty"));
214        assert!(!constant_time_eq(b"notempty", b""));
215    }
216
217    #[test]
218    fn ct_eq_single_bit_difference() {
219        assert!(!constant_time_eq(b"A", b"B"));
220    }
221
222    #[test]
223    fn ct_eq_long_strings() {
224        let a = "a".repeat(10_000);
225        let b = "a".repeat(10_000);
226        assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
227    }
228
229    #[test]
230    fn ct_eq_all_byte_values() {
231        for b in 0..=255u8 {
232            let a = [b];
233            assert!(constant_time_eq(&a, &a));
234            if b < 255 {
235                assert!(!constant_time_eq(&a, &[b + 1]));
236            }
237        }
238    }
239
240    // generate_token
241
242    #[test]
243    fn tokens_are_unique() {
244        let t1 = generate_token();
245        let t2 = generate_token();
246        assert_ne!(t1, t2);
247        assert_eq!(t1.len(), 36);
248    }
249
250    #[test]
251    fn token_is_valid_uuid() {
252        let token = generate_token();
253        assert!(uuid::Uuid::parse_str(&token).is_ok());
254    }
255
256    #[test]
257    fn token_uniqueness_over_1000() {
258        let mut set = std::collections::HashSet::new();
259        for _ in 0..1000 {
260            assert!(set.insert(generate_token()), "duplicate token");
261        }
262    }
263
264    // RateLimiter
265
266    #[test]
267    fn rate_limiter_allows_within_budget() {
268        let limiter = RateLimiter::new(10);
269        for _ in 0..10 {
270            assert!(limiter.try_acquire());
271        }
272    }
273
274    #[test]
275    fn rate_limiter_denies_when_exhausted() {
276        let limiter = RateLimiter::new(5);
277        for _ in 0..5 {
278            assert!(limiter.try_acquire());
279        }
280        assert!(!limiter.try_acquire());
281    }
282
283    #[test]
284    fn rate_limiter_initial_tokens_match_max() {
285        let limiter = RateLimiter::new(42);
286        assert_eq!(limiter.current_tokens(), 42);
287        assert_eq!(limiter.max_tokens(), 42);
288    }
289
290    #[test]
291    fn rate_limiter_zero_capacity() {
292        let limiter = RateLimiter::new(0);
293        assert!(!limiter.try_acquire());
294    }
295
296    #[test]
297    fn rate_limiter_concurrent() {
298        let limiter = std::sync::Arc::new(RateLimiter::new(1000));
299        let mut handles = vec![];
300        for _ in 0..10 {
301            let l = limiter.clone();
302            handles.push(std::thread::spawn(move || {
303                let mut acquired: u64 = 0;
304                for _ in 0..200 {
305                    if l.try_acquire() {
306                        acquired += 1;
307                    }
308                }
309                acquired
310            }));
311        }
312        let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
313        assert!(
314            total >= 1000,
315            "should dispense at least the initial budget, got {total}"
316        );
317        assert!(total <= 1200, "refill overshoot too high, got {total}");
318    }
319
320    // is_localhost_host
321
322    #[test]
323    fn host_allows_localhost() {
324        assert!(is_localhost_host("localhost"));
325        assert!(is_localhost_host("localhost:7373"));
326    }
327
328    #[test]
329    fn host_allows_ipv4() {
330        assert!(is_localhost_host("127.0.0.1"));
331        assert!(is_localhost_host("127.0.0.1:7373"));
332    }
333
334    #[test]
335    fn host_allows_ipv6() {
336        assert!(is_localhost_host("[::1]"));
337        assert!(is_localhost_host("[::1]:7373"));
338        assert!(is_localhost_host("::1"));
339    }
340
341    #[test]
342    fn host_blocks_evil() {
343        assert!(!is_localhost_host("evil.com"));
344        assert!(!is_localhost_host("localhost.evil.com"));
345        assert!(!is_localhost_host("127.0.0.1.evil.com"));
346        assert!(!is_localhost_host(""));
347    }
348
349    // is_allowed_origin
350
351    #[test]
352    fn origin_allows_localhost_variants() {
353        assert!(is_allowed_origin("http://localhost"));
354        assert!(is_allowed_origin("http://localhost:7373"));
355        assert!(is_allowed_origin("https://localhost"));
356        assert!(is_allowed_origin("http://127.0.0.1"));
357        assert!(is_allowed_origin("http://127.0.0.1:8080"));
358        assert!(is_allowed_origin("http://[::1]"));
359        assert!(is_allowed_origin("http://[::1]:7373"));
360        assert!(is_allowed_origin("tauri://localhost"));
361        assert!(is_allowed_origin("tauri://some-app"));
362    }
363
364    #[test]
365    fn origin_blocks_smuggling() {
366        assert!(!is_allowed_origin("http://localhost.evil.com"));
367        assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
368        assert!(!is_allowed_origin("http://localhost@evil.com"));
369        assert!(!is_allowed_origin("http://user:pass@localhost:7373"));
370    }
371
372    #[test]
373    fn origin_blocks_external() {
374        assert!(!is_allowed_origin("http://evil.com"));
375        assert!(!is_allowed_origin("https://attacker.io"));
376        assert!(!is_allowed_origin("not-a-url"));
377        assert!(!is_allowed_origin(""));
378        assert!(!is_allowed_origin("null"));
379        assert!(!is_allowed_origin("ftp://localhost"));
380    }
381}