rustango/
account_lockout.rs1use std::sync::Arc;
45use std::time::Duration;
46
47use crate::cache::Cache;
48
49pub const DEFAULT_MAX_ATTEMPTS: u32 = 5;
51pub const DEFAULT_LOCKOUT_DURATION_SECS: u64 = 900;
53
54pub struct Lockout {
56 cache: Arc<dyn Cache>,
57 max_attempts: u32,
58 lockout_duration: Duration,
59 counter_ttl: Duration,
60 key_prefix: String,
61}
62
63impl Lockout {
64 #[must_use]
67 pub fn new(cache: Arc<dyn Cache>) -> Self {
68 Self {
69 cache,
70 max_attempts: DEFAULT_MAX_ATTEMPTS,
71 lockout_duration: Duration::from_secs(DEFAULT_LOCKOUT_DURATION_SECS),
72 counter_ttl: Duration::from_secs(3600),
73 key_prefix: "lockout:".to_owned(),
74 }
75 }
76
77 #[must_use]
79 pub fn max_attempts(mut self, n: u32) -> Self {
80 self.max_attempts = n.max(1);
81 self
82 }
83
84 #[must_use]
86 pub fn lockout_duration(mut self, d: Duration) -> Self {
87 self.lockout_duration = d;
88 self
89 }
90
91 #[must_use]
95 pub fn counter_ttl(mut self, d: Duration) -> Self {
96 self.counter_ttl = d;
97 self
98 }
99
100 #[must_use]
104 pub fn key_prefix(mut self, p: impl Into<String>) -> Self {
105 self.key_prefix = p.into();
106 self
107 }
108
109 pub async fn is_locked(&self, account: &str) -> bool {
112 self.cache
113 .exists(&self.lock_key(account))
114 .await
115 .unwrap_or(false)
116 }
117
118 pub async fn record_failure(&self, account: &str) -> u32 {
122 let counter_key = self.counter_key(account);
123 let current: u32 = self
124 .cache
125 .get(&counter_key)
126 .await
127 .ok()
128 .flatten()
129 .and_then(|s| s.parse().ok())
130 .unwrap_or(0);
131 let next = current + 1;
132 let _ = self
133 .cache
134 .set(&counter_key, &next.to_string(), Some(self.counter_ttl))
135 .await;
136 if next >= self.max_attempts {
137 let _ = self
139 .cache
140 .set(&self.lock_key(account), "1", Some(self.lockout_duration))
141 .await;
142 }
143 next
144 }
145
146 pub async fn clear(&self, account: &str) {
149 let _ = self.cache.delete(&self.counter_key(account)).await;
150 let _ = self.cache.delete(&self.lock_key(account)).await;
151 }
152
153 pub async fn attempt_count(&self, account: &str) -> u32 {
155 self.cache
156 .get(&self.counter_key(account))
157 .await
158 .ok()
159 .flatten()
160 .and_then(|s| s.parse().ok())
161 .unwrap_or(0)
162 }
163
164 pub async fn force_lock(&self, account: &str) {
166 let _ = self
167 .cache
168 .set(&self.lock_key(account), "1", Some(self.lockout_duration))
169 .await;
170 }
171
172 fn counter_key(&self, account: &str) -> String {
173 format!("{}attempts:{}", self.key_prefix, account)
174 }
175
176 fn lock_key(&self, account: &str) -> String {
177 format!("{}locked:{}", self.key_prefix, account)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::cache::InMemoryCache;
185
186 fn lockout(max: u32) -> Lockout {
187 let cache: Arc<dyn Cache> = Arc::new(InMemoryCache::new());
188 Lockout::new(cache)
189 .max_attempts(max)
190 .lockout_duration(Duration::from_secs(60))
191 }
192
193 #[tokio::test]
194 async fn fresh_account_not_locked() {
195 let l = lockout(5);
196 assert!(!l.is_locked("alice").await);
197 assert_eq!(l.attempt_count("alice").await, 0);
198 }
199
200 #[tokio::test]
201 async fn record_failure_increments_count() {
202 let l = lockout(5);
203 assert_eq!(l.record_failure("alice").await, 1);
204 assert_eq!(l.record_failure("alice").await, 2);
205 assert_eq!(l.attempt_count("alice").await, 2);
206 assert!(!l.is_locked("alice").await);
207 }
208
209 #[tokio::test]
210 async fn locks_at_threshold() {
211 let l = lockout(3);
212 for _ in 0..2 {
213 l.record_failure("alice").await;
214 }
215 assert!(!l.is_locked("alice").await);
216 l.record_failure("alice").await;
217 assert!(l.is_locked("alice").await);
218 }
219
220 #[tokio::test]
221 async fn clear_resets_counter_and_lock() {
222 let l = lockout(2);
223 l.record_failure("alice").await;
224 l.record_failure("alice").await;
225 assert!(l.is_locked("alice").await);
226 l.clear("alice").await;
227 assert!(!l.is_locked("alice").await);
228 assert_eq!(l.attempt_count("alice").await, 0);
229 }
230
231 #[tokio::test]
232 async fn force_lock_works_without_failures() {
233 let l = lockout(5);
234 l.force_lock("alice").await;
235 assert!(l.is_locked("alice").await);
236 }
237
238 #[tokio::test]
239 async fn lockout_expires() {
240 let cache: Arc<dyn Cache> = Arc::new(InMemoryCache::new());
241 let l = Lockout::new(cache)
242 .max_attempts(2)
243 .lockout_duration(Duration::from_millis(100));
244 l.record_failure("alice").await;
245 l.record_failure("alice").await;
246 assert!(l.is_locked("alice").await);
247 tokio::time::sleep(Duration::from_millis(150)).await;
248 assert!(!l.is_locked("alice").await);
249 }
250
251 #[tokio::test]
252 async fn separate_accounts_dont_share_state() {
253 let l = lockout(2);
254 l.record_failure("alice").await;
255 l.record_failure("alice").await;
256 assert!(l.is_locked("alice").await);
257 assert!(!l.is_locked("bob").await);
258 assert_eq!(l.attempt_count("bob").await, 0);
259 }
260
261 #[tokio::test]
262 async fn key_prefix_isolates_namespaces() {
263 let cache: Arc<dyn Cache> = Arc::new(InMemoryCache::new());
264 let l1 = Lockout::new(cache.clone())
265 .key_prefix("login:")
266 .max_attempts(2);
267 let l2 = Lockout::new(cache).key_prefix("mfa:").max_attempts(2);
268 l1.record_failure("alice").await;
269 l1.record_failure("alice").await;
270 assert!(l1.is_locked("alice").await);
271 assert!(
272 !l2.is_locked("alice").await,
273 "MFA namespace shouldn't be locked"
274 );
275 }
276
277 #[tokio::test]
278 async fn max_attempts_floors_at_1() {
279 let l = lockout(0);
280 l.record_failure("alice").await;
281 assert!(
282 l.is_locked("alice").await,
283 "max_attempts(0) should be treated as 1"
284 );
285 }
286}