1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone)]
8pub struct EngineConfig {
9 pub window: Duration,
10 pub max_requests: u32,
11 pub burst_size: u32,
12}
13
14#[derive(Debug)]
15struct Bucket {
16 tokens: f64,
17 last_update: Instant,
18 requests_in_window: u32,
19 window_start: Instant,
20}
21
22impl Bucket {
23 fn new(initial_tokens: f64) -> Self {
24 let now = Instant::now();
25 Self {
26 tokens: initial_tokens,
27 last_update: now,
28 requests_in_window: 0,
29 window_start: now,
30 }
31 }
32
33 fn try_consume(&mut self, cfg: &EngineConfig) -> bool {
34 let now = Instant::now();
35 if now.duration_since(self.window_start) > cfg.window {
36 self.window_start = now;
37 self.requests_in_window = 0;
38 }
39 let elapsed = now.duration_since(self.last_update).as_secs_f64();
40 let refill_rate = cfg.max_requests as f64 / cfg.window.as_secs_f64();
41 self.tokens += elapsed * refill_rate;
42 self.tokens = self.tokens.min(cfg.burst_size as f64);
43 self.last_update = now;
44 if self.tokens >= 1.0 && self.requests_in_window < cfg.max_requests {
45 self.tokens -= 1.0;
46 self.requests_in_window += 1;
47 true
48 } else {
49 false
50 }
51 }
52}
53
54#[derive(Debug)]
55pub struct Engine<K: Eq + Hash + Clone + ToString> {
56 cfg: EngineConfig,
57 global: Mutex<Bucket>,
58 keyed: RwLock<HashMap<K, Bucket>>,
59}
60
61impl<K: Eq + Hash + Clone + ToString> Engine<K> {
62 pub fn new(cfg: EngineConfig) -> Self {
63 let burst_size = cfg.burst_size as f64;
64 Self {
65 cfg,
66 global: Mutex::new(Bucket::new(burst_size)),
67 keyed: RwLock::new(HashMap::new()),
68 }
69 }
70
71 pub fn try_consume_global(&self) -> bool {
72 match self.global.lock() {
73 Ok(mut guard) => guard.try_consume(&self.cfg),
74 Err(_poisoned) => {
75 false
78 }
79 }
80 }
81
82 pub fn try_consume_key(&self, key: &K) -> bool {
83 let mut map = self.keyed.write();
84 let bucket = map
85 .entry(key.clone())
86 .or_insert_with(|| Bucket::new(self.cfg.burst_size as f64));
87 bucket.try_consume(&self.cfg)
88 }
89}
90
91pub type SharedEngine<K> = Arc<Engine<K>>;
92
93use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
98use thiserror::Error;
99
100#[derive(Debug, Error)]
102pub enum JoinRateLimitError {
103 #[error("global join rate limit exceeded: max {max_per_minute} joins per minute")]
105 GlobalLimitExceeded { max_per_minute: u32 },
106
107 #[error("subnet /64 join rate limit exceeded: max {max_per_hour} joins per hour from this /64")]
109 Subnet64LimitExceeded { max_per_hour: u32 },
110
111 #[error("subnet /48 join rate limit exceeded: max {max_per_hour} joins per hour from this /48")]
113 Subnet48LimitExceeded { max_per_hour: u32 },
114
115 #[error("subnet /24 join rate limit exceeded: max {max_per_hour} joins per hour from this /24")]
117 Subnet24LimitExceeded { max_per_hour: u32 },
118}
119
120#[derive(Debug, Clone)]
122pub struct JoinRateLimiterConfig {
123 pub max_joins_per_64_per_hour: u32,
126
127 pub max_joins_per_48_per_hour: u32,
129
130 pub max_joins_per_24_per_hour: u32,
132
133 pub max_global_joins_per_minute: u32,
136
137 pub global_burst_size: u32,
139}
140
141impl Default for JoinRateLimiterConfig {
142 fn default() -> Self {
143 Self {
144 max_joins_per_64_per_hour: 1,
145 max_joins_per_48_per_hour: 5,
146 max_joins_per_24_per_hour: 3,
147 max_global_joins_per_minute: 100,
148 global_burst_size: 10,
149 }
150 }
151}
152
153#[derive(Debug)]
178pub struct JoinRateLimiter {
179 config: JoinRateLimiterConfig,
180 per_subnet_64: Engine<Ipv6Addr>,
182 per_subnet_48: Engine<Ipv6Addr>,
184 per_subnet_24: Engine<Ipv4Addr>,
186 global: Engine<u8>,
188}
189
190impl JoinRateLimiter {
191 pub fn new(config: JoinRateLimiterConfig) -> Self {
193 let subnet_64_config = EngineConfig {
195 window: Duration::from_secs(3600), max_requests: config.max_joins_per_64_per_hour,
197 burst_size: config.max_joins_per_64_per_hour, };
199
200 let subnet_48_config = EngineConfig {
202 window: Duration::from_secs(3600), max_requests: config.max_joins_per_48_per_hour,
204 burst_size: config.max_joins_per_48_per_hour, };
206
207 let subnet_24_config = EngineConfig {
209 window: Duration::from_secs(3600), max_requests: config.max_joins_per_24_per_hour,
211 burst_size: config.max_joins_per_24_per_hour, };
213
214 let global_config = EngineConfig {
216 window: Duration::from_secs(60), max_requests: config.max_global_joins_per_minute,
218 burst_size: config.global_burst_size,
219 };
220
221 Self {
222 config,
223 per_subnet_64: Engine::new(subnet_64_config),
224 per_subnet_48: Engine::new(subnet_48_config),
225 per_subnet_24: Engine::new(subnet_24_config),
226 global: Engine::new(global_config),
227 }
228 }
229
230 pub fn check_join_allowed(&self, ip: &IpAddr) -> Result<(), JoinRateLimitError> {
242 if !self.global.try_consume_key(&0u8) {
244 return Err(JoinRateLimitError::GlobalLimitExceeded {
245 max_per_minute: self.config.max_global_joins_per_minute,
246 });
247 }
248
249 match ip {
251 IpAddr::V6(ipv6) => {
252 let subnet_64 = extract_ipv6_subnet_64(ipv6);
254 if !self.per_subnet_64.try_consume_key(&subnet_64) {
255 return Err(JoinRateLimitError::Subnet64LimitExceeded {
256 max_per_hour: self.config.max_joins_per_64_per_hour,
257 });
258 }
259
260 let subnet_48 = extract_ipv6_subnet_48(ipv6);
262 if !self.per_subnet_48.try_consume_key(&subnet_48) {
263 return Err(JoinRateLimitError::Subnet48LimitExceeded {
264 max_per_hour: self.config.max_joins_per_48_per_hour,
265 });
266 }
267 }
268 IpAddr::V4(ipv4) => {
269 let subnet_24 = extract_ipv4_subnet_24(ipv4);
271 if !self.per_subnet_24.try_consume_key(&subnet_24) {
272 return Err(JoinRateLimitError::Subnet24LimitExceeded {
273 max_per_hour: self.config.max_joins_per_24_per_hour,
274 });
275 }
276 }
277 }
278
279 Ok(())
280 }
281
282 pub fn config(&self) -> &JoinRateLimiterConfig {
284 &self.config
285 }
286}
287
288#[inline]
293pub fn extract_ipv6_subnet_64(addr: &Ipv6Addr) -> Ipv6Addr {
294 let octets = addr.octets();
295 let mut subnet = [0u8; 16];
296 subnet[..8].copy_from_slice(&octets[..8]); Ipv6Addr::from(subnet)
298}
299
300#[inline]
304pub fn extract_ipv6_subnet_48(addr: &Ipv6Addr) -> Ipv6Addr {
305 let octets = addr.octets();
306 let mut subnet = [0u8; 16];
307 subnet[..6].copy_from_slice(&octets[..6]); Ipv6Addr::from(subnet)
309}
310
311#[inline]
315pub fn extract_ipv6_subnet_32(addr: &Ipv6Addr) -> Ipv6Addr {
316 let octets = addr.octets();
317 let mut subnet = [0u8; 16];
318 subnet[..4].copy_from_slice(&octets[..4]); Ipv6Addr::from(subnet)
320}
321
322#[inline]
326pub fn extract_ipv4_subnet_24(addr: &Ipv4Addr) -> Ipv4Addr {
327 let octets = addr.octets();
328 Ipv4Addr::new(octets[0], octets[1], octets[2], 0)
329}
330
331#[inline]
333pub fn extract_ipv4_subnet_16(addr: &Ipv4Addr) -> Ipv4Addr {
334 let octets = addr.octets();
335 Ipv4Addr::new(octets[0], octets[1], 0, 0)
336}
337
338#[inline]
340pub fn extract_ipv4_subnet_8(addr: &Ipv4Addr) -> Ipv4Addr {
341 let octets = addr.octets();
342 Ipv4Addr::new(octets[0], 0, 0, 0)
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_extract_ipv6_subnet_64() {
351 let addr: Ipv6Addr = "2001:db8:85a3:1234:8a2e:370:7334:1234".parse().unwrap();
352 let subnet = extract_ipv6_subnet_64(&addr);
353 assert_eq!(subnet.to_string(), "2001:db8:85a3:1234::");
354 }
355
356 #[test]
357 fn test_extract_ipv6_subnet_48() {
358 let addr: Ipv6Addr = "2001:db8:85a3:1234:8a2e:370:7334:1234".parse().unwrap();
359 let subnet = extract_ipv6_subnet_48(&addr);
360 assert_eq!(subnet.to_string(), "2001:db8:85a3::");
361 }
362
363 #[test]
364 fn test_extract_ipv4_subnet_24() {
365 let addr: Ipv4Addr = "192.168.1.100".parse().unwrap();
366 let subnet = extract_ipv4_subnet_24(&addr);
367 assert_eq!(subnet.to_string(), "192.168.1.0");
368 }
369
370 #[test]
371 fn test_join_rate_limiter_allows_first_join() {
372 let limiter = JoinRateLimiter::new(JoinRateLimiterConfig::default());
373 let ip: IpAddr = "2001:db8::1".parse().unwrap();
374 assert!(limiter.check_join_allowed(&ip).is_ok());
375 }
376
377 #[test]
378 fn test_join_rate_limiter_blocks_second_from_same_64() {
379 let config = JoinRateLimiterConfig {
380 max_joins_per_64_per_hour: 1,
381 ..Default::default()
382 };
383 let limiter = JoinRateLimiter::new(config);
384
385 let ip1: IpAddr = "2001:db8::1".parse().unwrap();
387 assert!(limiter.check_join_allowed(&ip1).is_ok());
388
389 let ip2: IpAddr = "2001:db8::2".parse().unwrap();
391 let result = limiter.check_join_allowed(&ip2);
392 assert!(matches!(
393 result,
394 Err(JoinRateLimitError::Subnet64LimitExceeded { .. })
395 ));
396 }
397
398 #[test]
399 fn test_join_rate_limiter_allows_different_subnets() {
400 let config = JoinRateLimiterConfig {
401 max_joins_per_64_per_hour: 1,
402 ..Default::default()
403 };
404 let limiter = JoinRateLimiter::new(config);
405
406 let ip1: IpAddr = "2001:db8:1::1".parse().unwrap();
408 assert!(limiter.check_join_allowed(&ip1).is_ok());
409
410 let ip2: IpAddr = "2001:db8:2::1".parse().unwrap();
412 assert!(limiter.check_join_allowed(&ip2).is_ok());
413 }
414
415 #[test]
416 fn test_join_rate_limiter_ipv4() {
417 let config = JoinRateLimiterConfig {
418 max_joins_per_24_per_hour: 2,
419 ..Default::default()
420 };
421 let limiter = JoinRateLimiter::new(config);
422
423 let ip1: IpAddr = "192.168.1.1".parse().unwrap();
425 let ip2: IpAddr = "192.168.1.2".parse().unwrap();
426 assert!(limiter.check_join_allowed(&ip1).is_ok());
427 assert!(limiter.check_join_allowed(&ip2).is_ok());
428
429 let ip3: IpAddr = "192.168.1.3".parse().unwrap();
431 let result = limiter.check_join_allowed(&ip3);
432 assert!(matches!(
433 result,
434 Err(JoinRateLimitError::Subnet24LimitExceeded { .. })
435 ));
436 }
437}