1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8const BEARER_PREFIX_LEN: usize = "Bearer ".len();
9
10fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
11 if a.len() != b.len() {
12 return false;
13 }
14 a.iter()
15 .zip(b.iter())
16 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
17 == 0
18}
19
20#[must_use]
22pub fn generate_token() -> String {
23 uuid::Uuid::new_v4().to_string()
24}
25
26#[derive(Clone)]
27pub struct AuthState {
28 pub token: Option<String>,
29}
30
31pub async fn require_auth(
37 axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
38 request: Request,
39 next: Next,
40) -> Result<Response, StatusCode> {
41 let Some(expected) = &auth.token else {
42 return Ok(next.run(request).await);
43 };
44
45 let provided = request
46 .headers()
47 .get("authorization")
48 .and_then(|v| v.to_str().ok())
49 .and_then(|v| {
50 let lower = v.to_lowercase();
51 if lower.starts_with("bearer ") {
52 Some(v[BEARER_PREFIX_LEN..].to_string())
53 } else {
54 None
55 }
56 });
57
58 match provided {
59 Some(ref token) if constant_time_eq(token.as_bytes(), expected.as_bytes()) => {
60 Ok(next.run(request).await)
61 }
62 _ => {
63 tracing::warn!("victauri-browser: rejected request — invalid or missing auth token");
64 Err(StatusCode::UNAUTHORIZED)
65 }
66 }
67}
68
69fn now_ms() -> u64 {
70 std::time::SystemTime::now()
71 .duration_since(std::time::UNIX_EPOCH)
72 .unwrap_or_default()
73 .as_millis() as u64
74}
75
76pub struct RateLimiterState {
77 tokens: AtomicU64,
78 max_tokens: u64,
79 last_refill_ms: AtomicU64,
80 refill_rate_per_sec: u64,
81}
82
83impl RateLimiterState {
84 #[must_use]
85 pub fn new(max_requests_per_sec: u64) -> Self {
86 Self {
87 tokens: AtomicU64::new(max_requests_per_sec),
88 max_tokens: max_requests_per_sec,
89 last_refill_ms: AtomicU64::new(now_ms()),
90 refill_rate_per_sec: max_requests_per_sec,
91 }
92 }
93
94 pub fn try_acquire(&self) -> bool {
96 self.refill();
97 loop {
98 let current = self.tokens.load(Ordering::Relaxed);
99 if current == 0 {
100 return false;
101 }
102 if self
103 .tokens
104 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
105 .is_ok()
106 {
107 return true;
108 }
109 }
110 }
111
112 fn refill(&self) {
113 let now = now_ms();
114 let last = self.last_refill_ms.load(Ordering::Relaxed);
115 let elapsed_ms = now.saturating_sub(last);
116 if elapsed_ms < 10 {
117 return;
118 }
119 let new_tokens = (elapsed_ms * self.refill_rate_per_sec) / 1000;
120 if new_tokens == 0 {
121 return;
122 }
123 if self
124 .last_refill_ms
125 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
126 .is_ok()
127 {
128 let current = self.tokens.load(Ordering::Relaxed);
129 let capped = (current + new_tokens).min(self.max_tokens);
130 self.tokens.store(capped, Ordering::Relaxed);
131 }
132 }
133}
134
135#[must_use]
137pub fn default_rate_limiter() -> Arc<RateLimiterState> {
138 Arc::new(RateLimiterState::new(1000))
139}
140
141pub async fn rate_limit(
147 axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
148 request: Request,
149 next: Next,
150) -> Result<Response, StatusCode> {
151 if limiter.try_acquire() {
152 Ok(next.run(request).await)
153 } else {
154 Err(StatusCode::TOO_MANY_REQUESTS)
155 }
156}
157
158pub async fn security_headers(request: Request, next: Next) -> Response {
163 let mut response = next.run(request).await;
164 let headers = response.headers_mut();
165 headers.insert("x-content-type-options", "nosniff".parse().unwrap());
166 headers.insert("cache-control", "no-store".parse().unwrap());
167 response
168}
169
170pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
178 if let Some(origin) = request
179 .headers()
180 .get("origin")
181 .and_then(|v| v.to_str().ok())
182 {
183 let is_local = is_localhost_origin(origin);
184 if !is_local {
185 tracing::warn!("rejected non-local origin: {origin}");
186 return Err(StatusCode::FORBIDDEN);
187 }
188 }
189 Ok(next.run(request).await)
190}
191
192fn is_localhost_origin(origin: &str) -> bool {
193 let after_scheme = match origin.find("://") {
195 Some(i) => &origin[i + 3..],
196 None => origin,
197 };
198 let host = if after_scheme.starts_with('[') {
200 match after_scheme.find(']') {
202 Some(i) => &after_scheme[..=i],
203 None => after_scheme,
204 }
205 } else {
206 after_scheme.split(':').next().unwrap_or(after_scheme)
207 };
208 let host = host.split('/').next().unwrap_or(host);
210
211 host == "127.0.0.1" || host == "localhost" || host == "[::1]"
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn token_generation() {
220 let t1 = generate_token();
221 let t2 = generate_token();
222 assert_ne!(t1, t2);
223 assert_eq!(t1.len(), 36);
224 }
225
226 #[test]
227 fn rate_limiter_allows_within_budget() {
228 let limiter = RateLimiterState::new(10);
229 for _ in 0..10 {
230 assert!(limiter.try_acquire());
231 }
232 assert!(!limiter.try_acquire());
233 }
234
235 #[test]
236 fn constant_time_eq_works() {
237 assert!(constant_time_eq(b"hello", b"hello"));
238 assert!(!constant_time_eq(b"hello", b"world"));
239 assert!(!constant_time_eq(b"hello", b"hell"));
240 }
241
242 #[test]
243 fn constant_time_eq_empty_strings() {
244 assert!(constant_time_eq(b"", b""));
245 assert!(!constant_time_eq(b"", b"x"));
246 }
247
248 #[test]
249 fn constant_time_eq_single_bit_diff() {
250 assert!(!constant_time_eq(b"\x00", b"\x01"));
251 assert!(!constant_time_eq(b"\xff", b"\xfe"));
252 }
253
254 #[test]
255 fn rate_limiter_single_token() {
256 let limiter = RateLimiterState::new(1);
257 assert!(limiter.try_acquire());
258 assert!(!limiter.try_acquire());
259 }
260
261 #[test]
262 fn token_format_is_uuid() {
263 let token = generate_token();
264 assert_eq!(token.len(), 36);
265 assert_eq!(token.chars().filter(|c| *c == '-').count(), 4);
266 }
267
268 #[test]
269 fn default_rate_limiter_has_budget() {
270 let limiter = default_rate_limiter();
271 assert!(limiter.try_acquire());
272 }
273
274 #[test]
277 fn rate_limiter_exact_boundary() {
278 let limiter = RateLimiterState::new(100);
279 for i in 0..100 {
280 assert!(limiter.try_acquire(), "failed at iteration {i}");
281 }
282 assert!(!limiter.try_acquire());
283 assert!(!limiter.try_acquire());
284 assert!(!limiter.try_acquire());
285 }
286
287 #[test]
288 fn rate_limiter_concurrent_contention() {
289 use std::sync::Arc;
290 use std::thread;
291
292 let limiter = Arc::new(RateLimiterState::new(50));
293 let mut handles = vec![];
294
295 for _ in 0..10 {
296 let l = Arc::clone(&limiter);
297 handles.push(thread::spawn(move || {
298 let mut acquired = 0u32;
299 for _ in 0..20 {
300 if l.try_acquire() {
301 acquired += 1;
302 }
303 }
304 acquired
305 }));
306 }
307
308 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
309 assert!(total <= 50, "acquired {total} but budget was 50");
311 assert!(total >= 45, "should acquire most tokens, got {total}");
312 }
313
314 #[test]
315 fn constant_time_eq_long_strings() {
316 let a = "a".repeat(10_000);
317 let b = "a".repeat(10_000);
318 assert!(constant_time_eq(a.as_bytes(), b.as_bytes()));
319
320 let mut c = "a".repeat(10_000);
321 c.push('b');
322 assert!(!constant_time_eq(a.as_bytes(), c.as_bytes()));
323 }
324
325 #[test]
326 fn constant_time_eq_timing_consistency() {
327 let token = "8f14e45f-ceea-367f-a27f-c790e5a0fdc4";
328 let wrong1 = "0000000f-ceea-367f-a27f-c790e5a0fdc4";
329 let wrong2 = "8f14e45f-ceea-367f-a27f-c790e5a0fd00";
330
331 assert!(!constant_time_eq(token.as_bytes(), wrong1.as_bytes()));
333 assert!(!constant_time_eq(token.as_bytes(), wrong2.as_bytes()));
334 }
335
336 #[test]
337 fn token_uniqueness_over_1000_generations() {
338 let mut tokens = std::collections::HashSet::new();
339 for _ in 0..1000 {
340 let t = generate_token();
341 assert!(tokens.insert(t), "duplicate token generated");
342 }
343 }
344
345 #[test]
346 fn rate_limiter_zero_budget() {
347 let limiter = RateLimiterState::new(0);
348 assert!(!limiter.try_acquire());
349 }
350
351 #[test]
352 fn constant_time_eq_all_byte_values() {
353 for b in 0..=255u8 {
354 let a = [b];
355 assert!(constant_time_eq(&a, &a));
356 if b < 255 {
357 let c = [b + 1];
358 assert!(!constant_time_eq(&a, &c));
359 }
360 }
361 }
362
363 #[test]
366 fn localhost_origin_accepted() {
367 assert!(is_localhost_origin("http://localhost:3000"));
368 assert!(is_localhost_origin("http://localhost"));
369 assert!(is_localhost_origin("https://localhost:7474"));
370 }
371
372 #[test]
373 fn ipv4_loopback_accepted() {
374 assert!(is_localhost_origin("http://127.0.0.1:7474"));
375 assert!(is_localhost_origin("http://127.0.0.1"));
376 assert!(is_localhost_origin("https://127.0.0.1:443"));
377 }
378
379 #[test]
380 fn ipv6_loopback_accepted() {
381 assert!(is_localhost_origin("http://[::1]:7474"));
382 assert!(is_localhost_origin("http://[::1]"));
383 }
384
385 #[test]
386 fn subdomain_bypass_rejected() {
387 assert!(!is_localhost_origin("https://localhost.evil.com"));
388 assert!(!is_localhost_origin("https://127.0.0.1.evil.com"));
389 assert!(!is_localhost_origin("https://evil-localhost.com"));
390 }
391
392 #[test]
393 fn path_bypass_rejected() {
394 assert!(!is_localhost_origin("https://evil.com/localhost"));
395 assert!(!is_localhost_origin("https://evil.com/127.0.0.1"));
396 }
397
398 #[test]
399 fn external_origins_rejected() {
400 assert!(!is_localhost_origin("https://google.com"));
401 assert!(!is_localhost_origin("https://example.com:443"));
402 assert!(!is_localhost_origin("http://attacker.com"));
403 }
404}