Skip to main content

spec_ai/spec_ai_api/api/
auth.rs

1//! Authentication module for the HTTP API
2//!
3//! Provides:
4//! - User credential management (loaded from JSON file)
5//! - Password verification using PBKDF2-HMAC-SHA256
6//! - Bearer token generation and validation using HMAC-SHA256
7
8use anyhow::{Context, Result};
9use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
10use ring::{hmac, pbkdf2, rand as ring_rand};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::num::NonZeroU32;
14use std::path::Path;
15use std::sync::Arc;
16
17/// Number of PBKDF2 iterations for password hashing
18const PBKDF2_ITERATIONS: u32 = 100_000;
19
20/// Length of the salt for password hashing
21const SALT_LENGTH: usize = 16;
22
23/// Length of the derived key for password hashing
24const CREDENTIAL_LENGTH: usize = 32;
25
26/// Token validity duration default (24 hours in seconds)
27const DEFAULT_TOKEN_EXPIRY_SECS: u64 = 86400;
28
29/// A user credential stored in the credentials file
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct UserCredential {
32    /// Username for authentication
33    pub username: String,
34    /// PBKDF2-hashed password (base64 encoded: salt + derived_key)
35    pub password_hash: String,
36}
37
38/// Token payload that gets signed
39#[derive(Debug, Clone, Serialize, Deserialize)]
40struct TokenPayload {
41    /// Username this token belongs to
42    pub sub: String,
43    /// Token issue timestamp (Unix epoch seconds)
44    pub iat: u64,
45    /// Token expiration timestamp (Unix epoch seconds)
46    pub exp: u64,
47    /// Unique token ID
48    pub jti: String,
49}
50
51/// Authentication service that manages credentials and tokens
52#[derive(Clone)]
53pub struct AuthService {
54    /// Map of username to credential
55    credentials: Arc<HashMap<String, UserCredential>>,
56    /// HMAC key for signing tokens
57    signing_key: Arc<hmac::Key>,
58    /// Token expiry duration in seconds
59    token_expiry_secs: u64,
60    /// Whether auth is enabled
61    enabled: bool,
62}
63
64impl std::fmt::Debug for AuthService {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("AuthService")
67            .field("credentials_count", &self.credentials.len())
68            .field("token_expiry_secs", &self.token_expiry_secs)
69            .field("enabled", &self.enabled)
70            .finish()
71    }
72}
73
74impl AuthService {
75    /// Create a new AuthService
76    ///
77    /// # Arguments
78    /// * `credentials_file` - Optional path to JSON file containing credentials
79    /// * `token_secret` - Optional secret for signing tokens (random if not provided)
80    /// * `token_expiry_secs` - Token expiry duration in seconds
81    /// * `enabled` - Whether authentication is enabled
82    pub fn new(
83        credentials_file: Option<&Path>,
84        token_secret: Option<&str>,
85        token_expiry_secs: Option<u64>,
86        enabled: bool,
87    ) -> Result<Self> {
88        // Load credentials if file is provided
89        let credentials = if let Some(path) = credentials_file {
90            Self::load_credentials(path)?
91        } else {
92            HashMap::new()
93        };
94
95        // Create signing key from provided secret or generate random
96        let signing_key = if let Some(secret) = token_secret {
97            hmac::Key::new(hmac::HMAC_SHA256, secret.as_bytes())
98        } else {
99            let rng = ring_rand::SystemRandom::new();
100            hmac::Key::generate(hmac::HMAC_SHA256, &rng)
101                .map_err(|_| anyhow::anyhow!("Failed to generate signing key"))?
102        };
103
104        Ok(Self {
105            credentials: Arc::new(credentials),
106            signing_key: Arc::new(signing_key),
107            token_expiry_secs: token_expiry_secs.unwrap_or(DEFAULT_TOKEN_EXPIRY_SECS),
108            enabled,
109        })
110    }
111
112    /// Create a disabled AuthService (no authentication required)
113    pub fn disabled() -> Self {
114        Self {
115            credentials: Arc::new(HashMap::new()),
116            signing_key: Arc::new(hmac::Key::new(hmac::HMAC_SHA256, b"disabled-auth-not-used")),
117            token_expiry_secs: DEFAULT_TOKEN_EXPIRY_SECS,
118            enabled: false,
119        }
120    }
121
122    /// Check if authentication is enabled
123    pub fn is_enabled(&self) -> bool {
124        self.enabled
125    }
126
127    /// Load credentials from a JSON file
128    fn load_credentials(path: &Path) -> Result<HashMap<String, UserCredential>> {
129        let content = std::fs::read_to_string(path)
130            .with_context(|| format!("Failed to read credentials file: {}", path.display()))?;
131
132        let credentials: Vec<UserCredential> = serde_json::from_str(&content)
133            .with_context(|| format!("Failed to parse credentials file: {}", path.display()))?;
134
135        let mut map = HashMap::new();
136        for cred in credentials {
137            map.insert(cred.username.clone(), cred);
138        }
139
140        tracing::info!("Loaded {} user credentials", map.len());
141        Ok(map)
142    }
143
144    /// Verify a username/password combination
145    pub fn verify_password(&self, username: &str, password: &str) -> bool {
146        let Some(credential) = self.credentials.get(username) else {
147            return false;
148        };
149
150        // Decode the stored hash (base64: salt + derived_key)
151        let Ok(stored_bytes) = URL_SAFE_NO_PAD.decode(&credential.password_hash) else {
152            tracing::warn!("Invalid base64 in password hash for user: {}", username);
153            return false;
154        };
155
156        if stored_bytes.len() != SALT_LENGTH + CREDENTIAL_LENGTH {
157            tracing::warn!("Invalid password hash length for user: {}", username);
158            return false;
159        }
160
161        let (salt, stored_hash) = stored_bytes.split_at(SALT_LENGTH);
162
163        // Verify the password using PBKDF2
164        pbkdf2::verify(
165            pbkdf2::PBKDF2_HMAC_SHA256,
166            NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
167            salt,
168            password.as_bytes(),
169            stored_hash,
170        )
171        .is_ok()
172    }
173
174    /// Generate a bearer token for a user
175    pub fn generate_token(&self, username: &str) -> Result<String> {
176        let now = std::time::SystemTime::now()
177            .duration_since(std::time::UNIX_EPOCH)
178            .context("System time before Unix epoch")?
179            .as_secs();
180
181        let payload = TokenPayload {
182            sub: username.to_string(),
183            iat: now,
184            exp: now + self.token_expiry_secs,
185            jti: uuid::Uuid::new_v4().to_string(),
186        };
187
188        // Serialize payload to JSON
189        let payload_json = serde_json::to_string(&payload)?;
190        let payload_b64 = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
191
192        // Sign the payload
193        let signature = hmac::sign(&self.signing_key, payload_b64.as_bytes());
194        let signature_b64 = URL_SAFE_NO_PAD.encode(signature.as_ref());
195
196        // Token format: payload.signature (both base64 encoded)
197        Ok(format!("{}.{}", payload_b64, signature_b64))
198    }
199
200    /// Validate a bearer token and return the username if valid
201    pub fn validate_token(&self, token: &str) -> Option<String> {
202        let parts: Vec<&str> = token.split('.').collect();
203        if parts.len() != 2 {
204            return None;
205        }
206
207        let payload_b64 = parts[0];
208        let signature_b64 = parts[1];
209
210        // Verify signature
211        let Ok(signature_bytes) = URL_SAFE_NO_PAD.decode(signature_b64) else {
212            return None;
213        };
214
215        if hmac::verify(&self.signing_key, payload_b64.as_bytes(), &signature_bytes).is_err() {
216            return None;
217        }
218
219        // Decode and validate payload
220        let Ok(payload_json) = URL_SAFE_NO_PAD.decode(payload_b64) else {
221            return None;
222        };
223
224        let Ok(payload): Result<TokenPayload, _> = serde_json::from_slice(&payload_json) else {
225            return None;
226        };
227
228        // Check expiration
229        let now = std::time::SystemTime::now()
230            .duration_since(std::time::UNIX_EPOCH)
231            .ok()?
232            .as_secs();
233
234        if now > payload.exp {
235            return None;
236        }
237
238        Some(payload.sub)
239    }
240
241    /// Hash a password for storage
242    /// Returns base64-encoded salt + derived_key
243    pub fn hash_password(password: &str) -> Result<String> {
244        let rng = ring_rand::SystemRandom::new();
245
246        // Generate random salt
247        let mut salt = [0u8; SALT_LENGTH];
248        ring_rand::SecureRandom::fill(&rng, &mut salt)
249            .map_err(|_| anyhow::anyhow!("Failed to generate salt"))?;
250
251        // Derive key from password
252        let mut derived_key = [0u8; CREDENTIAL_LENGTH];
253        pbkdf2::derive(
254            pbkdf2::PBKDF2_HMAC_SHA256,
255            NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
256            &salt,
257            password.as_bytes(),
258            &mut derived_key,
259        );
260
261        // Combine salt + derived_key and encode
262        let mut combined = Vec::with_capacity(SALT_LENGTH + CREDENTIAL_LENGTH);
263        combined.extend_from_slice(&salt);
264        combined.extend_from_slice(&derived_key);
265
266        Ok(URL_SAFE_NO_PAD.encode(&combined))
267    }
268}
269
270/// Request body for token generation
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct TokenRequest {
273    /// Username for authentication
274    pub username: String,
275    /// Password for authentication
276    pub password: String,
277}
278
279/// Response body for successful token generation
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct TokenResponse {
282    /// Bearer token
283    pub token: String,
284    /// Token type (always "Bearer")
285    pub token_type: String,
286    /// Seconds until token expires
287    pub expires_in: u64,
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use std::io::Write;
294    use tempfile::NamedTempFile;
295
296    #[test]
297    fn test_password_hashing() {
298        let password = "my_secret_password";
299        let hash = AuthService::hash_password(password).unwrap();
300
301        // Hash should be base64 encoded
302        let decoded = URL_SAFE_NO_PAD.decode(&hash).unwrap();
303        assert_eq!(decoded.len(), SALT_LENGTH + CREDENTIAL_LENGTH);
304    }
305
306    #[test]
307    fn test_password_verification() {
308        let password = "test_password_123";
309        let hash = AuthService::hash_password(password).unwrap();
310
311        // Create a credentials file
312        let credentials = vec![UserCredential {
313            username: "testuser".to_string(),
314            password_hash: hash,
315        }];
316
317        let mut file = NamedTempFile::new().unwrap();
318        write!(file, "{}", serde_json::to_string(&credentials).unwrap()).unwrap();
319
320        let auth =
321            AuthService::new(Some(file.path()), Some("test_secret"), Some(3600), true).unwrap();
322
323        // Correct password should verify
324        assert!(auth.verify_password("testuser", password));
325
326        // Wrong password should fail
327        assert!(!auth.verify_password("testuser", "wrong_password"));
328
329        // Unknown user should fail
330        assert!(!auth.verify_password("unknown", password));
331    }
332
333    #[test]
334    fn test_token_generation_and_validation() {
335        let auth = AuthService::new(None, Some("test_secret"), Some(3600), true).unwrap();
336
337        let token = auth.generate_token("testuser").unwrap();
338
339        // Token should validate and return correct username
340        let username = auth.validate_token(&token);
341        assert_eq!(username, Some("testuser".to_string()));
342
343        // Invalid token should fail
344        assert!(auth.validate_token("invalid.token").is_none());
345        assert!(auth.validate_token("notavalidtoken").is_none());
346    }
347
348    #[test]
349    fn test_expired_token() {
350        // Create auth service with 0 second expiry
351        let auth = AuthService::new(None, Some("test_secret"), Some(0), true).unwrap();
352
353        let token = auth.generate_token("testuser").unwrap();
354
355        // Wait more than 1 second so the token is definitely expired
356        // (expiry is checked at second granularity)
357        std::thread::sleep(std::time::Duration::from_millis(1100));
358
359        assert!(auth.validate_token(&token).is_none());
360    }
361
362    #[test]
363    fn test_disabled_auth() {
364        let auth = AuthService::disabled();
365        assert!(!auth.is_enabled());
366    }
367
368    #[test]
369    fn test_token_tampering() {
370        let auth = AuthService::new(None, Some("test_secret"), Some(3600), true).unwrap();
371
372        let token = auth.generate_token("testuser").unwrap();
373        let parts: Vec<&str> = token.split('.').collect();
374
375        // Tamper with payload
376        let tampered_payload = URL_SAFE_NO_PAD
377            .encode(b"{\"sub\":\"admin\",\"iat\":0,\"exp\":9999999999,\"jti\":\"fake\"}");
378        let tampered_token = format!("{}.{}", tampered_payload, parts[1]);
379
380        assert!(auth.validate_token(&tampered_token).is_none());
381    }
382}