1use axum::extract::Request;
2use axum::http::StatusCode;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8pub fn generate_token() -> String {
10 uuid::Uuid::new_v4().to_string()
11}
12
13#[derive(Clone)]
15pub struct AuthState {
16 pub token: Option<String>,
18}
19
20pub async fn require_auth(
22 axum::extract::State(auth): axum::extract::State<Arc<AuthState>>,
23 request: Request,
24 next: Next,
25) -> Result<Response, StatusCode> {
26 let expected = match &auth.token {
27 Some(t) => t,
28 None => return Ok(next.run(request).await),
29 };
30
31 let provided = request
32 .headers()
33 .get("authorization")
34 .and_then(|v| v.to_str().ok())
35 .and_then(|v| {
36 let lower = v.to_lowercase();
37 if lower.starts_with("bearer ") {
38 Some(v[7..].to_string())
39 } else {
40 None
41 }
42 });
43
44 match provided {
45 Some(ref token) if token == expected => Ok(next.run(request).await),
46 _ => Err(StatusCode::UNAUTHORIZED),
47 }
48}
49
50pub struct RateLimiterState {
54 tokens: AtomicU64,
55 max_tokens: u64,
56 last_refill_ms: AtomicU64,
57 refill_rate_per_sec: u64,
58}
59
60fn now_ms() -> u64 {
61 std::time::SystemTime::now()
62 .duration_since(std::time::UNIX_EPOCH)
63 .unwrap_or_default()
64 .as_millis() as u64
65}
66
67impl RateLimiterState {
68 pub fn new(max_requests_per_sec: u64) -> Self {
70 Self {
71 tokens: AtomicU64::new(max_requests_per_sec),
72 max_tokens: max_requests_per_sec,
73 last_refill_ms: AtomicU64::new(now_ms()),
74 refill_rate_per_sec: max_requests_per_sec,
75 }
76 }
77
78 pub fn try_acquire(&self) -> bool {
80 self.refill();
81 loop {
82 let current = self.tokens.load(Ordering::Relaxed);
83 if current == 0 {
84 return false;
85 }
86 if self
87 .tokens
88 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
89 .is_ok()
90 {
91 return true;
92 }
93 }
94 }
95
96 fn refill(&self) {
97 let now = now_ms();
98 let last = self.last_refill_ms.load(Ordering::Relaxed);
99 let elapsed_ms = now.saturating_sub(last);
100 if elapsed_ms == 0 {
101 return;
102 }
103 let add = elapsed_ms * self.refill_rate_per_sec / 1000;
104 if add == 0 {
105 return;
106 }
107 if self
108 .last_refill_ms
109 .compare_exchange(last, now, Ordering::Relaxed, Ordering::Relaxed)
110 .is_ok()
111 {
112 loop {
113 let current = self.tokens.load(Ordering::Relaxed);
114 let new_val = (current + add).min(self.max_tokens);
115 if self
116 .tokens
117 .compare_exchange_weak(current, new_val, Ordering::Relaxed, Ordering::Relaxed)
118 .is_ok()
119 {
120 break;
121 }
122 }
123 }
124 }
125}
126
127pub async fn rate_limit(
129 axum::extract::State(limiter): axum::extract::State<Arc<RateLimiterState>>,
130 request: Request,
131 next: Next,
132) -> Result<Response, StatusCode> {
133 if limiter.try_acquire() {
134 Ok(next.run(request).await)
135 } else {
136 Err(StatusCode::TOO_MANY_REQUESTS)
137 }
138}
139
140const DEFAULT_RATE_LIMIT: u64 = 100;
141
142pub fn default_rate_limiter() -> Arc<RateLimiterState> {
144 Arc::new(RateLimiterState::new(DEFAULT_RATE_LIMIT))
145}
146
147pub async fn dns_rebinding_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
153 let host = request
154 .headers()
155 .get("host")
156 .and_then(|v| v.to_str().ok())
157 .unwrap_or("");
158 let host_name = if host.starts_with('[') {
159 host.split(']').next().map(|s| &s[1..]).unwrap_or(host)
160 } else {
161 host.split(':').next().unwrap_or(host)
162 };
163 let is_allowed = matches!(host_name, "localhost" | "127.0.0.1" | "::1" | "");
164 if !is_allowed {
165 tracing::warn!("DNS rebinding attempt blocked: Host={host}");
166 return Err(StatusCode::FORBIDDEN);
167 }
168 Ok(next.run(request).await)
169}
170
171pub async fn origin_guard(request: Request, next: Next) -> Result<Response, StatusCode> {
173 if let Some(origin) = request
174 .headers()
175 .get("origin")
176 .and_then(|v| v.to_str().ok())
177 {
178 let allowed = origin.starts_with("http://localhost")
179 || origin.starts_with("https://localhost")
180 || origin.starts_with("http://127.0.0.1")
181 || origin.starts_with("https://127.0.0.1")
182 || origin.starts_with("http://[::1]")
183 || origin.starts_with("https://[::1]")
184 || origin.starts_with("tauri://")
185 || origin == "null";
186 if !allowed {
187 tracing::warn!("Cross-origin request blocked: Origin={origin}");
188 return Err(StatusCode::FORBIDDEN);
189 }
190 }
191 Ok(next.run(request).await)
192}
193
194pub async fn security_headers(request: Request, next: Next) -> Response {
196 let mut response = next.run(request).await;
197 let headers = response.headers_mut();
198 headers.insert(
199 axum::http::header::X_CONTENT_TYPE_OPTIONS,
200 axum::http::HeaderValue::from_static("nosniff"),
201 );
202 headers.insert(
203 axum::http::header::CACHE_CONTROL,
204 axum::http::HeaderValue::from_static("no-store"),
205 );
206 headers.insert(
207 axum::http::header::HeaderName::from_static("x-frame-options"),
208 axum::http::HeaderValue::from_static("DENY"),
209 );
210 response
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use axum::Router;
217 use axum::body::Body;
218 use axum::middleware;
219 use axum::routing::get;
220 use tower::ServiceExt; async fn ok_handler() -> &'static str {
223 "ok"
224 }
225
226 #[test]
227 fn token_generation_is_unique() {
228 let t1 = generate_token();
229 let t2 = generate_token();
230 assert_ne!(t1, t2);
231 assert_eq!(t1.len(), 36); }
233
234 #[test]
235 fn token_is_valid_uuid() {
236 let token = generate_token();
237 assert!(uuid::Uuid::parse_str(&token).is_ok());
238 }
239
240 #[test]
241 fn rate_limiter_allows_within_budget() {
242 let limiter = RateLimiterState::new(10);
243 for _ in 0..10 {
244 assert!(limiter.try_acquire());
245 }
246 }
247
248 #[test]
249 fn rate_limiter_denies_when_exhausted() {
250 let limiter = RateLimiterState::new(5);
251 for _ in 0..5 {
252 assert!(limiter.try_acquire());
253 }
254 assert!(!limiter.try_acquire());
255 }
256
257 #[test]
258 fn rate_limiter_initial_tokens_match_max() {
259 let limiter = RateLimiterState::new(42);
260 assert_eq!(limiter.tokens.load(Ordering::Relaxed), 42);
261 assert_eq!(limiter.max_tokens, 42);
262 }
263
264 #[test]
265 fn rate_limiter_concurrent_acquire() {
266 let limiter = Arc::new(RateLimiterState::new(1000));
268 let mut handles = vec![];
269 for _ in 0..10 {
270 let l = limiter.clone();
271 handles.push(std::thread::spawn(move || {
272 let mut acquired = 0;
273 for _ in 0..200 {
274 if l.try_acquire() {
275 acquired += 1;
276 }
277 }
278 acquired
279 }));
280 }
281 let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
282 assert!((1000..=1010).contains(&total));
284 }
285
286 #[test]
287 fn default_rate_limiter_has_100_tokens() {
288 let limiter = default_rate_limiter();
289 assert_eq!(limiter.max_tokens, 100);
290 }
291
292 #[test]
293 fn rate_limiter_zero_capacity() {
294 let limiter = RateLimiterState::new(0);
295 assert!(!limiter.try_acquire());
296 }
297
298 fn dns_rebinding_router() -> Router {
301 Router::new()
302 .route("/test", get(ok_handler))
303 .layer(middleware::from_fn(dns_rebinding_guard))
304 }
305
306 fn dns_request(host: Option<&str>) -> Request<Body> {
307 let mut builder = Request::builder().uri("/test");
308 if let Some(h) = host {
309 builder = builder.header("host", h);
310 }
311 builder.body(Body::empty()).unwrap()
312 }
313
314 #[tokio::test]
315 async fn dns_rebinding_allows_localhost() {
316 let app = dns_rebinding_router();
317 let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
318 assert_eq!(resp.status(), StatusCode::OK);
319 }
320
321 #[tokio::test]
322 async fn dns_rebinding_allows_127_0_0_1() {
323 let app = dns_rebinding_router();
324 let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
325 assert_eq!(resp.status(), StatusCode::OK);
326 }
327
328 #[tokio::test]
329 async fn dns_rebinding_allows_ipv6_bracketed() {
330 let app = dns_rebinding_router();
331 let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
332 assert_eq!(resp.status(), StatusCode::OK);
333 }
334
335 #[tokio::test]
336 async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
337 let app = dns_rebinding_router();
338 let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
339 assert_eq!(resp.status(), StatusCode::OK);
340 }
341
342 #[tokio::test]
343 async fn dns_rebinding_allows_ipv6_bare() {
344 let app = dns_rebinding_router();
345 let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
346 assert_eq!(resp.status(), StatusCode::OK);
347 }
348
349 #[tokio::test]
350 async fn dns_rebinding_allows_empty_host() {
351 let app = dns_rebinding_router();
352 let resp = app.oneshot(dns_request(None)).await.unwrap();
353 assert_eq!(resp.status(), StatusCode::OK);
354 }
355
356 #[tokio::test]
357 async fn dns_rebinding_blocks_evil_com() {
358 let app = dns_rebinding_router();
359 let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
360 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
361 }
362
363 #[tokio::test]
364 async fn dns_rebinding_blocks_localhost_subdomain() {
365 let app = dns_rebinding_router();
366 let resp = app
367 .oneshot(dns_request(Some("localhost.evil.com")))
368 .await
369 .unwrap();
370 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
371 }
372
373 #[tokio::test]
374 async fn dns_rebinding_blocks_ip_subdomain() {
375 let app = dns_rebinding_router();
376 let resp = app
377 .oneshot(dns_request(Some("127.0.0.1.evil.com")))
378 .await
379 .unwrap();
380 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
381 }
382
383 fn origin_router() -> Router {
386 Router::new()
387 .route("/test", get(ok_handler))
388 .layer(middleware::from_fn(origin_guard))
389 }
390
391 fn origin_request(origin: Option<&str>) -> Request<Body> {
392 let mut builder = Request::builder().uri("/test");
393 if let Some(o) = origin {
394 builder = builder.header("origin", o);
395 }
396 builder.body(Body::empty()).unwrap()
397 }
398
399 #[tokio::test]
400 async fn origin_allows_no_origin() {
401 let app = origin_router();
402 let resp = app.oneshot(origin_request(None)).await.unwrap();
403 assert_eq!(resp.status(), StatusCode::OK);
404 }
405
406 #[tokio::test]
407 async fn origin_allows_localhost_http() {
408 let app = origin_router();
409 let resp = app
410 .oneshot(origin_request(Some("http://localhost:3000")))
411 .await
412 .unwrap();
413 assert_eq!(resp.status(), StatusCode::OK);
414 }
415
416 #[tokio::test]
417 async fn origin_allows_127_0_0_1_https() {
418 let app = origin_router();
419 let resp = app
420 .oneshot(origin_request(Some("https://127.0.0.1:8080")))
421 .await
422 .unwrap();
423 assert_eq!(resp.status(), StatusCode::OK);
424 }
425
426 #[tokio::test]
427 async fn origin_allows_tauri_scheme() {
428 let app = origin_router();
429 let resp = app
430 .oneshot(origin_request(Some("tauri://localhost")))
431 .await
432 .unwrap();
433 assert_eq!(resp.status(), StatusCode::OK);
434 }
435
436 #[tokio::test]
437 async fn origin_allows_null() {
438 let app = origin_router();
439 let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
440 assert_eq!(resp.status(), StatusCode::OK);
441 }
442
443 #[tokio::test]
444 async fn origin_blocks_evil_com() {
445 let app = origin_router();
446 let resp = app
447 .oneshot(origin_request(Some("http://evil.com")))
448 .await
449 .unwrap();
450 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
451 }
452
453 fn security_headers_router() -> Router {
456 Router::new()
457 .route("/test", get(ok_handler))
458 .layer(middleware::from_fn(security_headers))
459 }
460
461 #[tokio::test]
462 async fn security_headers_x_content_type_options() {
463 let app = security_headers_router();
464 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
465 let resp = app.oneshot(req).await.unwrap();
466 assert_eq!(resp.status(), StatusCode::OK);
467 assert_eq!(
468 resp.headers().get("x-content-type-options").unwrap(),
469 "nosniff"
470 );
471 }
472
473 #[tokio::test]
474 async fn security_headers_cache_control() {
475 let app = security_headers_router();
476 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
477 let resp = app.oneshot(req).await.unwrap();
478 assert_eq!(resp.status(), StatusCode::OK);
479 assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
480 }
481
482 #[tokio::test]
483 async fn security_headers_x_frame_options() {
484 let app = security_headers_router();
485 let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
486 let resp = app.oneshot(req).await.unwrap();
487 assert_eq!(resp.status(), StatusCode::OK);
488 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
489 }
490}