Skip to main content

shodh_memory/
auth.rs

1use axum::{
2    extract::Request,
3    http::StatusCode,
4    middleware::Next,
5    response::{IntoResponse, Response},
6    Json,
7};
8use std::env;
9
10use crate::errors::ErrorResponse;
11
12/// Default API key for development when no key env vars are configured.
13/// Visibility is crate-only — not exposed in the public API surface.
14pub(crate) const DEFAULT_DEV_API_KEY: &str = "sk-shodh-dev-default";
15
16/// Check if running in production mode
17pub fn is_production_mode() -> bool {
18    env::var("SHODH_ENV")
19        .map(|v| v.to_lowercase() == "production" || v.to_lowercase() == "prod")
20        .unwrap_or(false)
21}
22
23/// Check if dev key should be hidden from error messages.
24///
25/// Returns true when SHODH_HIDE_DEV_KEY=true (opt-in).
26/// In production mode, always returns true regardless of the env var.
27fn should_hide_dev_key() -> bool {
28    if is_production_mode() {
29        return true;
30    }
31    env::var("SHODH_HIDE_DEV_KEY")
32        .map(|v| v.to_lowercase() == "true" || v == "1")
33        .unwrap_or(false)
34}
35
36/// Log security warnings at startup based on environment configuration
37pub fn log_security_status() {
38    let has_api_keys = env::var("SHODH_API_KEYS")
39        .map(|k| !k.trim().is_empty())
40        .unwrap_or(false);
41    let has_dev_key = env::var("SHODH_DEV_API_KEY")
42        .map(|k| !k.trim().is_empty())
43        .unwrap_or(false);
44    let is_prod = is_production_mode();
45
46    if is_prod {
47        if has_api_keys {
48            tracing::info!("Running in PRODUCTION mode with API key authentication");
49        } else {
50            tracing::error!(
51                "PRODUCTION mode but SHODH_API_KEYS not set! Server will reject all authenticated requests."
52            );
53        }
54    } else {
55        tracing::warn!("╔════════════════════════════════════════════════════════════════╗");
56        tracing::warn!("║  SECURITY WARNING: Running in DEVELOPMENT mode                 ║");
57        tracing::warn!("║                                                                ║");
58        if has_dev_key {
59            tracing::warn!("║  Using SHODH_DEV_API_KEY for authentication.                  ║");
60            tracing::warn!("║  DO NOT use this configuration in production!                 ║");
61        } else if !has_api_keys {
62            tracing::warn!("║  No API keys configured. Using default dev key.              ║");
63            tracing::warn!("║  DEPRECATION: Default dev key will be removed in v0.2.0.     ║");
64            tracing::warn!("║  Set SHODH_DEV_API_KEY or SHODH_API_KEYS to override.        ║");
65            tracing::warn!("║  Set SHODH_HIDE_DEV_KEY=true to hide key from error msgs.    ║");
66        }
67        tracing::warn!("║                                                                ║");
68        tracing::warn!("║  For production, set:                                          ║");
69        tracing::warn!("║    SHODH_ENV=production                                        ║");
70        tracing::warn!("║    SHODH_API_KEYS=your-secure-key-1,your-secure-key-2          ║");
71        tracing::warn!("╚════════════════════════════════════════════════════════════════╝");
72    }
73}
74
75/// API Key authentication errors
76#[derive(Debug)]
77pub enum AuthError {
78    MissingApiKey,
79    InvalidApiKey,
80    NotConfigured,
81}
82
83impl AuthError {
84    fn code(&self) -> &'static str {
85        match self {
86            Self::MissingApiKey => "MISSING_API_KEY",
87            Self::InvalidApiKey => "INVALID_API_KEY",
88            Self::NotConfigured => "AUTH_NOT_CONFIGURED",
89        }
90    }
91
92    fn status_code(&self) -> StatusCode {
93        match self {
94            Self::MissingApiKey | Self::InvalidApiKey => StatusCode::UNAUTHORIZED,
95            Self::NotConfigured => StatusCode::SERVICE_UNAVAILABLE,
96        }
97    }
98}
99
100impl IntoResponse for AuthError {
101    fn into_response(self) -> Response {
102        let is_prod = is_production_mode();
103        let status = self.status_code();
104
105        let message = match &self {
106            AuthError::MissingApiKey => {
107                if is_prod {
108                    "Missing X-API-Key header".to_string()
109                } else if should_hide_dev_key() {
110                    "Missing X-API-Key header. Set SHODH_DEV_API_KEY or SHODH_API_KEYS. \
111                     See docs for setup."
112                        .to_string()
113                } else {
114                    format!(
115                        "Missing X-API-Key header. Set the header in your request. \
116                         The server accepts keys from SHODH_API_KEYS (comma-separated) \
117                         or SHODH_DEV_API_KEY. Default dev key: '{}'",
118                        DEFAULT_DEV_API_KEY
119                    )
120                }
121            }
122            AuthError::InvalidApiKey => {
123                if is_prod {
124                    "Invalid API key".to_string()
125                } else if should_hide_dev_key() {
126                    "Invalid API key. Check SHODH_DEV_API_KEY or SHODH_API_KEYS.".to_string()
127                } else {
128                    format!(
129                        "Invalid API key. Expected a key from SHODH_API_KEYS or \
130                         SHODH_DEV_API_KEY. Default dev key: '{}'",
131                        DEFAULT_DEV_API_KEY
132                    )
133                }
134            }
135            AuthError::NotConfigured => {
136                "API keys not configured. Set SHODH_API_KEYS environment variable.".to_string()
137            }
138        };
139
140        let body = ErrorResponse {
141            code: self.code().to_string(),
142            message,
143            details: None,
144            request_id: None,
145        };
146
147        (status, Json(body)).into_response()
148    }
149}
150
151/// Constant-time string comparison to prevent timing attacks
152///
153/// Compares all bytes of both strings to prevent length-based timing leaks.
154/// The comparison time is constant regardless of where differences occur.
155fn constant_time_compare(a: &str, b: &str) -> bool {
156    let a_bytes = a.as_bytes();
157    let b_bytes = b.as_bytes();
158    let a_len = a_bytes.len();
159    let b_len = b_bytes.len();
160    let max_len = std::cmp::max(a_len, b_len);
161
162    // Track whether lengths match (0 if equal, non-zero otherwise)
163    // Use u32 to avoid truncation: (usize as u8) wraps at 256, so lengths
164    // differing by a multiple of 256 would falsely compare as equal.
165    let mut result: u32 = (a_len ^ b_len) as u32;
166
167    // Compare all bytes up to max_len, using 0 for out-of-bounds indices
168    // This ensures constant time regardless of actual lengths
169    for i in 0..max_len {
170        let byte_a = if i < a_len { a_bytes[i] } else { 0 };
171        let byte_b = if i < b_len { b_bytes[i] } else { 0 };
172        result |= (byte_a ^ byte_b) as u32;
173    }
174
175    result == 0
176}
177
178/// Validate API key against configured keys using constant-time comparison
179pub fn validate_api_key(provided_key: &str) -> Result<(), AuthError> {
180    // Get API keys from environment (comma-separated for multiple keys)
181    let valid_keys = match env::var("SHODH_API_KEYS") {
182        Ok(keys) if !keys.trim().is_empty() => keys,
183        _ => {
184            // In production, refuse to start without API keys
185            let is_production = env::var("SHODH_ENV")
186                .map(|v| v.to_lowercase() == "production" || v.to_lowercase() == "prod")
187                .unwrap_or(false);
188
189            if is_production {
190                tracing::error!("SHODH_API_KEYS not set in production mode");
191                return Err(AuthError::NotConfigured);
192            }
193
194            // Development mode: use SHODH_DEV_API_KEY, or fall back to built-in default
195            match env::var("SHODH_DEV_API_KEY") {
196                Ok(key) if !key.trim().is_empty() => {
197                    tracing::warn!("Using SHODH_DEV_API_KEY for development (not for production!)");
198                    key
199                }
200                _ => {
201                    tracing::warn!(
202                        "No API key configured. Falling back to default dev key. \
203                         Set SHODH_DEV_API_KEY to override."
204                    );
205                    DEFAULT_DEV_API_KEY.to_string()
206                }
207            }
208        }
209    };
210
211    let keys: Vec<&str> = valid_keys.split(',').map(|k| k.trim()).collect();
212
213    // Use constant-time comparison to prevent timing attacks
214    let mut found = false;
215    for key in &keys {
216        if constant_time_compare(key, provided_key) {
217            found = true;
218            // Don't break early - continue checking to maintain constant time
219        }
220    }
221
222    if found {
223        Ok(())
224    } else {
225        Err(AuthError::InvalidApiKey)
226    }
227}
228
229/// Authentication middleware
230pub async fn auth_middleware(request: Request, next: Next) -> Response {
231    let path = request.uri().path();
232
233    // Skip auth for health endpoint
234    if path == "/health" {
235        return next.run(request).await;
236    }
237
238    // Skip API key auth for webhook endpoints (they use HMAC signature verification)
239    if path.starts_with("/webhook/") {
240        return next.run(request).await;
241    }
242
243    // Extract API key: try X-API-Key header first, then Authorization: Bearer,
244    // then query parameter (for WebSocket connections where headers aren't supported)
245    let api_key_value = match request
246        .headers()
247        .get("X-API-Key")
248        .and_then(|v| v.to_str().ok())
249        .map(|s| s.to_string())
250        .or_else(|| {
251            request
252                .headers()
253                .get("Authorization")
254                .and_then(|v| v.to_str().ok())
255                .and_then(|s| s.strip_prefix("Bearer "))
256                .map(|s| s.to_string())
257        })
258        .or_else(|| {
259            // WebSocket fallback: check query parameter for api_key
260            // Browser WebSocket API doesn't support custom headers, so
261            // clients can pass ?api_key=... in the URL instead.
262            // ONLY allow this for WebSocket upgrades to prevent API key
263            // leakage via URLs in server logs, browser history, and referrer headers.
264            let is_websocket = request
265                .headers()
266                .get("upgrade")
267                .and_then(|v| v.to_str().ok())
268                .map(|v| v.eq_ignore_ascii_case("websocket"))
269                .unwrap_or(false);
270            if !is_websocket {
271                return None;
272            }
273            request.uri().query().and_then(|q| {
274                q.split('&')
275                    .find_map(|pair| pair.strip_prefix("api_key=").map(|v| v.to_string()))
276            })
277        }) {
278        Some(key) => key,
279        None => return AuthError::MissingApiKey.into_response(),
280    };
281
282    // Validate the cloned key
283    if let Err(e) = validate_api_key(&api_key_value) {
284        return e.into_response();
285    }
286
287    // Now we can move request to next layer
288    next.run(request).await
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use axum::body::to_bytes;
295    use std::sync::Mutex;
296
297    /// Process-global lock for tests that manipulate environment variables.
298    /// `env::set_var` / `env::remove_var` are not thread-safe, so all tests
299    /// that touch auth env vars must hold this lock for the duration of the test.
300    static ENV_LOCK: Mutex<()> = Mutex::new(());
301
302    /// Clear all auth-related env vars to isolate tests.
303    /// Caller MUST hold `ENV_LOCK` — this is not enforced at compile time.
304    fn clear_auth_env() {
305        env::remove_var("SHODH_API_KEYS");
306        env::remove_var("SHODH_DEV_API_KEY");
307        env::remove_var("SHODH_ENV");
308        env::remove_var("SHODH_HIDE_DEV_KEY");
309    }
310
311    // ── constant_time_compare ──
312
313    #[test]
314    fn constant_time_equal_strings() {
315        assert!(constant_time_compare("hello", "hello"));
316    }
317
318    #[test]
319    fn constant_time_different_strings() {
320        assert!(!constant_time_compare("hello", "world"));
321    }
322
323    #[test]
324    fn constant_time_different_lengths() {
325        assert!(!constant_time_compare("short", "a-longer-string"));
326    }
327
328    #[test]
329    fn constant_time_empty_strings() {
330        assert!(constant_time_compare("", ""));
331    }
332
333    #[test]
334    fn constant_time_one_empty() {
335        assert!(!constant_time_compare("", "notempty"));
336        assert!(!constant_time_compare("notempty", ""));
337    }
338
339    #[test]
340    fn constant_time_length_multiple_of_256() {
341        // Regression: (256 ^ 0) as u8 == 0, so the old u8 accumulator
342        // would falsely treat a 256-byte string as equal to an empty string.
343        let long = "a".repeat(256);
344        assert!(!constant_time_compare(&long, ""));
345        assert!(!constant_time_compare("", &long));
346
347        // Also test 512 vs 256 (difference = 256, wraps to 0 in u8)
348        let medium = "b".repeat(256);
349        let longer = "b".repeat(512);
350        assert!(!constant_time_compare(&medium, &longer));
351    }
352
353    // ── is_production_mode ──
354
355    #[test]
356    fn production_mode_detection() {
357        let _guard = ENV_LOCK.lock().unwrap();
358        clear_auth_env();
359
360        assert!(!is_production_mode());
361
362        env::set_var("SHODH_ENV", "production");
363        assert!(is_production_mode());
364
365        env::set_var("SHODH_ENV", "prod");
366        assert!(is_production_mode());
367
368        env::set_var("SHODH_ENV", "PRODUCTION");
369        assert!(is_production_mode());
370
371        env::set_var("SHODH_ENV", "development");
372        assert!(!is_production_mode());
373
374        env::set_var("SHODH_ENV", "test");
375        assert!(!is_production_mode());
376
377        clear_auth_env();
378    }
379
380    // ── validate_api_key: SHODH_API_KEYS ──
381
382    #[test]
383    fn validate_with_single_api_key() {
384        let _guard = ENV_LOCK.lock().unwrap();
385        clear_auth_env();
386        env::set_var("SHODH_API_KEYS", "my-key");
387        assert!(validate_api_key("my-key").is_ok());
388        assert!(validate_api_key("wrong").is_err());
389        clear_auth_env();
390    }
391
392    #[test]
393    fn validate_with_multiple_api_keys() {
394        let _guard = ENV_LOCK.lock().unwrap();
395        clear_auth_env();
396        env::set_var("SHODH_API_KEYS", "key1,key2,key3");
397        assert!(validate_api_key("key1").is_ok());
398        assert!(validate_api_key("key2").is_ok());
399        assert!(validate_api_key("key3").is_ok());
400        assert!(validate_api_key("key4").is_err());
401        clear_auth_env();
402    }
403
404    #[test]
405    fn validate_api_keys_trims_whitespace() {
406        let _guard = ENV_LOCK.lock().unwrap();
407        clear_auth_env();
408        env::set_var("SHODH_API_KEYS", " key1 , key2 ");
409        assert!(validate_api_key("key1").is_ok());
410        assert!(validate_api_key("key2").is_ok());
411        clear_auth_env();
412    }
413
414    // ── validate_api_key: dev key ──
415
416    #[test]
417    fn validate_with_dev_key() {
418        let _guard = ENV_LOCK.lock().unwrap();
419        clear_auth_env();
420        env::set_var("SHODH_DEV_API_KEY", "dev-key-123");
421        assert!(validate_api_key("dev-key-123").is_ok());
422        assert!(validate_api_key("wrong").is_err());
423        clear_auth_env();
424    }
425
426    // ── validate_api_key: default dev key ──
427
428    #[test]
429    fn validate_with_default_dev_key_when_no_env_set() {
430        let _guard = ENV_LOCK.lock().unwrap();
431        clear_auth_env();
432        assert!(validate_api_key(DEFAULT_DEV_API_KEY).is_ok());
433        assert!(validate_api_key("wrong-key").is_err());
434        clear_auth_env();
435    }
436
437    // ── validate_api_key: production mode ──
438
439    #[test]
440    fn validate_production_rejects_when_no_keys() {
441        let _guard = ENV_LOCK.lock().unwrap();
442        clear_auth_env();
443        env::set_var("SHODH_ENV", "production");
444        let result = validate_api_key("any-key");
445        assert!(result.is_err());
446        match result.unwrap_err() {
447            AuthError::NotConfigured => {}
448            other => panic!("Expected NotConfigured, got {:?}", other),
449        }
450        clear_auth_env();
451    }
452
453    #[test]
454    fn validate_production_works_with_api_keys_set() {
455        let _guard = ENV_LOCK.lock().unwrap();
456        clear_auth_env();
457        env::set_var("SHODH_ENV", "production");
458        env::set_var("SHODH_API_KEYS", "prod-key");
459        assert!(validate_api_key("prod-key").is_ok());
460        assert!(validate_api_key("wrong").is_err());
461        clear_auth_env();
462    }
463
464    // ── validate_api_key: edge cases ──
465
466    #[test]
467    fn validate_empty_api_keys_falls_through() {
468        let _guard = ENV_LOCK.lock().unwrap();
469        clear_auth_env();
470        env::set_var("SHODH_API_KEYS", "  ");
471        // Empty SHODH_API_KEYS falls through to dev key / default
472        assert!(validate_api_key(DEFAULT_DEV_API_KEY).is_ok());
473        clear_auth_env();
474    }
475
476    #[test]
477    fn validate_empty_dev_key_uses_default() {
478        let _guard = ENV_LOCK.lock().unwrap();
479        clear_auth_env();
480        env::set_var("SHODH_DEV_API_KEY", "  ");
481        assert!(validate_api_key(DEFAULT_DEV_API_KEY).is_ok());
482        clear_auth_env();
483    }
484
485    #[test]
486    fn api_keys_takes_priority_over_dev_key() {
487        let _guard = ENV_LOCK.lock().unwrap();
488        clear_auth_env();
489        env::set_var("SHODH_API_KEYS", "prod-key");
490        env::set_var("SHODH_DEV_API_KEY", "dev-key");
491        assert!(validate_api_key("prod-key").is_ok());
492        assert!(validate_api_key("dev-key").is_err()); // dev key ignored
493        clear_auth_env();
494    }
495
496    // ── AuthError response codes ──
497
498    #[test]
499    fn auth_error_status_codes() {
500        assert_eq!(
501            AuthError::MissingApiKey.status_code(),
502            StatusCode::UNAUTHORIZED
503        );
504        assert_eq!(
505            AuthError::InvalidApiKey.status_code(),
506            StatusCode::UNAUTHORIZED
507        );
508        assert_eq!(
509            AuthError::NotConfigured.status_code(),
510            StatusCode::SERVICE_UNAVAILABLE
511        );
512    }
513
514    #[test]
515    fn auth_error_codes() {
516        assert_eq!(AuthError::MissingApiKey.code(), "MISSING_API_KEY");
517        assert_eq!(AuthError::InvalidApiKey.code(), "INVALID_API_KEY");
518        assert_eq!(AuthError::NotConfigured.code(), "AUTH_NOT_CONFIGURED");
519    }
520
521    // ── AuthError JSON response shape ──
522
523    #[tokio::test]
524    async fn auth_error_response_is_valid_json() {
525        let _guard = ENV_LOCK.lock().unwrap();
526        clear_auth_env();
527        let resp = AuthError::MissingApiKey.into_response();
528        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
529
530        let body = to_bytes(resp.into_body(), 2048).await.unwrap();
531        let parsed: ErrorResponse = serde_json::from_slice(&body)
532            .expect("Response body should be valid JSON matching ErrorResponse");
533        assert_eq!(parsed.code, "MISSING_API_KEY");
534        assert!(parsed.message.contains("X-API-Key"));
535        clear_auth_env();
536    }
537
538    #[tokio::test]
539    async fn missing_key_dev_message_includes_help() {
540        let _guard = ENV_LOCK.lock().unwrap();
541        clear_auth_env();
542        // Not production → should include env var names in message
543        let resp = AuthError::MissingApiKey.into_response();
544        let body = to_bytes(resp.into_body(), 2048).await.unwrap();
545        let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
546        assert!(
547            parsed.message.contains("SHODH_API_KEYS"),
548            "Should mention SHODH_API_KEYS"
549        );
550        assert!(
551            parsed.message.contains("SHODH_DEV_API_KEY"),
552            "Should mention SHODH_DEV_API_KEY"
553        );
554        assert!(
555            parsed.message.contains(DEFAULT_DEV_API_KEY),
556            "Should show the default dev key"
557        );
558        clear_auth_env();
559    }
560
561    #[tokio::test]
562    async fn invalid_key_dev_message_includes_help() {
563        let _guard = ENV_LOCK.lock().unwrap();
564        clear_auth_env();
565        let resp = AuthError::InvalidApiKey.into_response();
566        let body = to_bytes(resp.into_body(), 2048).await.unwrap();
567        let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
568        assert!(
569            parsed.message.contains("SHODH_API_KEYS"),
570            "Should mention SHODH_API_KEYS"
571        );
572        assert!(
573            parsed.message.contains(DEFAULT_DEV_API_KEY),
574            "Should show the default dev key"
575        );
576        clear_auth_env();
577    }
578
579    #[tokio::test]
580    async fn missing_key_prod_message_is_terse() {
581        let _guard = ENV_LOCK.lock().unwrap();
582        clear_auth_env();
583        env::set_var("SHODH_ENV", "production");
584        let resp = AuthError::MissingApiKey.into_response();
585        let body = to_bytes(resp.into_body(), 2048).await.unwrap();
586        let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
587        assert_eq!(parsed.message, "Missing X-API-Key header");
588        assert!(
589            !parsed.message.contains("SHODH_DEV_API_KEY"),
590            "Prod must not leak env var names"
591        );
592        clear_auth_env();
593    }
594
595    #[tokio::test]
596    async fn invalid_key_prod_message_is_terse() {
597        let _guard = ENV_LOCK.lock().unwrap();
598        clear_auth_env();
599        env::set_var("SHODH_ENV", "production");
600        let resp = AuthError::InvalidApiKey.into_response();
601        let body = to_bytes(resp.into_body(), 2048).await.unwrap();
602        let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
603        assert_eq!(parsed.message, "Invalid API key");
604        assert!(
605            !parsed.message.contains(DEFAULT_DEV_API_KEY),
606            "Prod must not leak default key"
607        );
608        clear_auth_env();
609    }
610
611    #[tokio::test]
612    async fn not_configured_response_shape() {
613        let resp = AuthError::NotConfigured.into_response();
614        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
615        let body = to_bytes(resp.into_body(), 2048).await.unwrap();
616        let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
617        assert_eq!(parsed.code, "AUTH_NOT_CONFIGURED");
618        assert!(parsed.message.contains("SHODH_API_KEYS"));
619    }
620
621    // ── SHODH_HIDE_DEV_KEY ──
622
623    #[tokio::test]
624    async fn hide_dev_key_suppresses_key_in_missing_key_error() {
625        let _guard = ENV_LOCK.lock().unwrap();
626        clear_auth_env();
627        env::set_var("SHODH_HIDE_DEV_KEY", "true");
628
629        let resp = AuthError::MissingApiKey.into_response();
630        let body = to_bytes(resp.into_body(), 2048).await.unwrap();
631        let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
632
633        assert!(
634            !parsed.message.contains(DEFAULT_DEV_API_KEY),
635            "SHODH_HIDE_DEV_KEY=true should suppress key in error: {}",
636            parsed.message
637        );
638        assert!(
639            parsed.message.contains("SHODH_DEV_API_KEY"),
640            "Should still mention env var name: {}",
641            parsed.message
642        );
643        clear_auth_env();
644    }
645
646    #[tokio::test]
647    async fn hide_dev_key_suppresses_key_in_invalid_key_error() {
648        let _guard = ENV_LOCK.lock().unwrap();
649        clear_auth_env();
650        env::set_var("SHODH_HIDE_DEV_KEY", "true");
651
652        let resp = AuthError::InvalidApiKey.into_response();
653        let body = to_bytes(resp.into_body(), 2048).await.unwrap();
654        let parsed: ErrorResponse = serde_json::from_slice(&body).unwrap();
655
656        assert!(
657            !parsed.message.contains(DEFAULT_DEV_API_KEY),
658            "SHODH_HIDE_DEV_KEY=true should suppress key in error: {}",
659            parsed.message
660        );
661        clear_auth_env();
662    }
663
664    #[test]
665    fn should_hide_dev_key_defaults_to_false() {
666        let _guard = ENV_LOCK.lock().unwrap();
667        clear_auth_env();
668        assert!(!should_hide_dev_key());
669        clear_auth_env();
670    }
671
672    #[test]
673    fn should_hide_dev_key_respects_env_var() {
674        let _guard = ENV_LOCK.lock().unwrap();
675        clear_auth_env();
676
677        env::set_var("SHODH_HIDE_DEV_KEY", "true");
678        assert!(should_hide_dev_key());
679
680        env::set_var("SHODH_HIDE_DEV_KEY", "1");
681        assert!(should_hide_dev_key());
682
683        env::set_var("SHODH_HIDE_DEV_KEY", "false");
684        assert!(!should_hide_dev_key());
685
686        clear_auth_env();
687    }
688
689    #[test]
690    fn should_hide_dev_key_always_true_in_production() {
691        let _guard = ENV_LOCK.lock().unwrap();
692        clear_auth_env();
693        env::set_var("SHODH_ENV", "production");
694        // Even without SHODH_HIDE_DEV_KEY, production always hides
695        assert!(should_hide_dev_key());
696        clear_auth_env();
697    }
698
699    // ── Query parameter auth (WebSocket fallback) ──
700
701    #[tokio::test]
702    async fn auth_middleware_accepts_query_param_for_websocket() {
703        use axum::body::Body;
704        use axum::http::Request as HttpRequest;
705        use axum::middleware::from_fn;
706        use axum::routing::get;
707        use axum::Router;
708        use tower::ServiceExt;
709
710        let _guard = ENV_LOCK.lock().unwrap();
711        clear_auth_env();
712        env::set_var("SHODH_API_KEYS", "test-ws-key");
713
714        let app = Router::new()
715            .route("/api/stream", get(|| async { "ok" }))
716            .layer(from_fn(auth_middleware));
717
718        // WebSocket upgrade with API key in query parameter
719        let req = HttpRequest::builder()
720            .uri("/api/stream?api_key=test-ws-key")
721            .header("upgrade", "websocket")
722            .body(Body::empty())
723            .unwrap();
724        let resp = app.oneshot(req).await.unwrap();
725        assert_eq!(
726            resp.status(),
727            StatusCode::OK,
728            "Should accept API key from query parameter on WebSocket upgrade"
729        );
730
731        clear_auth_env();
732    }
733
734    #[tokio::test]
735    async fn auth_middleware_ignores_query_param_without_websocket_upgrade() {
736        use axum::body::Body;
737        use axum::http::Request as HttpRequest;
738        use axum::middleware::from_fn;
739        use axum::routing::get;
740        use axum::Router;
741        use tower::ServiceExt;
742
743        let _guard = ENV_LOCK.lock().unwrap();
744        clear_auth_env();
745        env::set_var("SHODH_API_KEYS", "test-ws-key");
746
747        let app = Router::new()
748            .route("/api/remember", get(|| async { "ok" }))
749            .layer(from_fn(auth_middleware));
750
751        // Non-WebSocket request with API key in query parameter — should be ignored
752        let req = HttpRequest::builder()
753            .uri("/api/remember?api_key=test-ws-key")
754            .body(Body::empty())
755            .unwrap();
756        let resp = app.oneshot(req).await.unwrap();
757        assert_eq!(
758            resp.status(),
759            StatusCode::UNAUTHORIZED,
760            "Query param auth should be ignored for non-WebSocket requests"
761        );
762
763        clear_auth_env();
764    }
765
766    #[tokio::test]
767    async fn auth_middleware_rejects_invalid_websocket_query_param() {
768        use axum::body::Body;
769        use axum::http::Request as HttpRequest;
770        use axum::middleware::from_fn;
771        use axum::routing::get;
772        use axum::Router;
773        use tower::ServiceExt;
774
775        let _guard = ENV_LOCK.lock().unwrap();
776        clear_auth_env();
777        env::set_var("SHODH_API_KEYS", "correct-key");
778
779        let app = Router::new()
780            .route("/api/stream", get(|| async { "ok" }))
781            .layer(from_fn(auth_middleware));
782
783        let req = HttpRequest::builder()
784            .uri("/api/stream?api_key=wrong-key")
785            .header("upgrade", "websocket")
786            .body(Body::empty())
787            .unwrap();
788        let resp = app.oneshot(req).await.unwrap();
789        assert_eq!(
790            resp.status(),
791            StatusCode::UNAUTHORIZED,
792            "Should reject invalid query parameter API key on WebSocket"
793        );
794
795        clear_auth_env();
796    }
797}