1pub mod users;
7
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use base64::Engine;
13use base64::engine::general_purpose::STANDARD as BASE64;
14use hmac::{Hmac, Mac};
15use parking_lot::RwLock;
16use sha2::Sha256;
17use uuid::Uuid;
18use zeroize::Zeroize;
19
20use haystack_core::auth::{
21 DEFAULT_ITERATIONS, ScramCredentials, ScramHandshake, derive_credentials, extract_client_nonce,
22 format_auth_info, format_www_authenticate, generate_nonce, server_first_message,
23 server_verify_final,
24};
25
26use users::{UserRecord, load_users_from_str, load_users_from_toml};
27
28#[derive(Debug, Clone)]
30pub struct AuthUser {
31 pub username: String,
32 pub permissions: Vec<String>,
33}
34
35const HANDSHAKE_TTL: Duration = Duration::from_secs(60);
37
38pub struct AuthManager {
43 users: HashMap<String, UserRecord>,
45 handshakes: RwLock<HashMap<String, (ScramHandshake, Instant)>>,
47 tokens: RwLock<HashMap<String, (AuthUser, Instant)>>,
49 token_ttl: Duration,
51 server_secret: [u8; 32],
54 cleanup_counter: AtomicU64,
56}
57
58impl Drop for AuthManager {
59 fn drop(&mut self) {
60 self.server_secret.zeroize();
61 }
62}
63
64impl AuthManager {
65 pub fn new(users: HashMap<String, UserRecord>, token_ttl: Duration) -> Self {
67 let mut server_secret = [0u8; 32];
68 rand::RngExt::fill(&mut rand::rng(), &mut server_secret);
69 Self {
70 users,
71 handshakes: RwLock::new(HashMap::new()),
72 tokens: RwLock::new(HashMap::new()),
73 token_ttl,
74 server_secret,
75 cleanup_counter: AtomicU64::new(0),
76 }
77 }
78
79 pub fn empty() -> Self {
81 Self::new(HashMap::new(), Duration::from_secs(3600))
82 }
83
84 pub fn with_token_ttl(mut self, duration: Duration) -> Self {
86 self.token_ttl = duration;
87 self
88 }
89
90 pub fn from_toml(path: &str) -> Result<Self, String> {
92 let users = load_users_from_toml(path)?;
93 Ok(Self::new(users, Duration::from_secs(3600)))
94 }
95
96 pub fn from_toml_str(content: &str) -> Result<Self, String> {
98 let users = load_users_from_str(content)?;
99 Ok(Self::new(users, Duration::from_secs(3600)))
100 }
101
102 pub fn is_enabled(&self) -> bool {
104 !self.users.is_empty()
105 }
106
107 fn fake_credentials(&self, username: &str) -> ScramCredentials {
113 let mut mac = <Hmac<Sha256>>::new_from_slice(&self.server_secret)
114 .expect("HMAC accepts keys of any size");
115 mac.update(username.as_bytes());
116 let fake_salt = mac.finalize().into_bytes();
117
118 derive_credentials("", &fake_salt, DEFAULT_ITERATIONS)
122 }
123
124 pub fn handle_hello(
134 &self,
135 username: &str,
136 client_first_b64: Option<&str>,
137 ) -> Result<String, String> {
138 let owned_fake;
139 let credentials: &ScramCredentials = match self.users.get(username) {
140 Some(user_record) => &user_record.credentials,
141 None => {
142 owned_fake = self.fake_credentials(username);
143 &owned_fake
144 }
145 };
146
147 let client_nonce = match client_first_b64 {
149 Some(data) => {
150 extract_client_nonce(data).map_err(|e| format!("invalid client-first data: {e}"))?
151 }
152 None => generate_nonce(),
153 };
154
155 let (handshake, server_first_b64) =
157 server_first_message(username, &client_nonce, credentials);
158
159 {
161 let now = Instant::now();
162 self.handshakes
163 .write()
164 .retain(|_, (_, created)| now.duration_since(*created) < HANDSHAKE_TTL);
165 }
166
167 let handshake_token = Uuid::new_v4().to_string();
169 self.handshakes
170 .write()
171 .insert(handshake_token.clone(), (handshake, Instant::now()));
172
173 let www_auth = format_www_authenticate(&handshake_token, "SHA-256", &server_first_b64);
175 Ok(www_auth)
176 }
177
178 pub fn handle_scram(
182 &self,
183 handshake_token: &str,
184 data: &str,
185 ) -> Result<(String, String), String> {
186 let (handshake, created_at) = self
188 .handshakes
189 .write()
190 .remove(handshake_token)
191 .ok_or_else(|| "invalid or expired handshake token".to_string())?;
192 if created_at.elapsed() > HANDSHAKE_TTL {
193 return Err("handshake token expired".to_string());
194 }
195
196 let username = handshake.username.clone();
197
198 let server_sig = server_verify_final(&handshake, data)
200 .map_err(|e| format!("SCRAM verification failed: {e}"))?;
201
202 let auth_token = Uuid::new_v4().to_string();
204
205 let permissions = self
207 .users
208 .get(&username)
209 .map(|r| r.permissions.clone())
210 .unwrap_or_default();
211
212 self.tokens.write().insert(
214 auth_token.clone(),
215 (
216 AuthUser {
217 username,
218 permissions,
219 },
220 Instant::now(),
221 ),
222 );
223
224 let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
226 let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
227 let auth_info = format_auth_info(&auth_token, &server_final_b64);
228
229 Ok((auth_token, auth_info))
230 }
231
232 pub fn validate_token(&self, token: &str) -> Option<AuthUser> {
238 let count = self.cleanup_counter.fetch_add(1, Ordering::Relaxed);
240 if count.is_multiple_of(100) {
241 let mut tokens = self.tokens.write();
242 let ttl = self.token_ttl;
243 tokens.retain(|_, (_, created)| created.elapsed() < ttl);
244 }
245
246 let mut tokens = self.tokens.write();
247 match tokens.get(token) {
248 Some((user, created_at)) => {
249 if created_at.elapsed() <= self.token_ttl {
250 Some(user.clone())
251 } else {
252 tokens.remove(token);
254 None
255 }
256 }
257 None => None,
258 }
259 }
260
261 pub fn revoke_token(&self, token: &str) -> bool {
263 self.tokens.write().remove(token).is_some()
264 }
265
266 #[doc(hidden)]
269 pub fn inject_token(&self, token: String, user: AuthUser) {
270 self.tokens.write().insert(token, (user, Instant::now()));
271 }
272
273 pub fn check_permission(user: &AuthUser, required: &str) -> bool {
275 if user.permissions.contains(&"admin".to_string()) {
277 return true;
278 }
279 user.permissions.contains(&required.to_string())
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use crate::auth::users::hash_password;
287
288 fn make_test_manager() -> AuthManager {
289 let hash = hash_password("s3cret");
290 let toml_str = format!(
291 r#"
292[users.admin]
293password_hash = "{hash}"
294permissions = ["read", "write", "admin"]
295
296[users.viewer]
297password_hash = "{hash}"
298permissions = ["read"]
299"#
300 );
301 AuthManager::from_toml_str(&toml_str).unwrap()
302 }
303
304 #[test]
305 fn empty_manager_is_disabled() {
306 let mgr = AuthManager::empty();
307 assert!(!mgr.is_enabled());
308 }
309
310 #[test]
311 fn manager_with_users_is_enabled() {
312 let mgr = make_test_manager();
313 assert!(mgr.is_enabled());
314 }
315
316 #[test]
317 fn hello_unknown_user_returns_fake_challenge() {
318 let mgr = make_test_manager();
319 let result = mgr.handle_hello("nonexistent", None);
322 assert!(result.is_ok());
323 let www_auth = result.unwrap();
324 assert!(www_auth.contains("SCRAM"));
325 assert!(www_auth.contains("SHA-256"));
326 }
327
328 #[test]
329 fn hello_known_user_succeeds() {
330 let mgr = make_test_manager();
331 let result = mgr.handle_hello("admin", None);
332 assert!(result.is_ok());
333 let www_auth = result.unwrap();
334 assert!(www_auth.contains("SCRAM"));
335 assert!(www_auth.contains("SHA-256"));
336 }
337
338 #[test]
339 fn hello_known_and_unknown_users_look_similar() {
340 let mgr = make_test_manager();
341 let known = mgr.handle_hello("admin", None).unwrap();
342 let unknown = mgr.handle_hello("nonexistent", None).unwrap();
343
344 assert!(known.starts_with("SCRAM handshakeToken="));
347 assert!(unknown.starts_with("SCRAM handshakeToken="));
348 assert!(known.contains("hash=SHA-256"));
349 assert!(unknown.contains("hash=SHA-256"));
350 assert!(known.contains("data="));
351 assert!(unknown.contains("data="));
352 }
353
354 #[test]
355 fn fake_challenge_is_deterministic_per_username() {
356 let mgr = make_test_manager();
357 let creds1 = mgr.fake_credentials("ghost");
360 let creds2 = mgr.fake_credentials("ghost");
361 assert_eq!(creds1.salt, creds2.salt);
362 assert_eq!(creds1.stored_key, creds2.stored_key);
363 assert_eq!(creds1.server_key, creds2.server_key);
364
365 let creds3 = mgr.fake_credentials("phantom");
367 assert_ne!(creds1.salt, creds3.salt);
368 }
369
370 #[test]
371 fn validate_token_returns_none_for_unknown() {
372 let mgr = make_test_manager();
373 assert!(mgr.validate_token("nonexistent-token").is_none());
374 }
375
376 #[test]
377 fn check_permission_admin_has_all() {
378 let user = AuthUser {
379 username: "admin".to_string(),
380 permissions: vec!["admin".to_string()],
381 };
382 assert!(AuthManager::check_permission(&user, "read"));
383 assert!(AuthManager::check_permission(&user, "write"));
384 assert!(AuthManager::check_permission(&user, "admin"));
385 }
386
387 #[test]
388 fn check_permission_viewer_limited() {
389 let user = AuthUser {
390 username: "viewer".to_string(),
391 permissions: vec!["read".to_string()],
392 };
393 assert!(AuthManager::check_permission(&user, "read"));
394 assert!(!AuthManager::check_permission(&user, "write"));
395 assert!(!AuthManager::check_permission(&user, "admin"));
396 }
397
398 #[test]
399 fn revoke_token_returns_false_for_unknown() {
400 let mgr = make_test_manager();
401 assert!(!mgr.revoke_token("nonexistent-token"));
402 }
403
404 #[test]
405 fn validate_token_succeeds_before_expiry() {
406 let mgr = make_test_manager();
407 let user = AuthUser {
409 username: "admin".to_string(),
410 permissions: vec!["admin".to_string()],
411 };
412 mgr.tokens
413 .write()
414 .insert("good-token".to_string(), (user, Instant::now()));
415
416 assert!(mgr.validate_token("good-token").is_some());
417 }
418
419 #[test]
420 fn validate_token_fails_after_expiry() {
421 let mgr = make_test_manager().with_token_ttl(Duration::from_secs(0));
423
424 let user = AuthUser {
425 username: "admin".to_string(),
426 permissions: vec!["admin".to_string()],
427 };
428 mgr.tokens
431 .write()
432 .insert("expired-token".to_string(), (user, Instant::now()));
433
434 assert!(mgr.validate_token("expired-token").is_none());
436
437 assert!(mgr.tokens.read().get("expired-token").is_none());
439 }
440
441 #[test]
442 fn with_token_ttl_sets_custom_duration() {
443 let mgr = AuthManager::empty().with_token_ttl(Duration::from_secs(120));
444 assert_eq!(mgr.token_ttl, Duration::from_secs(120));
445 }
446}