volli_manager/connection/
rate_limiter.rs

1use lru::LruCache;
2use std::{collections::VecDeque, net::IpAddr, num::NonZeroUsize};
3use tokio::time::{Duration, Instant};
4
5pub struct RateLimiter {
6    entries: LruCache<IpAddr, VecDeque<Instant>>,
7    max: usize,
8    interval: Duration,
9}
10
11impl RateLimiter {
12    pub fn new(max: usize, interval: Duration) -> Self {
13        Self {
14            entries: LruCache::new(NonZeroUsize::new(1024).unwrap()),
15            max,
16            interval,
17        }
18    }
19
20    pub fn check(&mut self, ip: IpAddr) -> bool {
21        if self.max == 0 {
22            return true;
23        }
24        let now = Instant::now();
25        let deque = self.entries.get_or_insert_mut(ip, VecDeque::new);
26        while let Some(&ts) = deque.front() {
27            if now.duration_since(ts) > self.interval {
28                deque.pop_front();
29            } else {
30                break;
31            }
32        }
33        if deque.len() >= self.max {
34            return false;
35        }
36        deque.push_back(now);
37        true
38    }
39}
40
41#[derive(Clone, Copy)]
42struct BackoffEntry {
43    failures: u32,
44    next_allowed: Instant,
45}
46
47pub struct AuthBackoff {
48    entries: LruCache<IpAddr, BackoffEntry>,
49    base_delay: Duration,
50    max_delay: Duration,
51}
52
53impl AuthBackoff {
54    pub fn new(base_delay: Duration, max_delay: Duration) -> Self {
55        Self {
56            entries: LruCache::new(NonZeroUsize::new(1024).unwrap()),
57            base_delay,
58            max_delay,
59        }
60    }
61
62    pub fn allow(&mut self, ip: IpAddr) -> bool {
63        match self.entries.get(&ip) {
64            Some(entry) => Instant::now() >= entry.next_allowed,
65            None => true,
66        }
67    }
68
69    pub fn record_failure(&mut self, ip: IpAddr) -> Duration {
70        let now = Instant::now();
71        let failures = self
72            .entries
73            .get(&ip)
74            .map(|e| e.failures.saturating_add(1))
75            .unwrap_or(1);
76        let exp = failures.saturating_sub(1).min(6);
77        let delay = self
78            .base_delay
79            .checked_mul(1u32 << exp)
80            .unwrap_or(self.max_delay)
81            .min(self.max_delay);
82        self.entries.put(
83            ip,
84            BackoffEntry {
85                failures,
86                next_allowed: now + delay,
87            },
88        );
89        delay
90    }
91
92    pub fn record_success(&mut self, ip: IpAddr) {
93        self.entries.pop(&ip);
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::{AuthBackoff, RateLimiter};
100    use std::net::IpAddr;
101    use std::str::FromStr;
102    use tokio::time::Duration;
103
104    #[tokio::test(flavor = "current_thread")]
105    async fn rate_limiter_blocks_over_limit() {
106        let mut limiter = RateLimiter::new(2, Duration::from_millis(50));
107        let ip = IpAddr::from_str("127.0.0.1").unwrap();
108        assert!(limiter.check(ip));
109        assert!(limiter.check(ip));
110        assert!(!limiter.check(ip));
111    }
112
113    #[tokio::test(flavor = "current_thread")]
114    async fn auth_backoff_recovers_after_delay() {
115        let mut backoff = AuthBackoff::new(Duration::from_millis(5), Duration::from_millis(20));
116        let ip = IpAddr::from_str("127.0.0.1").unwrap();
117        assert!(backoff.allow(ip));
118        backoff.record_failure(ip);
119        assert!(!backoff.allow(ip));
120        tokio::time::sleep(Duration::from_millis(6)).await;
121        assert!(backoff.allow(ip));
122    }
123
124    #[tokio::test(flavor = "current_thread")]
125    async fn auth_backoff_clears_on_success() {
126        let mut backoff = AuthBackoff::new(Duration::from_millis(5), Duration::from_millis(20));
127        let ip = IpAddr::from_str("127.0.0.1").unwrap();
128        backoff.record_failure(ip);
129        assert!(!backoff.allow(ip));
130        backoff.record_success(ip);
131        assert!(backoff.allow(ip));
132    }
133}