Skip to main content

haystack_server/auth/
mod.rs

1//! Server-side authentication manager using SCRAM SHA-256.
2//!
3//! Manages user records, in-flight SCRAM handshakes, and active bearer
4//! tokens.
5
6pub mod users;
7
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use parking_lot::RwLock;
15use sha2::Sha256;
16use uuid::Uuid;
17
18use haystack_core::auth::{
19    DEFAULT_ITERATIONS, ScramCredentials, ScramHandshake, derive_credentials, format_auth_info,
20    format_www_authenticate, generate_nonce, server_first_message, server_verify_final,
21};
22
23use users::{UserRecord, load_users_from_str, load_users_from_toml};
24
25/// An authenticated user with associated permissions.
26#[derive(Debug, Clone)]
27pub struct AuthUser {
28    pub username: String,
29    pub permissions: Vec<String>,
30}
31
32/// Time-to-live for in-flight SCRAM handshakes.
33const HANDSHAKE_TTL: Duration = Duration::from_secs(60);
34
35/// Server-side authentication manager.
36///
37/// Holds user credentials, in-flight SCRAM handshakes, and active
38/// bearer tokens.
39pub struct AuthManager {
40    /// Username -> pre-computed SCRAM credentials + permissions.
41    users: HashMap<String, UserRecord>,
42    /// In-flight SCRAM handshakes: handshake_token -> (ScramHandshake, created_at).
43    handshakes: RwLock<HashMap<String, (ScramHandshake, Instant)>>,
44    /// Active bearer tokens: auth_token -> (AuthUser, created_at).
45    tokens: RwLock<HashMap<String, (AuthUser, Instant)>>,
46    /// Time-to-live for bearer tokens.
47    token_ttl: Duration,
48    /// Secret used to derive fake SCRAM challenges for unknown users,
49    /// preventing username enumeration attacks.
50    server_secret: [u8; 32],
51}
52
53impl AuthManager {
54    /// Create a new AuthManager with the given user records and token TTL.
55    pub fn new(users: HashMap<String, UserRecord>, token_ttl: Duration) -> Self {
56        let mut server_secret = [0u8; 32];
57        rand::Rng::fill(&mut rand::rng(), &mut server_secret);
58        Self {
59            users,
60            handshakes: RwLock::new(HashMap::new()),
61            tokens: RwLock::new(HashMap::new()),
62            token_ttl,
63            server_secret,
64        }
65    }
66
67    /// Create an AuthManager with no users (auth effectively disabled).
68    pub fn empty() -> Self {
69        Self::new(HashMap::new(), Duration::from_secs(3600))
70    }
71
72    /// Builder method to configure the token TTL.
73    pub fn with_token_ttl(mut self, duration: Duration) -> Self {
74        self.token_ttl = duration;
75        self
76    }
77
78    /// Create an AuthManager from a TOML file.
79    pub fn from_toml(path: &str) -> Result<Self, String> {
80        let users = load_users_from_toml(path)?;
81        Ok(Self::new(users, Duration::from_secs(3600)))
82    }
83
84    /// Create an AuthManager from TOML content string.
85    pub fn from_toml_str(content: &str) -> Result<Self, String> {
86        let users = load_users_from_str(content)?;
87        Ok(Self::new(users, Duration::from_secs(3600)))
88    }
89
90    /// Returns true if authentication is enabled (there are registered users).
91    pub fn is_enabled(&self) -> bool {
92        !self.users.is_empty()
93    }
94
95    /// Derive deterministic fake SCRAM credentials for an unknown username.
96    ///
97    /// Uses HMAC(server_secret, username) so the same unknown username always
98    /// produces the same salt, making the response indistinguishable from a
99    /// real user's challenge to an outside observer.
100    fn fake_credentials(&self, username: &str) -> ScramCredentials {
101        let mut mac = <Hmac<Sha256>>::new_from_slice(&self.server_secret)
102            .expect("HMAC accepts keys of any size");
103        mac.update(username.as_bytes());
104        let fake_salt = mac.finalize().into_bytes();
105
106        // Derive credentials using a throwaway password; the handshake will
107        // always fail at the `handle_scram` step because the attacker does
108        // not know a valid password, but the challenge itself looks normal.
109        derive_credentials("", &fake_salt, DEFAULT_ITERATIONS)
110    }
111
112    /// Handle a HELLO request: look up user, create SCRAM handshake.
113    ///
114    /// Returns the `WWW-Authenticate` header value for the 401 response.
115    /// Unknown users receive a fake but plausible challenge to prevent
116    /// username enumeration.
117    pub fn handle_hello(&self, username: &str) -> Result<String, String> {
118        let credentials = match self.users.get(username) {
119            Some(user_record) => user_record.credentials.clone(),
120            None => self.fake_credentials(username),
121        };
122
123        // Generate client nonce stand-in (server extracts from client-first)
124        let client_nonce = generate_nonce();
125
126        // Create server-first-message
127        let (handshake, server_first_b64) =
128            server_first_message(username, &client_nonce, &credentials);
129
130        // Lazy cleanup: remove expired handshakes before inserting.
131        {
132            let now = Instant::now();
133            self.handshakes
134                .write()
135                .retain(|_, (_, created)| now.duration_since(*created) < HANDSHAKE_TTL);
136        }
137
138        // Store handshake with a unique token and timestamp.
139        let handshake_token = Uuid::new_v4().to_string();
140        self.handshakes
141            .write()
142            .insert(handshake_token.clone(), (handshake, Instant::now()));
143
144        // Format the WWW-Authenticate header
145        let www_auth = format_www_authenticate(&handshake_token, "SHA-256", &server_first_b64);
146        Ok(www_auth)
147    }
148
149    /// Handle a SCRAM request: verify client proof, issue auth token.
150    ///
151    /// Returns `(auth_token, authentication_info_header_value)`.
152    pub fn handle_scram(
153        &self,
154        handshake_token: &str,
155        data: &str,
156    ) -> Result<(String, String), String> {
157        // Remove the handshake (one-time use) and check expiry.
158        let (handshake, created_at) = self
159            .handshakes
160            .write()
161            .remove(handshake_token)
162            .ok_or_else(|| "invalid or expired handshake token".to_string())?;
163        if created_at.elapsed() > HANDSHAKE_TTL {
164            return Err("handshake token expired".to_string());
165        }
166
167        let username = handshake.username.clone();
168
169        // Verify client proof
170        let server_sig = server_verify_final(&handshake, data)
171            .map_err(|e| format!("SCRAM verification failed: {e}"))?;
172
173        // Issue auth token
174        let auth_token = Uuid::new_v4().to_string();
175
176        // Look up permissions
177        let permissions = self
178            .users
179            .get(&username)
180            .map(|r| r.permissions.clone())
181            .unwrap_or_default();
182
183        // Store token -> (user, created_at) mapping
184        self.tokens.write().insert(
185            auth_token.clone(),
186            (
187                AuthUser {
188                    username,
189                    permissions,
190                },
191                Instant::now(),
192            ),
193        );
194
195        // Format the server-final data (v=<server_signature>)
196        let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
197        let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
198        let auth_info = format_auth_info(&auth_token, &server_final_b64);
199
200        Ok((auth_token, auth_info))
201    }
202
203    /// Validate a bearer token and return the associated user.
204    ///
205    /// Returns `None` if the token is unknown or has expired. Expired
206    /// tokens are automatically removed.
207    pub fn validate_token(&self, token: &str) -> Option<AuthUser> {
208        // First, check with a read lock.
209        {
210            let tokens = self.tokens.read();
211            match tokens.get(token) {
212                Some((user, created_at)) => {
213                    if created_at.elapsed() <= self.token_ttl {
214                        return Some(user.clone());
215                    }
216                    // Token expired -- fall through to remove it.
217                }
218                None => return None,
219            }
220        }
221        // Expired: remove under a write lock.
222        self.tokens.write().remove(token);
223        None
224    }
225
226    /// Remove a bearer token (logout / close).
227    pub fn revoke_token(&self, token: &str) -> bool {
228        self.tokens.write().remove(token).is_some()
229    }
230
231    /// Inject a token directly (for testing). The token is stamped with the
232    /// current instant so it will not be considered expired.
233    #[doc(hidden)]
234    pub fn inject_token(&self, token: String, user: AuthUser) {
235        self.tokens.write().insert(token, (user, Instant::now()));
236    }
237
238    /// Check whether a user has a required permission.
239    pub fn check_permission(user: &AuthUser, required: &str) -> bool {
240        // Admin has all permissions
241        if user.permissions.contains(&"admin".to_string()) {
242            return true;
243        }
244        user.permissions.contains(&required.to_string())
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::auth::users::hash_password;
252
253    fn make_test_manager() -> AuthManager {
254        let hash = hash_password("s3cret");
255        let toml_str = format!(
256            r#"
257[users.admin]
258password_hash = "{hash}"
259permissions = ["read", "write", "admin"]
260
261[users.viewer]
262password_hash = "{hash}"
263permissions = ["read"]
264"#
265        );
266        AuthManager::from_toml_str(&toml_str).unwrap()
267    }
268
269    #[test]
270    fn empty_manager_is_disabled() {
271        let mgr = AuthManager::empty();
272        assert!(!mgr.is_enabled());
273    }
274
275    #[test]
276    fn manager_with_users_is_enabled() {
277        let mgr = make_test_manager();
278        assert!(mgr.is_enabled());
279    }
280
281    #[test]
282    fn hello_unknown_user_returns_fake_challenge() {
283        let mgr = make_test_manager();
284        // Unknown users now get a plausible SCRAM challenge instead of an
285        // error, preventing username enumeration.
286        let result = mgr.handle_hello("nonexistent");
287        assert!(result.is_ok());
288        let www_auth = result.unwrap();
289        assert!(www_auth.contains("SCRAM"));
290        assert!(www_auth.contains("SHA-256"));
291    }
292
293    #[test]
294    fn hello_known_user_succeeds() {
295        let mgr = make_test_manager();
296        let result = mgr.handle_hello("admin");
297        assert!(result.is_ok());
298        let www_auth = result.unwrap();
299        assert!(www_auth.contains("SCRAM"));
300        assert!(www_auth.contains("SHA-256"));
301    }
302
303    #[test]
304    fn hello_known_and_unknown_users_look_similar() {
305        let mgr = make_test_manager();
306        let known = mgr.handle_hello("admin").unwrap();
307        let unknown = mgr.handle_hello("nonexistent").unwrap();
308
309        // Both responses must have the same structural format so that an
310        // attacker cannot distinguish real from fake users.
311        assert!(known.starts_with("SCRAM handshakeToken="));
312        assert!(unknown.starts_with("SCRAM handshakeToken="));
313        assert!(known.contains("hash=SHA-256"));
314        assert!(unknown.contains("hash=SHA-256"));
315        assert!(known.contains("data="));
316        assert!(unknown.contains("data="));
317    }
318
319    #[test]
320    fn fake_challenge_is_deterministic_per_username() {
321        let mgr = make_test_manager();
322        // The fake salt must be deterministic so that repeated HELLO requests
323        // for the same unknown username produce consistent parameters.
324        let creds1 = mgr.fake_credentials("ghost");
325        let creds2 = mgr.fake_credentials("ghost");
326        assert_eq!(creds1.salt, creds2.salt);
327        assert_eq!(creds1.stored_key, creds2.stored_key);
328        assert_eq!(creds1.server_key, creds2.server_key);
329
330        // Different usernames produce different fake salts.
331        let creds3 = mgr.fake_credentials("phantom");
332        assert_ne!(creds1.salt, creds3.salt);
333    }
334
335    #[test]
336    fn validate_token_returns_none_for_unknown() {
337        let mgr = make_test_manager();
338        assert!(mgr.validate_token("nonexistent-token").is_none());
339    }
340
341    #[test]
342    fn check_permission_admin_has_all() {
343        let user = AuthUser {
344            username: "admin".to_string(),
345            permissions: vec!["admin".to_string()],
346        };
347        assert!(AuthManager::check_permission(&user, "read"));
348        assert!(AuthManager::check_permission(&user, "write"));
349        assert!(AuthManager::check_permission(&user, "admin"));
350    }
351
352    #[test]
353    fn check_permission_viewer_limited() {
354        let user = AuthUser {
355            username: "viewer".to_string(),
356            permissions: vec!["read".to_string()],
357        };
358        assert!(AuthManager::check_permission(&user, "read"));
359        assert!(!AuthManager::check_permission(&user, "write"));
360        assert!(!AuthManager::check_permission(&user, "admin"));
361    }
362
363    #[test]
364    fn revoke_token_returns_false_for_unknown() {
365        let mgr = make_test_manager();
366        assert!(!mgr.revoke_token("nonexistent-token"));
367    }
368
369    #[test]
370    fn validate_token_succeeds_before_expiry() {
371        let mgr = make_test_manager();
372        // Manually insert a token with Instant::now() (fresh, not expired).
373        let user = AuthUser {
374            username: "admin".to_string(),
375            permissions: vec!["admin".to_string()],
376        };
377        mgr.tokens
378            .write()
379            .insert("good-token".to_string(), (user, Instant::now()));
380
381        assert!(mgr.validate_token("good-token").is_some());
382    }
383
384    #[test]
385    fn validate_token_fails_after_expiry() {
386        // Use a very short TTL so the token is already expired.
387        let mgr = make_test_manager().with_token_ttl(Duration::from_secs(0));
388
389        let user = AuthUser {
390            username: "admin".to_string(),
391            permissions: vec!["admin".to_string()],
392        };
393        // Insert a token that was created "now" -- with a 0s TTL it is
394        // immediately expired.
395        mgr.tokens
396            .write()
397            .insert("expired-token".to_string(), (user, Instant::now()));
398
399        // Even though the token exists, it should be reported as expired.
400        assert!(mgr.validate_token("expired-token").is_none());
401
402        // The expired token should have been removed from the map.
403        assert!(mgr.tokens.read().get("expired-token").is_none());
404    }
405
406    #[test]
407    fn with_token_ttl_sets_custom_duration() {
408        let mgr = AuthManager::empty().with_token_ttl(Duration::from_secs(120));
409        assert_eq!(mgr.token_ttl, Duration::from_secs(120));
410    }
411}