victauri_core/
security.rs1use std::sync::atomic::{AtomicU64, Ordering};
8
9#[must_use]
19pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
20 if a.len() != b.len() {
21 return false;
22 }
23 a.iter()
24 .zip(b.iter())
25 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
26 == 0
27}
28
29#[must_use]
33pub fn generate_token() -> String {
34 uuid::Uuid::new_v4().to_string()
35}
36
37pub struct RateLimiter {
47 tokens: AtomicU64,
48 max_tokens: u64,
49 last_refill_ms: AtomicU64,
50 refill_rate_per_sec: u64,
51 epoch: std::time::Instant,
52}
53
54impl RateLimiter {
55 #[must_use]
57 pub fn new(max_requests_per_sec: u64) -> Self {
58 Self {
59 tokens: AtomicU64::new(max_requests_per_sec),
60 max_tokens: max_requests_per_sec,
61 last_refill_ms: AtomicU64::new(0),
62 refill_rate_per_sec: max_requests_per_sec,
63 epoch: std::time::Instant::now(),
64 }
65 }
66
67 pub fn try_acquire(&self) -> bool {
70 self.refill();
71 loop {
72 let current = self.tokens.load(Ordering::Relaxed);
73 if current == 0 {
74 return false;
75 }
76 if self
77 .tokens
78 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
79 .is_ok()
80 {
81 return true;
82 }
83 }
84 }
85
86 #[must_use]
88 pub fn max_tokens(&self) -> u64 {
89 self.max_tokens
90 }
91
92 #[must_use]
94 pub fn current_tokens(&self) -> u64 {
95 self.tokens.load(Ordering::Relaxed)
96 }
97
98 fn elapsed_ms(&self) -> u64 {
99 self.epoch.elapsed().as_millis() as u64
100 }
101
102 fn refill(&self) {
103 let now = self.elapsed_ms();
104 let last = self.last_refill_ms.load(Ordering::Relaxed);
105 let elapsed_ms = now.saturating_sub(last);
106 if elapsed_ms == 0 {
107 return;
108 }
109 let add = elapsed_ms * self.refill_rate_per_sec / 1000;
110 if add == 0 {
111 return;
112 }
113 if self
114 .last_refill_ms
115 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
116 .is_ok()
117 {
118 loop {
119 let current = self.tokens.load(Ordering::Relaxed);
120 let new_val = (current + add).min(self.max_tokens);
121 if self
122 .tokens
123 .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
124 .is_ok()
125 {
126 break;
127 }
128 }
129 }
130 }
131}
132
133pub const DEFAULT_RATE_LIMIT: u64 = 1000;
135
136#[must_use]
144pub fn is_localhost_host(host: &str) -> bool {
145 let host_name = if let Some(rest) = host.strip_prefix('[') {
146 match rest.split_once(']') {
150 Some((inner, "")) => inner,
151 Some((inner, after)) if after.strip_prefix(':').is_some_and(valid_port) => inner,
152 _ => return false,
153 }
154 } else if host.contains("::") {
155 host
157 } else {
158 match host.split_once(':') {
160 Some((name, port)) if valid_port(port) => name,
161 Some(_) => return false,
162 None => host,
163 }
164 };
165 host_name.eq_ignore_ascii_case("localhost") || matches!(host_name, "127.0.0.1" | "::1")
166}
167
168fn valid_port(port: &str) -> bool {
169 !port.is_empty() && port.parse::<u16>().is_ok()
170}
171
172#[must_use]
181pub fn is_allowed_origin(origin: &str) -> bool {
182 if origin.starts_with("tauri://") {
183 return true;
184 }
185 let Ok(parsed) = url::Url::parse(origin) else {
186 return false;
187 };
188 parsed.username().is_empty()
189 && parsed.password().is_none()
190 && matches!(parsed.scheme(), "http" | "https")
191 && matches!(
192 parsed.host_str(),
193 Some("localhost" | "127.0.0.1" | "[::1]" | "::1")
194 )
195}
196
197#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
206 fn ct_eq_equal() {
207 assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
208 }
209
210 #[test]
211 fn ct_eq_different() {
212 assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
213 }
214
215 #[test]
216 fn ct_eq_different_lengths() {
217 assert!(!constant_time_eq(b"short", b"longer-string"));
218 }
219
220 #[test]
221 fn ct_eq_empty() {
222 assert!(constant_time_eq(b"", b""));
223 }
224
225 #[test]
226 fn ct_eq_one_empty() {
227 assert!(!constant_time_eq(b"", b"notempty"));
228 assert!(!constant_time_eq(b"notempty", b""));
229 }
230
231 #[test]
232 fn ct_eq_single_bit_difference() {
233 assert!(!constant_time_eq(b"A", b"B"));
234 }
235
236 #[test]
237 fn ct_eq_long_strings() {
238 let a = "a".repeat(10_000);
239 let b = "a".repeat(10_000);
240 assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
241 }
242
243 #[test]
244 fn ct_eq_all_byte_values() {
245 for b in 0..=255u8 {
246 let a = [b];
247 assert!(constant_time_eq(&a, &a));
248 if b < 255 {
249 assert!(!constant_time_eq(&a, &[b + 1]));
250 }
251 }
252 }
253
254 #[test]
257 fn tokens_are_unique() {
258 let t1 = generate_token();
259 let t2 = generate_token();
260 assert_ne!(t1, t2);
261 assert_eq!(t1.len(), 36);
262 }
263
264 #[test]
265 fn token_is_valid_uuid() {
266 let token = generate_token();
267 assert!(uuid::Uuid::parse_str(&token).is_ok());
268 }
269
270 #[test]
271 fn token_uniqueness_over_1000() {
272 let mut set = std::collections::HashSet::new();
273 for _ in 0..1000 {
274 assert!(set.insert(generate_token()), "duplicate token");
275 }
276 }
277
278 #[test]
281 fn rate_limiter_allows_within_budget() {
282 let limiter = RateLimiter::new(10);
283 for _ in 0..10 {
284 assert!(limiter.try_acquire());
285 }
286 }
287
288 #[test]
289 fn rate_limiter_denies_when_exhausted() {
290 let limiter = RateLimiter::new(5);
291 for _ in 0..5 {
292 assert!(limiter.try_acquire());
293 }
294 assert!(!limiter.try_acquire());
295 }
296
297 #[test]
298 fn rate_limiter_initial_tokens_match_max() {
299 let limiter = RateLimiter::new(42);
300 assert_eq!(limiter.current_tokens(), 42);
301 assert_eq!(limiter.max_tokens(), 42);
302 }
303
304 #[test]
305 fn rate_limiter_zero_capacity() {
306 let limiter = RateLimiter::new(0);
307 assert!(!limiter.try_acquire());
308 }
309
310 #[test]
311 fn rate_limiter_concurrent() {
312 let limiter = std::sync::Arc::new(RateLimiter::new(1000));
313 let mut handles = vec![];
314 for _ in 0..10 {
315 let l = limiter.clone();
316 handles.push(std::thread::spawn(move || {
317 let mut acquired: u64 = 0;
318 for _ in 0..200 {
319 if l.try_acquire() {
320 acquired += 1;
321 }
322 }
323 acquired
324 }));
325 }
326 let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
327 assert!(
328 total >= 1000,
329 "should dispense at least the initial budget, got {total}"
330 );
331 assert!(total <= 1200, "refill overshoot too high, got {total}");
332 }
333
334 #[test]
337 fn host_allows_localhost() {
338 assert!(is_localhost_host("localhost"));
339 assert!(is_localhost_host("LOCALHOST"));
340 assert!(is_localhost_host("localhost:7373"));
341 assert!(is_localhost_host("LocalHost:7373"));
342 }
343
344 #[test]
345 fn host_allows_ipv4() {
346 assert!(is_localhost_host("127.0.0.1"));
347 assert!(is_localhost_host("127.0.0.1:7373"));
348 }
349
350 #[test]
351 fn host_allows_ipv6() {
352 assert!(is_localhost_host("[::1]"));
353 assert!(is_localhost_host("[::1]:7373"));
354 assert!(is_localhost_host("::1"));
355 }
356
357 #[test]
358 fn host_blocks_evil() {
359 assert!(!is_localhost_host("evil.com"));
360 assert!(!is_localhost_host("localhost.evil.com"));
361 assert!(!is_localhost_host("127.0.0.1.evil.com"));
362 assert!(!is_localhost_host(""));
363 }
364
365 #[test]
366 fn host_blocks_bracketed_ipv6_smuggling() {
367 assert!(!is_localhost_host("[::1].evil.com"));
370 assert!(!is_localhost_host("[::1]@evil.com"));
371 assert!(!is_localhost_host("[::1]evil"));
372 assert!(!is_localhost_host("[2001:db8::1]")); assert!(!is_localhost_host("[::1].evil.com:7373")); assert!(is_localhost_host("[::1]"));
377 assert!(is_localhost_host("[::1]:7373"));
378 assert!(is_localhost_host("[127.0.0.1]"));
379 }
380
381 #[test]
382 fn host_blocks_malformed_port_suffixes() {
383 assert!(!is_localhost_host("localhost:notaport"));
384 assert!(!is_localhost_host("localhost:"));
385 assert!(!is_localhost_host("localhost:7373:extra"));
386 assert!(!is_localhost_host("127.0.0.1:notaport"));
387 assert!(!is_localhost_host("[::1]:notaport"));
388 assert!(!is_localhost_host("[::1]:"));
389 assert!(!is_localhost_host("[::1]:7373:extra"));
390 assert!(!is_localhost_host("[::1] :7373"));
391 }
392
393 #[test]
396 fn origin_allows_localhost_variants() {
397 assert!(is_allowed_origin("http://localhost"));
398 assert!(is_allowed_origin("http://localhost:7373"));
399 assert!(is_allowed_origin("https://localhost"));
400 assert!(is_allowed_origin("http://127.0.0.1"));
401 assert!(is_allowed_origin("http://127.0.0.1:8080"));
402 assert!(is_allowed_origin("http://[::1]"));
403 assert!(is_allowed_origin("http://[::1]:7373"));
404 assert!(is_allowed_origin("tauri://localhost"));
405 assert!(is_allowed_origin("tauri://some-app"));
406 }
407
408 #[test]
409 fn origin_blocks_smuggling() {
410 assert!(!is_allowed_origin("http://localhost.evil.com"));
411 assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
412 assert!(!is_allowed_origin("http://localhost@evil.com"));
413 assert!(!is_allowed_origin("http://user:pass@localhost:7373"));
414 }
415
416 #[test]
417 fn origin_blocks_external() {
418 assert!(!is_allowed_origin("http://evil.com"));
419 assert!(!is_allowed_origin("https://attacker.io"));
420 assert!(!is_allowed_origin("not-a-url"));
421 assert!(!is_allowed_origin(""));
422 assert!(!is_allowed_origin("null"));
423 assert!(!is_allowed_origin("ftp://localhost"));
424 }
425}