victauri_core/
security.rs1use std::sync::atomic::{AtomicU64, Ordering};
8
9#[must_use]
19pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
20 if a.len() != b.len() {
21 return false;
22 }
23 a.iter()
24 .zip(b.iter())
25 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
26 == 0
27}
28
29#[must_use]
33pub fn generate_token() -> String {
34 uuid::Uuid::new_v4().to_string()
35}
36
37pub struct RateLimiter {
47 tokens: AtomicU64,
48 max_tokens: u64,
49 last_refill_ms: AtomicU64,
50 refill_rate_per_sec: u64,
51 epoch: std::time::Instant,
52}
53
54impl RateLimiter {
55 #[must_use]
57 pub fn new(max_requests_per_sec: u64) -> Self {
58 Self {
59 tokens: AtomicU64::new(max_requests_per_sec),
60 max_tokens: max_requests_per_sec,
61 last_refill_ms: AtomicU64::new(0),
62 refill_rate_per_sec: max_requests_per_sec,
63 epoch: std::time::Instant::now(),
64 }
65 }
66
67 pub fn try_acquire(&self) -> bool {
70 self.refill();
71 loop {
72 let current = self.tokens.load(Ordering::Relaxed);
73 if current == 0 {
74 return false;
75 }
76 if self
77 .tokens
78 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
79 .is_ok()
80 {
81 return true;
82 }
83 }
84 }
85
86 #[must_use]
88 pub fn max_tokens(&self) -> u64 {
89 self.max_tokens
90 }
91
92 #[must_use]
94 pub fn current_tokens(&self) -> u64 {
95 self.tokens.load(Ordering::Relaxed)
96 }
97
98 fn elapsed_ms(&self) -> u64 {
99 self.epoch.elapsed().as_millis() as u64
100 }
101
102 fn refill(&self) {
103 let now = self.elapsed_ms();
104 let last = self.last_refill_ms.load(Ordering::Relaxed);
105 let elapsed_ms = now.saturating_sub(last);
106 if elapsed_ms == 0 {
107 return;
108 }
109 let add = elapsed_ms * self.refill_rate_per_sec / 1000;
110 if add == 0 {
111 return;
112 }
113 if self
114 .last_refill_ms
115 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
116 .is_ok()
117 {
118 loop {
119 let current = self.tokens.load(Ordering::Relaxed);
120 let new_val = (current + add).min(self.max_tokens);
121 if self
122 .tokens
123 .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
124 .is_ok()
125 {
126 break;
127 }
128 }
129 }
130 }
131}
132
133pub const DEFAULT_RATE_LIMIT: u64 = 1000;
135
136#[must_use]
144pub fn is_localhost_host(host: &str) -> bool {
145 let host_name = if host.starts_with('[') {
146 host.split(']').next().map_or(host, |s| &s[1..])
148 } else if host.contains("::") {
149 host
151 } else {
152 host.split(':').next().unwrap_or(host)
154 };
155 matches!(host_name, "localhost" | "127.0.0.1" | "::1")
156}
157
158#[must_use]
167pub fn is_allowed_origin(origin: &str) -> bool {
168 if origin.starts_with("tauri://") {
169 return true;
170 }
171 let Ok(parsed) = url::Url::parse(origin) else {
172 return false;
173 };
174 parsed.username().is_empty()
175 && parsed.password().is_none()
176 && matches!(parsed.scheme(), "http" | "https")
177 && matches!(
178 parsed.host_str(),
179 Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
180 )
181}
182
183#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
192 fn ct_eq_equal() {
193 assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
194 }
195
196 #[test]
197 fn ct_eq_different() {
198 assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
199 }
200
201 #[test]
202 fn ct_eq_different_lengths() {
203 assert!(!constant_time_eq(b"short", b"longer-string"));
204 }
205
206 #[test]
207 fn ct_eq_empty() {
208 assert!(constant_time_eq(b"", b""));
209 }
210
211 #[test]
212 fn ct_eq_one_empty() {
213 assert!(!constant_time_eq(b"", b"notempty"));
214 assert!(!constant_time_eq(b"notempty", b""));
215 }
216
217 #[test]
218 fn ct_eq_single_bit_difference() {
219 assert!(!constant_time_eq(b"A", b"B"));
220 }
221
222 #[test]
223 fn ct_eq_long_strings() {
224 let a = "a".repeat(10_000);
225 let b = "a".repeat(10_000);
226 assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
227 }
228
229 #[test]
230 fn ct_eq_all_byte_values() {
231 for b in 0..=255u8 {
232 let a = [b];
233 assert!(constant_time_eq(&a, &a));
234 if b < 255 {
235 assert!(!constant_time_eq(&a, &[b + 1]));
236 }
237 }
238 }
239
240 #[test]
243 fn tokens_are_unique() {
244 let t1 = generate_token();
245 let t2 = generate_token();
246 assert_ne!(t1, t2);
247 assert_eq!(t1.len(), 36);
248 }
249
250 #[test]
251 fn token_is_valid_uuid() {
252 let token = generate_token();
253 assert!(uuid::Uuid::parse_str(&token).is_ok());
254 }
255
256 #[test]
257 fn token_uniqueness_over_1000() {
258 let mut set = std::collections::HashSet::new();
259 for _ in 0..1000 {
260 assert!(set.insert(generate_token()), "duplicate token");
261 }
262 }
263
264 #[test]
267 fn rate_limiter_allows_within_budget() {
268 let limiter = RateLimiter::new(10);
269 for _ in 0..10 {
270 assert!(limiter.try_acquire());
271 }
272 }
273
274 #[test]
275 fn rate_limiter_denies_when_exhausted() {
276 let limiter = RateLimiter::new(5);
277 for _ in 0..5 {
278 assert!(limiter.try_acquire());
279 }
280 assert!(!limiter.try_acquire());
281 }
282
283 #[test]
284 fn rate_limiter_initial_tokens_match_max() {
285 let limiter = RateLimiter::new(42);
286 assert_eq!(limiter.current_tokens(), 42);
287 assert_eq!(limiter.max_tokens(), 42);
288 }
289
290 #[test]
291 fn rate_limiter_zero_capacity() {
292 let limiter = RateLimiter::new(0);
293 assert!(!limiter.try_acquire());
294 }
295
296 #[test]
297 fn rate_limiter_concurrent() {
298 let limiter = std::sync::Arc::new(RateLimiter::new(1000));
299 let mut handles = vec![];
300 for _ in 0..10 {
301 let l = limiter.clone();
302 handles.push(std::thread::spawn(move || {
303 let mut acquired: u64 = 0;
304 for _ in 0..200 {
305 if l.try_acquire() {
306 acquired += 1;
307 }
308 }
309 acquired
310 }));
311 }
312 let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
313 assert!(
314 total >= 1000,
315 "should dispense at least the initial budget, got {total}"
316 );
317 assert!(total <= 1200, "refill overshoot too high, got {total}");
318 }
319
320 #[test]
323 fn host_allows_localhost() {
324 assert!(is_localhost_host("localhost"));
325 assert!(is_localhost_host("localhost:7373"));
326 }
327
328 #[test]
329 fn host_allows_ipv4() {
330 assert!(is_localhost_host("127.0.0.1"));
331 assert!(is_localhost_host("127.0.0.1:7373"));
332 }
333
334 #[test]
335 fn host_allows_ipv6() {
336 assert!(is_localhost_host("[::1]"));
337 assert!(is_localhost_host("[::1]:7373"));
338 assert!(is_localhost_host("::1"));
339 }
340
341 #[test]
342 fn host_blocks_evil() {
343 assert!(!is_localhost_host("evil.com"));
344 assert!(!is_localhost_host("localhost.evil.com"));
345 assert!(!is_localhost_host("127.0.0.1.evil.com"));
346 assert!(!is_localhost_host(""));
347 }
348
349 #[test]
352 fn origin_allows_localhost_variants() {
353 assert!(is_allowed_origin("http://localhost"));
354 assert!(is_allowed_origin("http://localhost:7373"));
355 assert!(is_allowed_origin("https://localhost"));
356 assert!(is_allowed_origin("http://127.0.0.1"));
357 assert!(is_allowed_origin("http://127.0.0.1:8080"));
358 assert!(is_allowed_origin("http://[::1]"));
359 assert!(is_allowed_origin("http://[::1]:7373"));
360 assert!(is_allowed_origin("tauri://localhost"));
361 assert!(is_allowed_origin("tauri://some-app"));
362 }
363
364 #[test]
365 fn origin_blocks_smuggling() {
366 assert!(!is_allowed_origin("http://localhost.evil.com"));
367 assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
368 assert!(!is_allowed_origin("http://localhost@evil.com"));
369 assert!(!is_allowed_origin("http://user:pass@localhost:7373"));
370 }
371
372 #[test]
373 fn origin_blocks_external() {
374 assert!(!is_allowed_origin("http://evil.com"));
375 assert!(!is_allowed_origin("https://attacker.io"));
376 assert!(!is_allowed_origin("not-a-url"));
377 assert!(!is_allowed_origin(""));
378 assert!(!is_allowed_origin("null"));
379 assert!(!is_allowed_origin("ftp://localhost"));
380 }
381}