Skip to main content

velesdb_server/
auth.rs

1//! API key authentication middleware.
2//!
3//! When `api_keys` is non-empty, all requests except those to public paths
4//! (e.g. `GET /health`) must include a valid `Authorization: Bearer <key>` header.
5//! When `api_keys` is empty, authentication is disabled (local dev mode).
6
7use axum::{
8    body::Body,
9    extract::Request,
10    http::{header, StatusCode},
11    middleware::Next,
12    response::{IntoResponse, Response},
13    Json,
14};
15use std::sync::Arc;
16
17/// Constant-time byte comparison to prevent timing side-channel attacks.
18///
19/// Compares two byte slices in constant time relative to the length of `a`.
20/// Returns `true` only when both slices have equal length and identical contents.
21/// Uses XOR-and-fold so that the comparison does not short-circuit on the first
22/// differing byte.
23fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
24    if a.len() != b.len() {
25        // Length mismatch leaks the length difference, but not the key contents.
26        // This is acceptable: an attacker already controls `b` (the submitted
27        // token) and can trivially discover the expected length via other means
28        // (e.g. documentation). The critical property is that *content* is never
29        // leaked through timing.
30        return false;
31    }
32
33    let mut acc: u8 = 0;
34    for (x, y) in a.iter().zip(b.iter()) {
35        acc |= x ^ y;
36    }
37    acc == 0
38}
39
40/// Checks whether `token` matches any configured API key in constant time.
41///
42/// Iterates over **all** keys regardless of early matches to avoid leaking
43/// which key (if any) was correct through timing differences.
44fn any_key_matches(keys: &[String], token: &str) -> bool {
45    let token_bytes = token.as_bytes();
46    let mut matched = false;
47    for key in keys {
48        if constant_time_eq(key.as_bytes(), token_bytes) {
49            matched = true;
50        }
51        // Do NOT early-return — iterate all keys unconditionally.
52    }
53    matched
54}
55
56/// Shared authentication state injected into the middleware.
57#[derive(Debug, Clone)]
58pub struct AuthState {
59    /// Allowed API keys. Empty means auth is disabled.
60    pub api_keys: Arc<Vec<String>>,
61}
62
63impl AuthState {
64    /// Create a new `AuthState` from a list of API keys.
65    pub fn new(api_keys: Vec<String>) -> Self {
66        Self {
67            api_keys: Arc::new(api_keys),
68        }
69    }
70
71    /// Returns `true` when authentication is enabled.
72    pub fn auth_enabled(&self) -> bool {
73        !self.api_keys.is_empty()
74    }
75}
76
77/// Paths that bypass authentication.
78fn is_public_path(path: &str) -> bool {
79    path == "/health" || path == "/ready" || path == "/metrics"
80}
81
82/// Extract the Bearer token from the Authorization header value.
83fn extract_bearer_token(header_value: &str) -> Option<&str> {
84    let trimmed = header_value.trim();
85    if trimmed.len() > 7 && trimmed[..7].eq_ignore_ascii_case("bearer ") {
86        let token = trimmed[7..].trim();
87        if token.is_empty() {
88            None
89        } else {
90            Some(token)
91        }
92    } else {
93        None
94    }
95}
96
97/// Axum middleware function for API key authentication.
98///
99/// Use with `axum::middleware::from_fn_with_state`.
100pub async fn auth_middleware(
101    axum::extract::State(state): axum::extract::State<AuthState>,
102    request: Request<Body>,
103    next: Next,
104) -> Response {
105    // Skip auth if disabled (no keys configured)
106    if !state.auth_enabled() {
107        return next.run(request).await;
108    }
109
110    // Skip auth for public paths
111    if is_public_path(request.uri().path()) {
112        return next.run(request).await;
113    }
114
115    // Extract and validate Bearer token
116    let auth_header = request
117        .headers()
118        .get(header::AUTHORIZATION)
119        .and_then(|v| v.to_str().ok());
120
121    match auth_header {
122        Some(value) => match extract_bearer_token(value) {
123            Some(token) if any_key_matches(&state.api_keys, token) => next.run(request).await,
124            Some(_) => unauthorized_response("invalid API key"),
125            None => {
126                unauthorized_response("invalid Authorization header format, expected: Bearer <key>")
127            }
128        },
129        None => unauthorized_response("missing Authorization header"),
130    }
131}
132
133/// Build a 401 Unauthorized JSON response.
134fn unauthorized_response(message: &str) -> Response {
135    (
136        StatusCode::UNAUTHORIZED,
137        Json(serde_json::json!({
138            "error": "Unauthorized",
139            "message": message
140        })),
141    )
142        .into_response()
143}
144
145// ============================================================================
146// Tests
147// ============================================================================
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_auth_state_disabled_when_empty() {
155        let state = AuthState::new(vec![]);
156        assert!(!state.auth_enabled());
157    }
158
159    #[test]
160    fn test_auth_state_enabled_with_keys() {
161        let state = AuthState::new(vec!["key1".to_string()]);
162        assert!(state.auth_enabled());
163    }
164
165    #[test]
166    fn test_is_public_path_health() {
167        assert!(is_public_path("/health"));
168    }
169
170    #[test]
171    fn test_is_public_path_ready() {
172        assert!(is_public_path("/ready"));
173    }
174
175    #[test]
176    fn test_is_public_path_metrics() {
177        assert!(is_public_path("/metrics"));
178    }
179
180    #[test]
181    fn test_is_public_path_other() {
182        assert!(!is_public_path("/collections"));
183        assert!(!is_public_path("/query"));
184        assert!(!is_public_path("/health/extra"));
185    }
186
187    #[test]
188    fn test_extract_bearer_token_valid() {
189        assert_eq!(extract_bearer_token("Bearer my-key"), Some("my-key"));
190        assert_eq!(extract_bearer_token("bearer my-key"), Some("my-key"));
191        assert_eq!(extract_bearer_token("BEARER my-key"), Some("my-key"));
192        assert_eq!(extract_bearer_token("  Bearer  my-key  "), Some("my-key"));
193    }
194
195    #[test]
196    fn test_extract_bearer_token_invalid() {
197        assert_eq!(extract_bearer_token("Basic abc123"), None);
198        assert_eq!(extract_bearer_token("my-key"), None);
199        assert_eq!(extract_bearer_token("Bearer"), None);
200        assert_eq!(extract_bearer_token(""), None);
201    }
202
203    #[test]
204    fn test_extract_bearer_token_whitespace_only() {
205        assert_eq!(extract_bearer_token("Bearer   "), None);
206    }
207
208    // ========================================================================
209    // Constant-time comparison tests
210    // ========================================================================
211
212    #[test]
213    fn test_constant_time_eq_identical() {
214        assert!(constant_time_eq(b"secret-key-42", b"secret-key-42"));
215    }
216
217    #[test]
218    fn test_constant_time_eq_different_content() {
219        assert!(!constant_time_eq(b"secret-key-42", b"secret-key-43"));
220    }
221
222    #[test]
223    fn test_constant_time_eq_different_length() {
224        assert!(!constant_time_eq(b"short", b"longer-key"));
225    }
226
227    #[test]
228    fn test_constant_time_eq_empty() {
229        assert!(constant_time_eq(b"", b""));
230    }
231
232    #[test]
233    fn test_any_key_matches_found() {
234        let keys = vec!["key-a".to_string(), "key-b".to_string()];
235        assert!(any_key_matches(&keys, "key-b"));
236    }
237
238    #[test]
239    fn test_any_key_matches_not_found() {
240        let keys = vec!["key-a".to_string(), "key-b".to_string()];
241        assert!(!any_key_matches(&keys, "key-c"));
242    }
243
244    #[test]
245    fn test_any_key_matches_empty_keys() {
246        let keys: Vec<String> = vec![];
247        assert!(!any_key_matches(&keys, "anything"));
248    }
249}