Skip to main content

rustango/
account_lockout.rs

1//! Per-account login lockout — defends against credential stuffing /
2//! brute-force attacks that bypass per-IP rate limits.
3//!
4//! Backed by the cache layer (in-memory or Redis). Each failed login
5//! increments a counter; once it crosses the threshold, the account is
6//! locked for a configurable duration. Successful logins clear the counter.
7//!
8//! ## Quick start
9//!
10//! ```ignore
11//! use rustango::account_lockout::Lockout;
12//! use rustango::cache::InMemoryCache;
13//! use std::sync::Arc;
14//! use std::time::Duration;
15//!
16//! let cache: Arc<dyn rustango::cache::Cache> = Arc::new(InMemoryCache::new());
17//! let lockout = Lockout::new(cache)
18//!     .max_attempts(5)
19//!     .lockout_duration(Duration::from_secs(900));    // 15 min
20//!
21//! // Login handler:
22//! let username = "alice";
23//!
24//! if lockout.is_locked(username).await {
25//!     return Err("account temporarily locked — try again later");
26//! }
27//!
28//! if !verify_credentials(username, password).await? {
29//!     lockout.record_failure(username).await;
30//!     return Err("bad credentials");
31//! }
32//!
33//! lockout.clear(username).await;       // success → reset counter
34//! issue_session(username).await
35//! ```
36//!
37//! ## Why per-account, not per-IP?
38//!
39//! Per-IP rate limiting (`RateLimitLayer::per_ip`) catches one attacker
40//! pounding one endpoint. Per-account lockout catches a botnet trying
41//! the same username from thousands of IPs — the *account* is the rate
42//! axis. Both belong in your stack.
43
44use std::sync::Arc;
45use std::time::Duration;
46
47use crate::cache::Cache;
48
49/// Default attempts before lockout.
50pub const DEFAULT_MAX_ATTEMPTS: u32 = 5;
51/// Default lockout duration (15 minutes).
52pub const DEFAULT_LOCKOUT_DURATION_SECS: u64 = 900;
53
54/// Per-account lockout tracker.
55pub struct Lockout {
56    cache: Arc<dyn Cache>,
57    max_attempts: u32,
58    lockout_duration: Duration,
59    counter_ttl: Duration,
60    key_prefix: String,
61}
62
63impl Lockout {
64    /// New tracker with default thresholds (5 attempts → 15 min lock,
65    /// counter expires after 1 hour of inactivity).
66    #[must_use]
67    pub fn new(cache: Arc<dyn Cache>) -> Self {
68        Self {
69            cache,
70            max_attempts: DEFAULT_MAX_ATTEMPTS,
71            lockout_duration: Duration::from_secs(DEFAULT_LOCKOUT_DURATION_SECS),
72            counter_ttl: Duration::from_secs(3600),
73            key_prefix: "lockout:".to_owned(),
74        }
75    }
76
77    /// Override the attempts threshold.
78    #[must_use]
79    pub fn max_attempts(mut self, n: u32) -> Self {
80        self.max_attempts = n.max(1);
81        self
82    }
83
84    /// Override the lockout duration.
85    #[must_use]
86    pub fn lockout_duration(mut self, d: Duration) -> Self {
87        self.lockout_duration = d;
88        self
89    }
90
91    /// Override how long the failure counter persists between attempts.
92    /// Defaults to 1 hour — counters reset themselves if the user goes
93    /// quiet for a while.
94    #[must_use]
95    pub fn counter_ttl(mut self, d: Duration) -> Self {
96        self.counter_ttl = d;
97        self
98    }
99
100    /// Override the cache-key prefix. Defaults to `"lockout:"`. Useful
101    /// when sharing one cache across multiple lockout namespaces (login
102    /// vs MFA vs API key etc.).
103    #[must_use]
104    pub fn key_prefix(mut self, p: impl Into<String>) -> Self {
105        self.key_prefix = p.into();
106        self
107    }
108
109    /// Check whether `account` is currently locked. Returns `true` to
110    /// reject the login attempt; `false` to proceed with verification.
111    pub async fn is_locked(&self, account: &str) -> bool {
112        self.cache
113            .exists(&self.lock_key(account))
114            .await
115            .unwrap_or(false)
116    }
117
118    /// Record a failed login attempt. Returns the new attempt count.
119    /// When the count reaches `max_attempts`, the account is locked
120    /// for `lockout_duration`.
121    pub async fn record_failure(&self, account: &str) -> u32 {
122        let counter_key = self.counter_key(account);
123        let current: u32 = self
124            .cache
125            .get(&counter_key)
126            .await
127            .ok()
128            .flatten()
129            .and_then(|s| s.parse().ok())
130            .unwrap_or(0);
131        let next = current + 1;
132        let _ = self
133            .cache
134            .set(&counter_key, &next.to_string(), Some(self.counter_ttl))
135            .await;
136        if next >= self.max_attempts {
137            // Set the lock flag with TTL = lockout_duration
138            let _ = self
139                .cache
140                .set(&self.lock_key(account), "1", Some(self.lockout_duration))
141                .await;
142        }
143        next
144    }
145
146    /// Clear the failure counter and any active lock. Call on successful
147    /// authentication.
148    pub async fn clear(&self, account: &str) {
149        let _ = self.cache.delete(&self.counter_key(account)).await;
150        let _ = self.cache.delete(&self.lock_key(account)).await;
151    }
152
153    /// Read the current failure count for an account. 0 when absent.
154    pub async fn attempt_count(&self, account: &str) -> u32 {
155        self.cache
156            .get(&self.counter_key(account))
157            .await
158            .ok()
159            .flatten()
160            .and_then(|s| s.parse().ok())
161            .unwrap_or(0)
162    }
163
164    /// Force-lock an account (e.g. by an admin action).
165    pub async fn force_lock(&self, account: &str) {
166        let _ = self
167            .cache
168            .set(&self.lock_key(account), "1", Some(self.lockout_duration))
169            .await;
170    }
171
172    fn counter_key(&self, account: &str) -> String {
173        format!("{}attempts:{}", self.key_prefix, account)
174    }
175
176    fn lock_key(&self, account: &str) -> String {
177        format!("{}locked:{}", self.key_prefix, account)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::cache::InMemoryCache;
185
186    fn lockout(max: u32) -> Lockout {
187        let cache: Arc<dyn Cache> = Arc::new(InMemoryCache::new());
188        Lockout::new(cache)
189            .max_attempts(max)
190            .lockout_duration(Duration::from_secs(60))
191    }
192
193    #[tokio::test]
194    async fn fresh_account_not_locked() {
195        let l = lockout(5);
196        assert!(!l.is_locked("alice").await);
197        assert_eq!(l.attempt_count("alice").await, 0);
198    }
199
200    #[tokio::test]
201    async fn record_failure_increments_count() {
202        let l = lockout(5);
203        assert_eq!(l.record_failure("alice").await, 1);
204        assert_eq!(l.record_failure("alice").await, 2);
205        assert_eq!(l.attempt_count("alice").await, 2);
206        assert!(!l.is_locked("alice").await);
207    }
208
209    #[tokio::test]
210    async fn locks_at_threshold() {
211        let l = lockout(3);
212        for _ in 0..2 {
213            l.record_failure("alice").await;
214        }
215        assert!(!l.is_locked("alice").await);
216        l.record_failure("alice").await;
217        assert!(l.is_locked("alice").await);
218    }
219
220    #[tokio::test]
221    async fn clear_resets_counter_and_lock() {
222        let l = lockout(2);
223        l.record_failure("alice").await;
224        l.record_failure("alice").await;
225        assert!(l.is_locked("alice").await);
226        l.clear("alice").await;
227        assert!(!l.is_locked("alice").await);
228        assert_eq!(l.attempt_count("alice").await, 0);
229    }
230
231    #[tokio::test]
232    async fn force_lock_works_without_failures() {
233        let l = lockout(5);
234        l.force_lock("alice").await;
235        assert!(l.is_locked("alice").await);
236    }
237
238    #[tokio::test]
239    async fn lockout_expires() {
240        let cache: Arc<dyn Cache> = Arc::new(InMemoryCache::new());
241        let l = Lockout::new(cache)
242            .max_attempts(2)
243            .lockout_duration(Duration::from_millis(100));
244        l.record_failure("alice").await;
245        l.record_failure("alice").await;
246        assert!(l.is_locked("alice").await);
247        tokio::time::sleep(Duration::from_millis(150)).await;
248        assert!(!l.is_locked("alice").await);
249    }
250
251    #[tokio::test]
252    async fn separate_accounts_dont_share_state() {
253        let l = lockout(2);
254        l.record_failure("alice").await;
255        l.record_failure("alice").await;
256        assert!(l.is_locked("alice").await);
257        assert!(!l.is_locked("bob").await);
258        assert_eq!(l.attempt_count("bob").await, 0);
259    }
260
261    #[tokio::test]
262    async fn key_prefix_isolates_namespaces() {
263        let cache: Arc<dyn Cache> = Arc::new(InMemoryCache::new());
264        let l1 = Lockout::new(cache.clone())
265            .key_prefix("login:")
266            .max_attempts(2);
267        let l2 = Lockout::new(cache).key_prefix("mfa:").max_attempts(2);
268        l1.record_failure("alice").await;
269        l1.record_failure("alice").await;
270        assert!(l1.is_locked("alice").await);
271        assert!(
272            !l2.is_locked("alice").await,
273            "MFA namespace shouldn't be locked"
274        );
275    }
276
277    #[tokio::test]
278    async fn max_attempts_floors_at_1() {
279        let l = lockout(0);
280        l.record_failure("alice").await;
281        assert!(
282            l.is_locked("alice").await,
283            "max_attempts(0) should be treated as 1"
284        );
285    }
286}