volli_manager/connection/
rate_limiter.rs1use lru::LruCache;
2use std::{collections::VecDeque, net::IpAddr, num::NonZeroUsize};
3use tokio::time::{Duration, Instant};
4
5pub struct RateLimiter {
6 entries: LruCache<IpAddr, VecDeque<Instant>>,
7 max: usize,
8 interval: Duration,
9}
10
11impl RateLimiter {
12 pub fn new(max: usize, interval: Duration) -> Self {
13 Self {
14 entries: LruCache::new(NonZeroUsize::new(1024).unwrap()),
15 max,
16 interval,
17 }
18 }
19
20 pub fn check(&mut self, ip: IpAddr) -> bool {
21 if self.max == 0 {
22 return true;
23 }
24 let now = Instant::now();
25 let deque = self.entries.get_or_insert_mut(ip, VecDeque::new);
26 while let Some(&ts) = deque.front() {
27 if now.duration_since(ts) > self.interval {
28 deque.pop_front();
29 } else {
30 break;
31 }
32 }
33 if deque.len() >= self.max {
34 return false;
35 }
36 deque.push_back(now);
37 true
38 }
39}
40
41#[derive(Clone, Copy)]
42struct BackoffEntry {
43 failures: u32,
44 next_allowed: Instant,
45}
46
47pub struct AuthBackoff {
48 entries: LruCache<IpAddr, BackoffEntry>,
49 base_delay: Duration,
50 max_delay: Duration,
51}
52
53impl AuthBackoff {
54 pub fn new(base_delay: Duration, max_delay: Duration) -> Self {
55 Self {
56 entries: LruCache::new(NonZeroUsize::new(1024).unwrap()),
57 base_delay,
58 max_delay,
59 }
60 }
61
62 pub fn allow(&mut self, ip: IpAddr) -> bool {
63 match self.entries.get(&ip) {
64 Some(entry) => Instant::now() >= entry.next_allowed,
65 None => true,
66 }
67 }
68
69 pub fn record_failure(&mut self, ip: IpAddr) -> Duration {
70 let now = Instant::now();
71 let failures = self
72 .entries
73 .get(&ip)
74 .map(|e| e.failures.saturating_add(1))
75 .unwrap_or(1);
76 let exp = failures.saturating_sub(1).min(6);
77 let delay = self
78 .base_delay
79 .checked_mul(1u32 << exp)
80 .unwrap_or(self.max_delay)
81 .min(self.max_delay);
82 self.entries.put(
83 ip,
84 BackoffEntry {
85 failures,
86 next_allowed: now + delay,
87 },
88 );
89 delay
90 }
91
92 pub fn record_success(&mut self, ip: IpAddr) {
93 self.entries.pop(&ip);
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::{AuthBackoff, RateLimiter};
100 use std::net::IpAddr;
101 use std::str::FromStr;
102 use tokio::time::Duration;
103
104 #[tokio::test(flavor = "current_thread")]
105 async fn rate_limiter_blocks_over_limit() {
106 let mut limiter = RateLimiter::new(2, Duration::from_millis(50));
107 let ip = IpAddr::from_str("127.0.0.1").unwrap();
108 assert!(limiter.check(ip));
109 assert!(limiter.check(ip));
110 assert!(!limiter.check(ip));
111 }
112
113 #[tokio::test(flavor = "current_thread")]
114 async fn auth_backoff_recovers_after_delay() {
115 let mut backoff = AuthBackoff::new(Duration::from_millis(5), Duration::from_millis(20));
116 let ip = IpAddr::from_str("127.0.0.1").unwrap();
117 assert!(backoff.allow(ip));
118 backoff.record_failure(ip);
119 assert!(!backoff.allow(ip));
120 tokio::time::sleep(Duration::from_millis(6)).await;
121 assert!(backoff.allow(ip));
122 }
123
124 #[tokio::test(flavor = "current_thread")]
125 async fn auth_backoff_clears_on_success() {
126 let mut backoff = AuthBackoff::new(Duration::from_millis(5), Duration::from_millis(20));
127 let ip = IpAddr::from_str("127.0.0.1").unwrap();
128 backoff.record_failure(ip);
129 assert!(!backoff.allow(ip));
130 backoff.record_success(ip);
131 assert!(backoff.allow(ip));
132 }
133}