1use axum::{
2 extract::Request,
3 http::StatusCode,
4 middleware::Next,
5 response::{IntoResponse, Response},
6 Json,
7};
8use std::env;
9
10use crate::errors::ErrorResponse;
11
12pub(crate) const DEFAULT_DEV_API_KEY: &str = "sk-shodh-dev-default";
15
16pub fn is_production_mode() -> bool {
18 env::var("SHODH_ENV")
19 .map(|v| v.to_lowercase() == "production" || v.to_lowercase() == "prod")
20 .unwrap_or(false)
21}
22
23fn should_hide_dev_key() -> bool {
28 if is_production_mode() {
29 return true;
30 }
31 env::var("SHODH_HIDE_DEV_KEY")
32 .map(|v| v.to_lowercase() == "true" || v == "1")
33 .unwrap_or(false)
34}
35
36pub fn log_security_status() {
38 let has_api_keys = env::var("SHODH_API_KEYS")
39 .map(|k| !k.trim().is_empty())
40 .unwrap_or(false);
41 let has_dev_key = env::var("SHODH_DEV_API_KEY")
42 .map(|k| !k.trim().is_empty())
43 .unwrap_or(false);
44 let is_prod = is_production_mode();
45
46 if is_prod {
47 if has_api_keys {
48 tracing::info!("Running in PRODUCTION mode with API key authentication");
49 } else {
50 tracing::error!(
51 "PRODUCTION mode but SHODH_API_KEYS not set! Server will reject all authenticated requests."
52 );
53 }
54 } else {
55 tracing::warn!("╔════════════════════════════════════════════════════════════════╗");
56 tracing::warn!("║ SECURITY WARNING: Running in DEVELOPMENT mode ║");
57 tracing::warn!("║ ║");
58 if has_dev_key {
59 tracing::warn!("║ Using SHODH_DEV_API_KEY for authentication. ║");
60 tracing::warn!("║ DO NOT use this configuration in production! ║");
61 } else if !has_api_keys {
62 tracing::warn!("║ No API keys configured. Using default dev key. ║");
63 tracing::warn!("║ DEPRECATION: Default dev key will be removed in v0.2.0. ║");
64 tracing::warn!("║ Set SHODH_DEV_API_KEY or SHODH_API_KEYS to override. ║");
65 tracing::warn!("║ Set SHODH_HIDE_DEV_KEY=true to hide key from error msgs. ║");
66 }
67 tracing::warn!("║ ║");
68 tracing::warn!("║ For production, set: ║");
69 tracing::warn!("║ SHODH_ENV=production ║");
70 tracing::warn!("║ SHODH_API_KEYS=your-secure-key-1,your-secure-key-2 ║");
71 tracing::warn!("╚════════════════════════════════════════════════════════════════╝");
72 }
73}
74
75#[derive(Debug)]
77pub enum AuthError {
78 MissingApiKey,
79 InvalidApiKey,
80 NotConfigured,
81}
82
83impl AuthError {
84 fn code(&self) -> &'static str {
85 match self {
86 Self::MissingApiKey => "MISSING_API_KEY",
87 Self::InvalidApiKey => "INVALID_API_KEY",
88 Self::NotConfigured => "AUTH_NOT_CONFIGURED",
89 }
90 }
91
92 fn status_code(&self) -> StatusCode {
93 match self {
94 Self::MissingApiKey | Self::InvalidApiKey => StatusCode::UNAUTHORIZED,
95 Self::NotConfigured => StatusCode::SERVICE_UNAVAILABLE,
96 }
97 }
98}
99
100impl IntoResponse for AuthError {
101 fn into_response(self) -> Response {
102 let is_prod = is_production_mode();
103 let status = self.status_code();
104
105 let message = match &self {
106 AuthError::MissingApiKey => {
107 if is_prod {
108 "Missing X-API-Key header".to_string()
109 } else if should_hide_dev_key() {
110 "Missing X-API-Key header. Set SHODH_DEV_API_KEY or SHODH_API_KEYS. \
111 See docs for setup."
112 .to_string()
113 } else {
114 format!(
115 "Missing X-API-Key header. Set the header in your request. \
116 The server accepts keys from SHODH_API_KEYS (comma-separated) \
117 or SHODH_DEV_API_KEY. Default dev key: '{}'",
118 DEFAULT_DEV_API_KEY
119 )
120 }
121 }
122 AuthError::InvalidApiKey => {
123 if is_prod {
124 "Invalid API key".to_string()
125 } else if should_hide_dev_key() {
126 "Invalid API key. Check SHODH_DEV_API_KEY or SHODH_API_KEYS.".to_string()
127 } else {
128 format!(
129 "Invalid API key. Expected a key from SHODH_API_KEYS or \
130 SHODH_DEV_API_KEY. Default dev key: '{}'",
131 DEFAULT_DEV_API_KEY
132 )
133 }
134 }
135 AuthError::NotConfigured => {
136 "API keys not configured. Set SHODH_API_KEYS environment variable.".to_string()
137 }
138 };
139
140 let body = ErrorResponse {
141 code: self.code().to_string(),
142 message,
143 details: None,
144 request_id: None,
145 };
146
147 (status, Json(body)).into_response()
148 }
149}
150
151fn constant_time_compare(a: &str, b: &str) -> bool {
156 let a_bytes = a.as_bytes();
157 let b_bytes = b.as_bytes();
158 let a_len = a_bytes.len();
159 let b_len = b_bytes.len();
160 let max_len = std::cmp::max(a_len, b_len);
161
162 let mut result: u32 = (a_len ^ b_len) as u32;
166
167 for i in 0..max_len {
170 let byte_a = if i < a_len { a_bytes[i] } else { 0 };
171 let byte_b = if i < b_len { b_bytes[i] } else { 0 };
172 result |= (byte_a ^ byte_b) as u32;
173 }
174
175 result == 0
176}
177
178pub fn validate_api_key(provided_key: &str) -> Result<(), AuthError> {
180 let valid_keys = match env::var("SHODH_API_KEYS") {
182 Ok(keys) if !keys.trim().is_empty() => keys,
183 _ => {
184 let is_production = env::var("SHODH_ENV")
186 .map(|v| v.to_lowercase() == "production" || v.to_lowercase() == "prod")
187 .unwrap_or(false);
188
189 if is_production {
190 tracing::error!("SHODH_API_KEYS not set in production mode");
191 return Err(AuthError::NotConfigured);
192 }
193
194 match env::var("SHODH_DEV_API_KEY") {
196 Ok(key) if !key.trim().is_empty() => {
197 tracing::warn!("Using SHODH_DEV_API_KEY for development (not for production!)");
198 key
199 }
200 _ => {
201 tracing::warn!(
202 "No API key configured. Falling back to default dev key. \
203 Set SHODH_DEV_API_KEY to override."
204 );
205 DEFAULT_DEV_API_KEY.to_string()
206 }
207 }
208 }
209 };
210
211 let keys: Vec<&str> = valid_keys.split(',').map(|k| k.trim()).collect();
212
213 let mut found = false;
215 for key in &keys {
216 if constant_time_compare(key, provided_key) {
217 found = true;
218 }
220 }
221
222 if found {
223 Ok(())
224 } else {
225 Err(AuthError::InvalidApiKey)
226 }
227}
228
229pub async fn auth_middleware(request: Request, next: Next) -> Response {
231 let path = request.uri().path();
232
233 if path == "/health" {
235 return next.run(request).await;
236 }
237
238 if path.starts_with("/webhook/") {
240 return next.run(request).await;
241 }
242
243 let api_key_value = match request
246 .headers()
247 .get("X-API-Key")
248 .and_then(|v| v.to_str().ok())
249 .map(|s| s.to_string())
250 .or_else(|| {
251 request
252 .headers()
253 .get("Authorization")
254 .and_then(|v| v.to_str().ok())
255 .and_then(|s| s.strip_prefix("Bearer "))
256 .map(|s| s.to_string())
257 })
258 .or_else(|| {
259 let is_websocket = request
265 .headers()
266 .get("upgrade")
267 .and_then(|v| v.to_str().ok())
268 .map(|v| v.eq_ignore_ascii_case("websocket"))
269 .unwrap_or(false);
270 if !is_websocket {
271 return None;
272 }
273 request.uri().query().and_then(|q| {
274 q.split('&')
275 .find_map(|pair| pair.strip_prefix("api_key=").map(|v| v.to_string()))
276 })
277 }) {
278 Some(key) => key,
279 None => return AuthError::MissingApiKey.into_response(),
280 };
281
282 if let Err(e) = validate_api_key(&api_key_value) {
284 return e.into_response();
285 }
286
287 next.run(request).await
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use axum::body::to_bytes;
295 use std::sync::Mutex;
296
297 static ENV_LOCK: Mutex<()> = Mutex::new(());
301
302 fn clear_auth_env() {
305 env::remove_var("SHODH_API_KEYS");
306 env::remove_var("SHODH_DEV_API_KEY");
307 env::remove_var("SHODH_ENV");
308 env::remove_var("SHODH_HIDE_DEV_KEY");
309 }
310
311 #[test]
314 fn constant_time_equal_strings() {
315 assert!(constant_time_compare("hello", "hello"));
316 }
317
318 #[test]
319 fn constant_time_different_strings() {
320 assert!(!constant_time_compare("hello", "world"));
321 }
322
323 #[test]
324 fn constant_time_different_lengths() {
325 assert!(!constant_time_compare("short", "a-longer-string"));
326 }
327
328 #[test]
329 fn constant_time_empty_strings() {
330 assert!(constant_time_compare("", ""));
331 }
332
333 #[test]
334 fn constant_time_one_empty() {
335 assert!(!constant_time_compare("", "notempty"));
336 assert!(!constant_time_compare("notempty", ""));
337 }
338
339 #[test]
340 fn constant_time_length_multiple_of_256() {
341 let long = "a".repeat(256);
344 assert!(!constant_time_compare(&long, ""));
345 assert!(!constant_time_compare("", &long));
346
347 let medium = "b".repeat(256);
349 let longer = "b".repeat(512);
350 assert!(!constant_time_compare(&medium, &longer));
351 }
352
353 #[test]
356 fn production_mode_detection() {
357 let _guard = ENV_LOCK.lock().unwrap();
358 clear_auth_env();
359
360 assert!(!is_production_mode());
361
362 env::set_var("SHODH_ENV", "production");
363 assert!(is_production_mode());
364
365 env::set_var("SHODH_ENV", "prod");
366 assert!(is_production_mode());
367
368 env::set_var("SHODH_ENV", "PRODUCTION");
369 assert!(is_production_mode());
370
371 env::set_var("SHODH_ENV", "development");
372 assert!(!is_production_mode());
373
374 env::set_var("SHODH_ENV", "test");
375 assert!(!is_production_mode());
376
377 clear_auth_env();
378 }
379
380 #[test]
383 fn validate_with_single_api_key() {
384 let _guard = ENV_LOCK.lock().unwrap();
385 clear_auth_env();
386 env::set_var("SHODH_API_KEYS", "my-key");
387 assert!(validate_api_key("my-key").is_ok());
388 assert!(validate_api_key("wrong").is_err());
389 clear_auth_env();
390 }
391
392 #[test]
393 fn validate_with_multiple_api_keys() {
394 let _guard = ENV_LOCK.lock().unwrap();
395 clear_auth_env();
396 env::set_var("SHODH_API_KEYS", "key1,key2,key3");
397 assert!(validate_api_key("key1").is_ok());
398 assert!(validate_api_key("key2").is_ok());
399 assert!(validate_api_key("key3").is_ok());
400 assert!(validate_api_key("key4").is_err());
401 clear_auth_env();
402 }
403
404 #[test]
405 fn validate_api_keys_trims_whitespace() {
406 let _guard = ENV_LOCK.lock().unwrap();
407 clear_auth_env();
408 env::set_var("SHODH_API_KEYS", " key1 , key2 ");
409 assert!(validate_api_key("key1").is_ok());
410 assert!(validate_api_key("key2").is_ok());
411 clear_auth_env();
412 }
413
414 #[test]
417 fn validate_with_dev_key() {
418 let _guard = ENV_LOCK.lock().unwrap();
419 clear_auth_env();
420 env::set_var("SHODH_DEV_API_KEY", "dev-key-123");
421 assert!(validate_api_key("dev-key-123").is_ok());
422 assert!(validate_api_key("wrong").is_err());
423 clear_auth_env();
424 }
425
426 #[test]
429 fn validate_with_default_dev_key_when_no_env_set() {
430 let _guard = ENV_LOCK.lock().unwrap();
431 clear_auth_env();
432 assert!(validate_api_key(DEFAULT_DEV_API_KEY).is_ok());
433 assert!(validate_api_key("wrong-key").is_err());
434 clear_auth_env();
435 }
436
437 #[test]
440 fn validate_production_rejects_when_no_keys() {
441 let _guard = ENV_LOCK.lock().unwrap();
442 clear_auth_env();
443 env::set_var("SHODH_ENV", "production");
444 let result = validate_api_key("any-key");
445 assert!(result.is_err());
446 match result.unwrap_err() {
447 AuthError::NotConfigured => {}
448 other => panic!("Expected NotConfigured, got {:?}", other),
449 }
450 clear_auth_env();
451 }
452
453 #[test]
454 fn validate_production_works_with_api_keys_set() {
455 let _guard = ENV_LOCK.lock().unwrap();
456 clear_auth_env();
457 env::set_var("SHODH_ENV", "production");
458 env::set_var("SHODH_API_KEYS", "prod-key");
459 assert!(validate_api_key("prod-key").is_ok());
460 assert!(validate_api_key("wrong").is_err());
461 clear_auth_env();
462 }
463
464 #[test]
467 fn validate_empty_api_keys_falls_through() {
468 let _guard = ENV_LOCK.lock().unwrap();
469 clear_auth_env();
470 env::set_var("SHODH_API_KEYS", " ");
471 assert!(validate_api_key(DEFAULT_DEV_API_KEY).is_ok());
473 clear_auth_env();
474 }
475
476 #[test]
477 fn validate_empty_dev_key_uses_default() {
478 let _guard = ENV_LOCK.lock().unwrap();
479 clear_auth_env();
480 env::set_var("SHODH_DEV_API_KEY", " ");
481 assert!(validate_api_key(DEFAULT_DEV_API_KEY).is_ok());
482 clear_auth_env();
483 }
484
485 #[test]
486 fn api_keys_takes_priority_over_dev_key() {
487 let _guard = ENV_LOCK.lock().unwrap();
488 clear_auth_env();
489 env::set_var("SHODH_API_KEYS", "prod-key");
490 env::set_var("SHODH_DEV_API_KEY", "dev-key");
491 assert!(validate_api_key("prod-key").is_ok());
492 assert!(validate_api_key("dev-key").is_err()); clear_auth_env();
494 }
495
496 #[test]
499 fn auth_error_status_codes() {
500 assert_eq!(
501 AuthError::MissingApiKey.status_code(),
502 StatusCode::UNAUTHORIZED
503 );
504 assert_eq!(
505 AuthError::InvalidApiKey.status_code(),
506 StatusCode::UNAUTHORIZED
507 );
508 assert_eq!(
509 AuthError::NotConfigured.status_code(),
510 StatusCode::SERVICE_UNAVAILABLE
511 );
512 }
513
514 #[test]
515 fn auth_error_codes() {
516 assert_eq!(AuthError::MissingApiKey.code(), "MISSING_API_KEY");
517 assert_eq!(AuthError::InvalidApiKey.code(), "INVALID_API_KEY");
518 assert_eq!(AuthError::NotConfigured.code(), "AUTH_NOT_CONFIGURED");
519 }
520
521 #[tokio::test]
524 async fn auth_error_response_is_valid_json() {
525 let _guard = ENV_LOCK.lock().unwrap();
526 clear_auth_env();
527 let resp = AuthError::MissingApiKey.into_response();
528 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
529
530 let body = to_bytes(resp.into_body(), 2048).await.unwrap();
531 let parsed: ErrorResponse = serde_json::from_slice(&body)
532 .expect("Response body should be valid JSON matching ErrorResponse");
533 assert_eq!(parsed.code, "MISSING_API_KEY");
534 assert!(parsed.message.contains("X-API-Key"));
535 clear_auth_env();
536 }
537
538 #[tokio::test]
539 async fn missing_key_dev_message_includes_help() {
540 let _guard = ENV_LOCK.lock().unwrap();
541 clear_auth_env();
542 let resp = AuthError::MissingApiKey.into_response();
544 let body = to_bytes(resp.into_body(), 2048).await.unwrap();
545 let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
546 assert!(
547 parsed.message.contains("SHODH_API_KEYS"),
548 "Should mention SHODH_API_KEYS"
549 );
550 assert!(
551 parsed.message.contains("SHODH_DEV_API_KEY"),
552 "Should mention SHODH_DEV_API_KEY"
553 );
554 assert!(
555 parsed.message.contains(DEFAULT_DEV_API_KEY),
556 "Should show the default dev key"
557 );
558 clear_auth_env();
559 }
560
561 #[tokio::test]
562 async fn invalid_key_dev_message_includes_help() {
563 let _guard = ENV_LOCK.lock().unwrap();
564 clear_auth_env();
565 let resp = AuthError::InvalidApiKey.into_response();
566 let body = to_bytes(resp.into_body(), 2048).await.unwrap();
567 let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
568 assert!(
569 parsed.message.contains("SHODH_API_KEYS"),
570 "Should mention SHODH_API_KEYS"
571 );
572 assert!(
573 parsed.message.contains(DEFAULT_DEV_API_KEY),
574 "Should show the default dev key"
575 );
576 clear_auth_env();
577 }
578
579 #[tokio::test]
580 async fn missing_key_prod_message_is_terse() {
581 let _guard = ENV_LOCK.lock().unwrap();
582 clear_auth_env();
583 env::set_var("SHODH_ENV", "production");
584 let resp = AuthError::MissingApiKey.into_response();
585 let body = to_bytes(resp.into_body(), 2048).await.unwrap();
586 let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
587 assert_eq!(parsed.message, "Missing X-API-Key header");
588 assert!(
589 !parsed.message.contains("SHODH_DEV_API_KEY"),
590 "Prod must not leak env var names"
591 );
592 clear_auth_env();
593 }
594
595 #[tokio::test]
596 async fn invalid_key_prod_message_is_terse() {
597 let _guard = ENV_LOCK.lock().unwrap();
598 clear_auth_env();
599 env::set_var("SHODH_ENV", "production");
600 let resp = AuthError::InvalidApiKey.into_response();
601 let body = to_bytes(resp.into_body(), 2048).await.unwrap();
602 let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
603 assert_eq!(parsed.message, "Invalid API key");
604 assert!(
605 !parsed.message.contains(DEFAULT_DEV_API_KEY),
606 "Prod must not leak default key"
607 );
608 clear_auth_env();
609 }
610
611 #[tokio::test]
612 async fn not_configured_response_shape() {
613 let resp = AuthError::NotConfigured.into_response();
614 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
615 let body = to_bytes(resp.into_body(), 2048).await.unwrap();
616 let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
617 assert_eq!(parsed.code, "AUTH_NOT_CONFIGURED");
618 assert!(parsed.message.contains("SHODH_API_KEYS"));
619 }
620
621 #[tokio::test]
624 async fn hide_dev_key_suppresses_key_in_missing_key_error() {
625 let _guard = ENV_LOCK.lock().unwrap();
626 clear_auth_env();
627 env::set_var("SHODH_HIDE_DEV_KEY", "true");
628
629 let resp = AuthError::MissingApiKey.into_response();
630 let body = to_bytes(resp.into_body(), 2048).await.unwrap();
631 let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
632
633 assert!(
634 !parsed.message.contains(DEFAULT_DEV_API_KEY),
635 "SHODH_HIDE_DEV_KEY=true should suppress key in error: {}",
636 parsed.message
637 );
638 assert!(
639 parsed.message.contains("SHODH_DEV_API_KEY"),
640 "Should still mention env var name: {}",
641 parsed.message
642 );
643 clear_auth_env();
644 }
645
646 #[tokio::test]
647 async fn hide_dev_key_suppresses_key_in_invalid_key_error() {
648 let _guard = ENV_LOCK.lock().unwrap();
649 clear_auth_env();
650 env::set_var("SHODH_HIDE_DEV_KEY", "true");
651
652 let resp = AuthError::InvalidApiKey.into_response();
653 let body = to_bytes(resp.into_body(), 2048).await.unwrap();
654 let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
655
656 assert!(
657 !parsed.message.contains(DEFAULT_DEV_API_KEY),
658 "SHODH_HIDE_DEV_KEY=true should suppress key in error: {}",
659 parsed.message
660 );
661 clear_auth_env();
662 }
663
664 #[test]
665 fn should_hide_dev_key_defaults_to_false() {
666 let _guard = ENV_LOCK.lock().unwrap();
667 clear_auth_env();
668 assert!(!should_hide_dev_key());
669 clear_auth_env();
670 }
671
672 #[test]
673 fn should_hide_dev_key_respects_env_var() {
674 let _guard = ENV_LOCK.lock().unwrap();
675 clear_auth_env();
676
677 env::set_var("SHODH_HIDE_DEV_KEY", "true");
678 assert!(should_hide_dev_key());
679
680 env::set_var("SHODH_HIDE_DEV_KEY", "1");
681 assert!(should_hide_dev_key());
682
683 env::set_var("SHODH_HIDE_DEV_KEY", "false");
684 assert!(!should_hide_dev_key());
685
686 clear_auth_env();
687 }
688
689 #[test]
690 fn should_hide_dev_key_always_true_in_production() {
691 let _guard = ENV_LOCK.lock().unwrap();
692 clear_auth_env();
693 env::set_var("SHODH_ENV", "production");
694 assert!(should_hide_dev_key());
696 clear_auth_env();
697 }
698
699 #[tokio::test]
702 async fn auth_middleware_accepts_query_param_for_websocket() {
703 use axum::body::Body;
704 use axum::http::Request as HttpRequest;
705 use axum::middleware::from_fn;
706 use axum::routing::get;
707 use axum::Router;
708 use tower::ServiceExt;
709
710 let _guard = ENV_LOCK.lock().unwrap();
711 clear_auth_env();
712 env::set_var("SHODH_API_KEYS", "test-ws-key");
713
714 let app = Router::new()
715 .route("/api/stream", get(|| async { "ok" }))
716 .layer(from_fn(auth_middleware));
717
718 let req = HttpRequest::builder()
720 .uri("/api/stream?api_key=test-ws-key")
721 .header("upgrade", "websocket")
722 .body(Body::empty())
723 .unwrap();
724 let resp = app.oneshot(req).await.unwrap();
725 assert_eq!(
726 resp.status(),
727 StatusCode::OK,
728 "Should accept API key from query parameter on WebSocket upgrade"
729 );
730
731 clear_auth_env();
732 }
733
734 #[tokio::test]
735 async fn auth_middleware_ignores_query_param_without_websocket_upgrade() {
736 use axum::body::Body;
737 use axum::http::Request as HttpRequest;
738 use axum::middleware::from_fn;
739 use axum::routing::get;
740 use axum::Router;
741 use tower::ServiceExt;
742
743 let _guard = ENV_LOCK.lock().unwrap();
744 clear_auth_env();
745 env::set_var("SHODH_API_KEYS", "test-ws-key");
746
747 let app = Router::new()
748 .route("/api/remember", get(|| async { "ok" }))
749 .layer(from_fn(auth_middleware));
750
751 let req = HttpRequest::builder()
753 .uri("/api/remember?api_key=test-ws-key")
754 .body(Body::empty())
755 .unwrap();
756 let resp = app.oneshot(req).await.unwrap();
757 assert_eq!(
758 resp.status(),
759 StatusCode::UNAUTHORIZED,
760 "Query param auth should be ignored for non-WebSocket requests"
761 );
762
763 clear_auth_env();
764 }
765
766 #[tokio::test]
767 async fn auth_middleware_rejects_invalid_websocket_query_param() {
768 use axum::body::Body;
769 use axum::http::Request as HttpRequest;
770 use axum::middleware::from_fn;
771 use axum::routing::get;
772 use axum::Router;
773 use tower::ServiceExt;
774
775 let _guard = ENV_LOCK.lock().unwrap();
776 clear_auth_env();
777 env::set_var("SHODH_API_KEYS", "correct-key");
778
779 let app = Router::new()
780 .route("/api/stream", get(|| async { "ok" }))
781 .layer(from_fn(auth_middleware));
782
783 let req = HttpRequest::builder()
784 .uri("/api/stream?api_key=wrong-key")
785 .header("upgrade", "websocket")
786 .body(Body::empty())
787 .unwrap();
788 let resp = app.oneshot(req).await.unwrap();
789 assert_eq!(
790 resp.status(),
791 StatusCode::UNAUTHORIZED,
792 "Should reject invalid query parameter API key on WebSocket"
793 );
794
795 clear_auth_env();
796 }
797}