Skip to main content

trojan_auth/
memory.rs

1//! In-memory authentication backend.
2
3use std::collections::HashMap;
4
5use async_trait::async_trait;
6
7use crate::error::AuthError;
8use crate::hash::sha224_hex;
9use crate::result::AuthResult;
10use crate::traits::AuthBackend;
11
12/// Simple in-memory authentication backend using a hash set.
13///
14/// This is suitable for small deployments with a fixed set of users.
15/// For dynamic user management or large user bases, consider using
16/// a database-backed backend.
17#[derive(Debug, Clone)]
18pub struct MemoryAuth {
19    /// Map from hash to optional user ID
20    users: HashMap<String, Option<String>>,
21}
22
23impl MemoryAuth {
24    /// Create a new empty auth backend.
25    #[inline]
26    pub fn new() -> Self {
27        Self {
28            users: HashMap::new(),
29        }
30    }
31
32    /// Create from pre-computed SHA224 hashes.
33    ///
34    /// # Example
35    /// ```
36    /// use trojan_auth::MemoryAuth;
37    ///
38    /// let auth = MemoryAuth::from_hashes(["abc123...", "def456..."]);
39    /// ```
40    pub fn from_hashes<I, S>(hashes: I) -> Self
41    where
42        I: IntoIterator<Item = S>,
43        S: Into<String>,
44    {
45        let users = hashes.into_iter().map(|h| (h.into(), None)).collect();
46        Self { users }
47    }
48
49    /// Create from plaintext passwords (will be hashed).
50    ///
51    /// # Example
52    /// ```
53    /// use trojan_auth::MemoryAuth;
54    ///
55    /// let auth = MemoryAuth::from_passwords(["password1", "password2"]);
56    /// ```
57    pub fn from_passwords<I, S>(passwords: I) -> Self
58    where
59        I: IntoIterator<Item = S>,
60        S: AsRef<str>,
61    {
62        let users = passwords
63            .into_iter()
64            .map(|p| (sha224_hex(p.as_ref()), None))
65            .collect();
66        Self { users }
67    }
68
69    /// Create from password-to-user-id pairs.
70    ///
71    /// # Example
72    /// ```
73    /// use trojan_auth::MemoryAuth;
74    ///
75    /// let auth = MemoryAuth::from_passwords_with_ids([
76    ///     ("password1", "user1"),
77    ///     ("password2", "user2"),
78    /// ]);
79    /// ```
80    pub fn from_passwords_with_ids<I, P, U>(pairs: I) -> Self
81    where
82        I: IntoIterator<Item = (P, U)>,
83        P: AsRef<str>,
84        U: Into<String>,
85    {
86        let users = pairs
87            .into_iter()
88            .map(|(p, u)| (sha224_hex(p.as_ref()), Some(u.into())))
89            .collect();
90        Self { users }
91    }
92
93    /// Add a user with a plaintext password.
94    #[inline]
95    pub fn add_password(&mut self, password: &str, user_id: Option<String>) {
96        self.users.insert(sha224_hex(password), user_id);
97    }
98
99    /// Add a user with a pre-computed hash.
100    #[inline]
101    pub fn add_hash(&mut self, hash: String, user_id: Option<String>) {
102        self.users.insert(hash, user_id);
103    }
104
105    /// Remove a user by hash.
106    #[inline]
107    pub fn remove_hash(&mut self, hash: &str) -> bool {
108        self.users.remove(hash).is_some()
109    }
110
111    /// Get the number of registered users.
112    #[inline]
113    pub fn len(&self) -> usize {
114        self.users.len()
115    }
116
117    /// Check if no users are registered.
118    #[inline]
119    pub fn is_empty(&self) -> bool {
120        self.users.is_empty()
121    }
122
123    /// Check if a hash is registered.
124    #[inline]
125    pub fn contains(&self, hash: &str) -> bool {
126        self.users.contains_key(hash)
127    }
128}
129
130impl Default for MemoryAuth {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136#[async_trait]
137impl AuthBackend for MemoryAuth {
138    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
139        match self.users.get(hash) {
140            Some(user_id) => Ok(AuthResult {
141                user_id: user_id.clone(),
142                metadata: None,
143            }),
144            None => Err(AuthError::Invalid),
145        }
146    }
147}
148
149// Backward compatibility alias
150#[doc(hidden)]
151pub type HashSetAuth = MemoryAuth;
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[tokio::test]
158    async fn test_from_passwords() {
159        let auth = MemoryAuth::from_passwords(["test123", "password"]);
160        assert_eq!(auth.len(), 2);
161
162        let hash = sha224_hex("test123");
163        assert!(auth.verify(&hash).await.is_ok());
164
165        let wrong_hash = sha224_hex("wrong");
166        assert!(auth.verify(&wrong_hash).await.is_err());
167    }
168
169    #[tokio::test]
170    async fn test_with_user_ids() {
171        let auth = MemoryAuth::from_passwords_with_ids([("pass1", "user1"), ("pass2", "user2")]);
172
173        let hash = sha224_hex("pass1");
174        let result = auth.verify(&hash).await.unwrap();
175        assert_eq!(result.user_id, Some("user1".to_string()));
176    }
177
178    #[test]
179    fn test_add_remove() {
180        let mut auth = MemoryAuth::new();
181        assert!(auth.is_empty());
182
183        auth.add_password("test", Some("user".to_string()));
184        assert_eq!(auth.len(), 1);
185
186        let hash = sha224_hex("test");
187        assert!(auth.contains(&hash));
188
189        auth.remove_hash(&hash);
190        assert!(auth.is_empty());
191    }
192}