1pub use victauri_core::middleware::{
2 AuthState, default_rate_limiter, dns_rebinding_guard, origin_guard, rate_limit, require_auth,
3 security_headers,
4};
5pub use victauri_core::security::{
6 self, RateLimiter as RateLimiterState, constant_time_eq, generate_token, is_allowed_origin,
7 is_localhost_host,
8};
9
10#[cfg(test)]
11mod tests {
12 use std::sync::Arc;
13
14 use super::*;
15 use axum::Router;
16 use axum::body::Body;
17 use axum::http::StatusCode;
18 use axum::middleware;
19 use axum::routing::get;
20 use tower::ServiceExt;
21
22 async fn ok_handler() -> &'static str {
23 "ok"
24 }
25
26 #[test]
27 fn token_generation_is_unique() {
28 let t1 = generate_token();
29 let t2 = generate_token();
30 assert_ne!(t1, t2);
31 assert_eq!(t1.len(), 36); }
33
34 #[test]
35 fn token_is_valid_uuid() {
36 let token = generate_token();
37 assert!(uuid::Uuid::parse_str(&token).is_ok());
38 }
39
40 #[test]
41 fn rate_limiter_allows_within_budget() {
42 let limiter = RateLimiterState::new(10);
43 for _ in 0..10 {
44 assert!(limiter.try_acquire());
45 }
46 }
47
48 #[test]
49 fn rate_limiter_denies_when_exhausted() {
50 let limiter = RateLimiterState::new(5);
51 for _ in 0..5 {
52 assert!(limiter.try_acquire());
53 }
54 assert!(!limiter.try_acquire());
55 }
56
57 #[test]
58 fn rate_limiter_initial_tokens_match_max() {
59 let limiter = RateLimiterState::new(42);
60 assert_eq!(limiter.current_tokens(), 42);
61 assert_eq!(limiter.max_tokens(), 42);
62 }
63
64 #[test]
65 fn rate_limiter_concurrent_acquire() {
66 let limiter = Arc::new(RateLimiterState::new(1000));
67 let mut handles = vec![];
68 for _ in 0..10 {
69 let l = limiter.clone();
70 handles.push(std::thread::spawn(move || {
71 let mut acquired = 0;
72 for _ in 0..200 {
73 if l.try_acquire() {
74 acquired += 1;
75 }
76 }
77 acquired
78 }));
79 }
80 let total: u64 = handles.into_iter().map(|h| h.join().unwrap()).sum();
81 assert!(
82 total >= 1000,
83 "should dispense at least the initial budget, got {total}"
84 );
85 assert!(total <= 1200, "refill overshoot too high, got {total}");
86 }
87
88 #[test]
89 fn default_rate_limiter_has_expected_tokens() {
90 let limiter = default_rate_limiter();
91 assert_eq!(limiter.max_tokens(), 1000);
92 }
93
94 #[test]
95 fn rate_limiter_zero_capacity() {
96 let limiter = RateLimiterState::new(0);
97 assert!(!limiter.try_acquire());
98 }
99
100 fn dns_rebinding_router() -> Router {
103 Router::new()
104 .route("/test", get(ok_handler))
105 .layer(middleware::from_fn(dns_rebinding_guard))
106 }
107
108 fn dns_request(host: Option<&str>) -> axum::extract::Request<Body> {
109 let mut builder = axum::extract::Request::builder().uri("/test");
110 if let Some(h) = host {
111 builder = builder.header("host", h);
112 }
113 builder.body(Body::empty()).unwrap()
114 }
115
116 #[tokio::test]
117 async fn dns_rebinding_allows_localhost() {
118 let app = dns_rebinding_router();
119 let resp = app.oneshot(dns_request(Some("localhost"))).await.unwrap();
120 assert_eq!(resp.status(), StatusCode::OK);
121 }
122
123 #[tokio::test]
124 async fn dns_rebinding_allows_127_0_0_1() {
125 let app = dns_rebinding_router();
126 let resp = app.oneshot(dns_request(Some("127.0.0.1"))).await.unwrap();
127 assert_eq!(resp.status(), StatusCode::OK);
128 }
129
130 #[tokio::test]
131 async fn dns_rebinding_allows_ipv6_bracketed() {
132 let app = dns_rebinding_router();
133 let resp = app.oneshot(dns_request(Some("[::1]"))).await.unwrap();
134 assert_eq!(resp.status(), StatusCode::OK);
135 }
136
137 #[tokio::test]
138 async fn dns_rebinding_allows_ipv6_bracketed_with_port() {
139 let app = dns_rebinding_router();
140 let resp = app.oneshot(dns_request(Some("[::1]:7373"))).await.unwrap();
141 assert_eq!(resp.status(), StatusCode::OK);
142 }
143
144 #[tokio::test]
145 async fn dns_rebinding_allows_ipv6_bare() {
146 let app = dns_rebinding_router();
147 let resp = app.oneshot(dns_request(Some("::1"))).await.unwrap();
148 assert_eq!(resp.status(), StatusCode::OK);
149 }
150
151 #[tokio::test]
152 async fn dns_rebinding_blocks_empty_host() {
153 let app = dns_rebinding_router();
154 let resp = app.oneshot(dns_request(None)).await.unwrap();
155 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
156 }
157
158 #[tokio::test]
159 async fn dns_rebinding_blocks_evil_com() {
160 let app = dns_rebinding_router();
161 let resp = app.oneshot(dns_request(Some("evil.com"))).await.unwrap();
162 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
163 }
164
165 #[tokio::test]
166 async fn dns_rebinding_blocks_localhost_subdomain() {
167 let app = dns_rebinding_router();
168 let resp = app
169 .oneshot(dns_request(Some("localhost.evil.com")))
170 .await
171 .unwrap();
172 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
173 }
174
175 #[tokio::test]
176 async fn dns_rebinding_blocks_ip_subdomain() {
177 let app = dns_rebinding_router();
178 let resp = app
179 .oneshot(dns_request(Some("127.0.0.1.evil.com")))
180 .await
181 .unwrap();
182 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
183 }
184
185 fn origin_router() -> Router {
188 Router::new()
189 .route("/test", get(ok_handler))
190 .layer(middleware::from_fn(origin_guard))
191 }
192
193 fn origin_request(origin: Option<&str>) -> axum::extract::Request<Body> {
194 let mut builder = axum::extract::Request::builder().uri("/test");
195 if let Some(o) = origin {
196 builder = builder.header("origin", o);
197 }
198 builder.body(Body::empty()).unwrap()
199 }
200
201 #[tokio::test]
202 async fn origin_allows_no_origin() {
203 let app = origin_router();
204 let resp = app.oneshot(origin_request(None)).await.unwrap();
205 assert_eq!(resp.status(), StatusCode::OK);
206 }
207
208 #[tokio::test]
209 async fn origin_allows_localhost_http() {
210 let app = origin_router();
211 let resp = app
212 .oneshot(origin_request(Some("http://localhost:3000")))
213 .await
214 .unwrap();
215 assert_eq!(resp.status(), StatusCode::OK);
216 }
217
218 #[tokio::test]
219 async fn origin_allows_127_0_0_1_https() {
220 let app = origin_router();
221 let resp = app
222 .oneshot(origin_request(Some("https://127.0.0.1:8080")))
223 .await
224 .unwrap();
225 assert_eq!(resp.status(), StatusCode::OK);
226 }
227
228 #[tokio::test]
229 async fn origin_allows_tauri_scheme() {
230 let app = origin_router();
231 let resp = app
232 .oneshot(origin_request(Some("tauri://localhost")))
233 .await
234 .unwrap();
235 assert_eq!(resp.status(), StatusCode::OK);
236 }
237
238 #[tokio::test]
239 async fn origin_blocks_null() {
240 let app = origin_router();
241 let resp = app.oneshot(origin_request(Some("null"))).await.unwrap();
242 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
243 }
244
245 #[tokio::test]
246 async fn origin_blocks_evil_com() {
247 let app = origin_router();
248 let resp = app
249 .oneshot(origin_request(Some("http://evil.com")))
250 .await
251 .unwrap();
252 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
253 }
254
255 fn security_headers_router() -> Router {
258 Router::new()
259 .route("/test", get(ok_handler))
260 .layer(middleware::from_fn(security_headers))
261 }
262
263 #[tokio::test]
264 async fn security_headers_x_content_type_options() {
265 let app = security_headers_router();
266 let req = axum::extract::Request::builder()
267 .uri("/test")
268 .body(Body::empty())
269 .unwrap();
270 let resp = app.oneshot(req).await.unwrap();
271 assert_eq!(resp.status(), StatusCode::OK);
272 assert_eq!(
273 resp.headers().get("x-content-type-options").unwrap(),
274 "nosniff"
275 );
276 }
277
278 #[tokio::test]
279 async fn security_headers_cache_control() {
280 let app = security_headers_router();
281 let req = axum::extract::Request::builder()
282 .uri("/test")
283 .body(Body::empty())
284 .unwrap();
285 let resp = app.oneshot(req).await.unwrap();
286 assert_eq!(resp.status(), StatusCode::OK);
287 assert_eq!(resp.headers().get("cache-control").unwrap(), "no-store");
288 }
289
290 #[tokio::test]
291 async fn security_headers_x_frame_options() {
292 let app = security_headers_router();
293 let req = axum::extract::Request::builder()
294 .uri("/test")
295 .body(Body::empty())
296 .unwrap();
297 let resp = app.oneshot(req).await.unwrap();
298 assert_eq!(resp.status(), StatusCode::OK);
299 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
300 }
301
302 fn auth_router(token: Option<&str>) -> Router {
305 let state = Arc::new(AuthState {
306 token: token.map(String::from),
307 });
308 Router::new()
309 .route("/test", get(ok_handler))
310 .layer(middleware::from_fn_with_state(state, require_auth))
311 }
312
313 fn auth_request(token: Option<&str>) -> axum::extract::Request<Body> {
314 let mut builder = axum::extract::Request::builder().uri("/test");
315 if let Some(t) = token {
316 builder = builder.header("authorization", format!("Bearer {t}"));
317 }
318 builder.body(Body::empty()).unwrap()
319 }
320
321 #[tokio::test]
322 async fn auth_allows_correct_token() {
323 let app = auth_router(Some("secret-123"));
324 let resp = app.oneshot(auth_request(Some("secret-123"))).await.unwrap();
325 assert_eq!(resp.status(), StatusCode::OK);
326 }
327
328 #[tokio::test]
329 async fn auth_rejects_wrong_token() {
330 let app = auth_router(Some("secret-123"));
331 let resp = app
332 .oneshot(auth_request(Some("wrong-token")))
333 .await
334 .unwrap();
335 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
336 }
337
338 #[tokio::test]
339 async fn auth_rejects_missing_token() {
340 let app = auth_router(Some("secret-123"));
341 let resp = app.oneshot(auth_request(None)).await.unwrap();
342 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
343 }
344
345 #[tokio::test]
346 async fn auth_allows_any_when_disabled() {
347 let app = auth_router(None);
348 let resp = app.oneshot(auth_request(None)).await.unwrap();
349 assert_eq!(resp.status(), StatusCode::OK);
350 }
351
352 #[tokio::test]
353 async fn auth_case_insensitive_bearer_prefix() {
354 let state = Arc::new(AuthState {
355 token: Some("my-token".into()),
356 });
357 let app = Router::new()
358 .route("/test", get(ok_handler))
359 .layer(middleware::from_fn_with_state(state, require_auth));
360
361 let req = axum::extract::Request::builder()
362 .uri("/test")
363 .header("authorization", "BEARER my-token")
364 .body(Body::empty())
365 .unwrap();
366 let resp = app.oneshot(req).await.unwrap();
367 assert_eq!(resp.status(), StatusCode::OK);
368 }
369
370 #[tokio::test]
371 async fn auth_rejects_non_bearer_scheme() {
372 let app = auth_router(Some("secret"));
373 let req = axum::extract::Request::builder()
374 .uri("/test")
375 .header("authorization", "Basic c2VjcmV0")
376 .body(Body::empty())
377 .unwrap();
378 let resp = app.oneshot(req).await.unwrap();
379 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
380 }
381
382 #[tokio::test]
385 async fn rate_limiter_returns_429_when_exhausted() {
386 let limiter = Arc::new(RateLimiterState::new(2));
387 let app = Router::new()
388 .route("/test", get(ok_handler))
389 .layer(middleware::from_fn_with_state(limiter, rate_limit));
390
391 let app2 = app.clone();
392 let app3 = app2.clone();
393
394 let req = axum::extract::Request::builder()
395 .uri("/test")
396 .body(Body::empty())
397 .unwrap();
398 assert_eq!(app.oneshot(req).await.unwrap().status(), StatusCode::OK);
399
400 let req = axum::extract::Request::builder()
401 .uri("/test")
402 .body(Body::empty())
403 .unwrap();
404 assert_eq!(app2.oneshot(req).await.unwrap().status(), StatusCode::OK);
405
406 let req = axum::extract::Request::builder()
407 .uri("/test")
408 .body(Body::empty())
409 .unwrap();
410 assert_eq!(
411 app3.oneshot(req).await.unwrap().status(),
412 StatusCode::TOO_MANY_REQUESTS
413 );
414 }
415
416 #[tokio::test]
419 async fn combined_layers_enforce_all_guards() {
420 let auth_state = Arc::new(AuthState {
421 token: Some("tok-123".into()),
422 });
423 let limiter = Arc::new(RateLimiterState::new(100));
424
425 let app = Router::new()
426 .route("/test", get(ok_handler))
427 .layer(middleware::from_fn_with_state(auth_state, require_auth))
428 .layer(middleware::from_fn_with_state(limiter, rate_limit))
429 .layer(middleware::from_fn(security_headers))
430 .layer(middleware::from_fn(origin_guard))
431 .layer(middleware::from_fn(dns_rebinding_guard));
432
433 let req = axum::extract::Request::builder()
435 .uri("/test")
436 .header("authorization", "Bearer tok-123")
437 .header("host", "127.0.0.1:7373")
438 .body(Body::empty())
439 .unwrap();
440 let resp = app.clone().oneshot(req).await.unwrap();
441 assert_eq!(resp.status(), StatusCode::OK);
442 assert_eq!(resp.headers().get("x-frame-options").unwrap(), "DENY");
443
444 let req = axum::extract::Request::builder()
446 .uri("/test")
447 .header("authorization", "Bearer tok-123")
448 .header("host", "evil.com")
449 .body(Body::empty())
450 .unwrap();
451 let resp = app.clone().oneshot(req).await.unwrap();
452 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
453
454 let req = axum::extract::Request::builder()
456 .uri("/test")
457 .header("authorization", "Bearer tok-123")
458 .header("host", "localhost")
459 .header("origin", "https://evil.com")
460 .body(Body::empty())
461 .unwrap();
462 let resp = app.clone().oneshot(req).await.unwrap();
463 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
464
465 let req = axum::extract::Request::builder()
467 .uri("/test")
468 .header("host", "localhost")
469 .body(Body::empty())
470 .unwrap();
471 let resp = app.oneshot(req).await.unwrap();
472 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
473 }
474
475 #[test]
476 fn origin_guard_allows_localhost_variants() {
477 assert!(is_allowed_origin("http://localhost"));
478 assert!(is_allowed_origin("http://localhost:7373"));
479 assert!(is_allowed_origin("https://localhost"));
480 assert!(is_allowed_origin("https://localhost:443"));
481 assert!(is_allowed_origin("http://127.0.0.1"));
482 assert!(is_allowed_origin("http://127.0.0.1:8080"));
483 assert!(is_allowed_origin("https://127.0.0.1"));
484 assert!(is_allowed_origin("http://[::1]"));
485 assert!(is_allowed_origin("http://[::1]:7373"));
486 assert!(is_allowed_origin("tauri://localhost"));
487 assert!(is_allowed_origin("tauri://some-app"));
488 }
489
490 #[test]
491 fn origin_guard_rejects_prefix_smuggling() {
492 assert!(!is_allowed_origin("http://localhost.evil.com"));
493 assert!(!is_allowed_origin("https://localhost.evil.com"));
494 assert!(!is_allowed_origin("https://127.0.0.1.evil.com"));
495 assert!(!is_allowed_origin("http://[::1].evil.com"));
496 }
497
498 #[test]
499 fn origin_guard_rejects_userinfo_trick() {
500 assert!(!is_allowed_origin("http://localhost@evil.com"));
501 assert!(!is_allowed_origin("http://127.0.0.1@evil.com"));
502 }
503
504 #[test]
505 fn origin_guard_rejects_foreign_and_malformed() {
506 assert!(!is_allowed_origin("http://evil.com"));
507 assert!(!is_allowed_origin("https://attacker.io"));
508 assert!(!is_allowed_origin("not-a-url"));
509 assert!(!is_allowed_origin(""));
510 assert!(!is_allowed_origin("ftp://localhost"));
511 }
512
513 #[test]
516 fn constant_time_eq_equal_strings() {
517 assert!(constant_time_eq(b"secret-token-123", b"secret-token-123"));
518 }
519
520 #[test]
521 fn constant_time_eq_different_strings() {
522 assert!(!constant_time_eq(b"secret-token-123", b"wrong-token-9999"));
523 }
524
525 #[test]
526 fn constant_time_eq_different_lengths() {
527 assert!(!constant_time_eq(b"short", b"longer-string"));
528 }
529
530 #[test]
531 fn constant_time_eq_empty_strings() {
532 assert!(constant_time_eq(b"", b""));
533 }
534
535 #[test]
536 fn constant_time_eq_one_empty() {
537 assert!(!constant_time_eq(b"", b"notempty"));
538 assert!(!constant_time_eq(b"notempty", b""));
539 }
540
541 #[test]
542 fn constant_time_eq_single_bit_difference() {
543 assert!(!constant_time_eq(b"A", b"B"));
544 }
545
546 #[tokio::test]
549 async fn security_headers_cors_deny() {
550 let app = security_headers_router();
551 let req = axum::extract::Request::builder()
552 .uri("/test")
553 .body(Body::empty())
554 .unwrap();
555 let resp = app.oneshot(req).await.unwrap();
556 assert_eq!(
557 resp.headers().get("access-control-allow-origin").unwrap(),
558 "null"
559 );
560 }
561
562 #[tokio::test]
563 async fn security_headers_csp() {
564 let app = security_headers_router();
565 let req = axum::extract::Request::builder()
566 .uri("/test")
567 .body(Body::empty())
568 .unwrap();
569 let resp = app.oneshot(req).await.unwrap();
570 assert_eq!(
571 resp.headers().get("content-security-policy").unwrap(),
572 "default-src 'none'"
573 );
574 }
575}