1pub mod users;
7
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use parking_lot::RwLock;
15use sha2::Sha256;
16use uuid::Uuid;
17
18use haystack_core::auth::{
19 DEFAULT_ITERATIONS, ScramCredentials, ScramHandshake, derive_credentials, extract_client_nonce,
20 format_auth_info, format_www_authenticate, generate_nonce, server_first_message,
21 server_verify_final,
22};
23
24use users::{UserRecord, load_users_from_str, load_users_from_toml};
25
26#[derive(Debug, Clone)]
28pub struct AuthUser {
29 pub username: String,
30 pub permissions: Vec<String>,
31}
32
33const HANDSHAKE_TTL: Duration = Duration::from_secs(60);
35
36pub struct AuthManager {
41 users: HashMap<String, UserRecord>,
43 handshakes: RwLock<HashMap<String, (ScramHandshake, Instant)>>,
45 tokens: RwLock<HashMap<String, (AuthUser, Instant)>>,
47 token_ttl: Duration,
49 server_secret: [u8; 32],
52}
53
54impl AuthManager {
55 pub fn new(users: HashMap<String, UserRecord>, token_ttl: Duration) -> Self {
57 let mut server_secret = [0u8; 32];
58 rand::Rng::fill(&mut rand::rng(), &mut server_secret);
59 Self {
60 users,
61 handshakes: RwLock::new(HashMap::new()),
62 tokens: RwLock::new(HashMap::new()),
63 token_ttl,
64 server_secret,
65 }
66 }
67
68 pub fn empty() -> Self {
70 Self::new(HashMap::new(), Duration::from_secs(3600))
71 }
72
73 pub fn with_token_ttl(mut self, duration: Duration) -> Self {
75 self.token_ttl = duration;
76 self
77 }
78
79 pub fn from_toml(path: &str) -> Result<Self, String> {
81 let users = load_users_from_toml(path)?;
82 Ok(Self::new(users, Duration::from_secs(3600)))
83 }
84
85 pub fn from_toml_str(content: &str) -> Result<Self, String> {
87 let users = load_users_from_str(content)?;
88 Ok(Self::new(users, Duration::from_secs(3600)))
89 }
90
91 pub fn is_enabled(&self) -> bool {
93 !self.users.is_empty()
94 }
95
96 fn fake_credentials(&self, username: &str) -> ScramCredentials {
102 let mut mac = <Hmac<Sha256>>::new_from_slice(&self.server_secret)
103 .expect("HMAC accepts keys of any size");
104 mac.update(username.as_bytes());
105 let fake_salt = mac.finalize().into_bytes();
106
107 derive_credentials("", &fake_salt, DEFAULT_ITERATIONS)
111 }
112
113 pub fn handle_hello(
123 &self,
124 username: &str,
125 client_first_b64: Option<&str>,
126 ) -> Result<String, String> {
127 let credentials = match self.users.get(username) {
128 Some(user_record) => user_record.credentials.clone(),
129 None => self.fake_credentials(username),
130 };
131
132 let client_nonce = match client_first_b64 {
134 Some(data) => {
135 extract_client_nonce(data).map_err(|e| format!("invalid client-first data: {e}"))?
136 }
137 None => generate_nonce(),
138 };
139
140 let (handshake, server_first_b64) =
142 server_first_message(username, &client_nonce, &credentials);
143
144 {
146 let now = Instant::now();
147 self.handshakes
148 .write()
149 .retain(|_, (_, created)| now.duration_since(*created) < HANDSHAKE_TTL);
150 }
151
152 let handshake_token = Uuid::new_v4().to_string();
154 self.handshakes
155 .write()
156 .insert(handshake_token.clone(), (handshake, Instant::now()));
157
158 let www_auth = format_www_authenticate(&handshake_token, "SHA-256", &server_first_b64);
160 Ok(www_auth)
161 }
162
163 pub fn handle_scram(
167 &self,
168 handshake_token: &str,
169 data: &str,
170 ) -> Result<(String, String), String> {
171 let (handshake, created_at) = self
173 .handshakes
174 .write()
175 .remove(handshake_token)
176 .ok_or_else(|| "invalid or expired handshake token".to_string())?;
177 if created_at.elapsed() > HANDSHAKE_TTL {
178 return Err("handshake token expired".to_string());
179 }
180
181 let username = handshake.username.clone();
182
183 let server_sig = server_verify_final(&handshake, data)
185 .map_err(|e| format!("SCRAM verification failed: {e}"))?;
186
187 let auth_token = Uuid::new_v4().to_string();
189
190 let permissions = self
192 .users
193 .get(&username)
194 .map(|r| r.permissions.clone())
195 .unwrap_or_default();
196
197 self.tokens.write().insert(
199 auth_token.clone(),
200 (
201 AuthUser {
202 username,
203 permissions,
204 },
205 Instant::now(),
206 ),
207 );
208
209 let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
211 let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
212 let auth_info = format_auth_info(&auth_token, &server_final_b64);
213
214 Ok((auth_token, auth_info))
215 }
216
217 pub fn validate_token(&self, token: &str) -> Option<AuthUser> {
222 {
224 let tokens = self.tokens.read();
225 match tokens.get(token) {
226 Some((user, created_at)) => {
227 if created_at.elapsed() <= self.token_ttl {
228 return Some(user.clone());
229 }
230 }
232 None => return None,
233 }
234 }
235 self.tokens.write().remove(token);
237 None
238 }
239
240 pub fn revoke_token(&self, token: &str) -> bool {
242 self.tokens.write().remove(token).is_some()
243 }
244
245 #[doc(hidden)]
248 pub fn inject_token(&self, token: String, user: AuthUser) {
249 self.tokens.write().insert(token, (user, Instant::now()));
250 }
251
252 pub fn check_permission(user: &AuthUser, required: &str) -> bool {
254 if user.permissions.contains(&"admin".to_string()) {
256 return true;
257 }
258 user.permissions.contains(&required.to_string())
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::auth::users::hash_password;
266
267 fn make_test_manager() -> AuthManager {
268 let hash = hash_password("s3cret");
269 let toml_str = format!(
270 r#"
271[users.admin]
272password_hash = "{hash}"
273permissions = ["read", "write", "admin"]
274
275[users.viewer]
276password_hash = "{hash}"
277permissions = ["read"]
278"#
279 );
280 AuthManager::from_toml_str(&toml_str).unwrap()
281 }
282
283 #[test]
284 fn empty_manager_is_disabled() {
285 let mgr = AuthManager::empty();
286 assert!(!mgr.is_enabled());
287 }
288
289 #[test]
290 fn manager_with_users_is_enabled() {
291 let mgr = make_test_manager();
292 assert!(mgr.is_enabled());
293 }
294
295 #[test]
296 fn hello_unknown_user_returns_fake_challenge() {
297 let mgr = make_test_manager();
298 let result = mgr.handle_hello("nonexistent", None);
301 assert!(result.is_ok());
302 let www_auth = result.unwrap();
303 assert!(www_auth.contains("SCRAM"));
304 assert!(www_auth.contains("SHA-256"));
305 }
306
307 #[test]
308 fn hello_known_user_succeeds() {
309 let mgr = make_test_manager();
310 let result = mgr.handle_hello("admin", None);
311 assert!(result.is_ok());
312 let www_auth = result.unwrap();
313 assert!(www_auth.contains("SCRAM"));
314 assert!(www_auth.contains("SHA-256"));
315 }
316
317 #[test]
318 fn hello_known_and_unknown_users_look_similar() {
319 let mgr = make_test_manager();
320 let known = mgr.handle_hello("admin", None).unwrap();
321 let unknown = mgr.handle_hello("nonexistent", None).unwrap();
322
323 assert!(known.starts_with("SCRAM handshakeToken="));
326 assert!(unknown.starts_with("SCRAM handshakeToken="));
327 assert!(known.contains("hash=SHA-256"));
328 assert!(unknown.contains("hash=SHA-256"));
329 assert!(known.contains("data="));
330 assert!(unknown.contains("data="));
331 }
332
333 #[test]
334 fn fake_challenge_is_deterministic_per_username() {
335 let mgr = make_test_manager();
336 let creds1 = mgr.fake_credentials("ghost");
339 let creds2 = mgr.fake_credentials("ghost");
340 assert_eq!(creds1.salt, creds2.salt);
341 assert_eq!(creds1.stored_key, creds2.stored_key);
342 assert_eq!(creds1.server_key, creds2.server_key);
343
344 let creds3 = mgr.fake_credentials("phantom");
346 assert_ne!(creds1.salt, creds3.salt);
347 }
348
349 #[test]
350 fn validate_token_returns_none_for_unknown() {
351 let mgr = make_test_manager();
352 assert!(mgr.validate_token("nonexistent-token").is_none());
353 }
354
355 #[test]
356 fn check_permission_admin_has_all() {
357 let user = AuthUser {
358 username: "admin".to_string(),
359 permissions: vec!["admin".to_string()],
360 };
361 assert!(AuthManager::check_permission(&user, "read"));
362 assert!(AuthManager::check_permission(&user, "write"));
363 assert!(AuthManager::check_permission(&user, "admin"));
364 }
365
366 #[test]
367 fn check_permission_viewer_limited() {
368 let user = AuthUser {
369 username: "viewer".to_string(),
370 permissions: vec!["read".to_string()],
371 };
372 assert!(AuthManager::check_permission(&user, "read"));
373 assert!(!AuthManager::check_permission(&user, "write"));
374 assert!(!AuthManager::check_permission(&user, "admin"));
375 }
376
377 #[test]
378 fn revoke_token_returns_false_for_unknown() {
379 let mgr = make_test_manager();
380 assert!(!mgr.revoke_token("nonexistent-token"));
381 }
382
383 #[test]
384 fn validate_token_succeeds_before_expiry() {
385 let mgr = make_test_manager();
386 let user = AuthUser {
388 username: "admin".to_string(),
389 permissions: vec!["admin".to_string()],
390 };
391 mgr.tokens
392 .write()
393 .insert("good-token".to_string(), (user, Instant::now()));
394
395 assert!(mgr.validate_token("good-token").is_some());
396 }
397
398 #[test]
399 fn validate_token_fails_after_expiry() {
400 let mgr = make_test_manager().with_token_ttl(Duration::from_secs(0));
402
403 let user = AuthUser {
404 username: "admin".to_string(),
405 permissions: vec!["admin".to_string()],
406 };
407 mgr.tokens
410 .write()
411 .insert("expired-token".to_string(), (user, Instant::now()));
412
413 assert!(mgr.validate_token("expired-token").is_none());
415
416 assert!(mgr.tokens.read().get("expired-token").is_none());
418 }
419
420 #[test]
421 fn with_token_ttl_sets_custom_duration() {
422 let mgr = AuthManager::empty().with_token_ttl(Duration::from_secs(120));
423 assert_eq!(mgr.token_ttl, Duration::from_secs(120));
424 }
425}