1use lru::LruCache;
2use parking_lot::RwLock;
3use std::hash::Hash;
4use std::num::NonZeroUsize;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8const MAX_RATE_LIMIT_KEYS: usize = 100_000;
10
11#[derive(Debug, Clone)]
12pub struct EngineConfig {
13 pub window: Duration,
14 pub max_requests: u32,
15 pub burst_size: u32,
16}
17
18#[derive(Debug)]
19struct Bucket {
20 tokens: f64,
21 last_update: Instant,
22 requests_in_window: u32,
23 window_start: Instant,
24}
25
26impl Bucket {
27 fn new(initial_tokens: f64) -> Self {
28 let now = Instant::now();
29 Self {
30 tokens: initial_tokens,
31 last_update: now,
32 requests_in_window: 0,
33 window_start: now,
34 }
35 }
36
37 fn try_consume(&mut self, cfg: &EngineConfig) -> bool {
38 let now = Instant::now();
39 if now.duration_since(self.window_start) > cfg.window {
40 self.window_start = now;
41 self.requests_in_window = 0;
42 }
43 let elapsed = now.duration_since(self.last_update).as_secs_f64();
44 let refill_rate = cfg.max_requests as f64 / cfg.window.as_secs_f64();
45 self.tokens += elapsed * refill_rate;
46 self.tokens = self.tokens.min(cfg.burst_size as f64);
47 self.last_update = now;
48 if self.tokens >= 1.0 && self.requests_in_window < cfg.max_requests {
49 self.tokens -= 1.0;
50 self.requests_in_window += 1;
51 true
52 } else {
53 false
54 }
55 }
56}
57
58#[derive(Debug)]
59pub struct Engine<K: Eq + Hash + Clone + ToString> {
60 cfg: EngineConfig,
61 global: Mutex<Bucket>,
62 keyed: RwLock<LruCache<K, Bucket>>,
64}
65
66impl<K: Eq + Hash + Clone + ToString> Engine<K> {
67 pub fn new(cfg: EngineConfig) -> Self {
68 let burst_size = cfg.burst_size as f64;
69 let cache_size = NonZeroUsize::new(MAX_RATE_LIMIT_KEYS).unwrap_or(NonZeroUsize::MIN);
71 Self {
72 cfg,
73 global: Mutex::new(Bucket::new(burst_size)),
74 keyed: RwLock::new(LruCache::new(cache_size)),
75 }
76 }
77
78 pub fn try_consume_global(&self) -> bool {
79 match self.global.lock() {
80 Ok(mut guard) => guard.try_consume(&self.cfg),
81 Err(_poisoned) => {
82 false
85 }
86 }
87 }
88
89 pub fn try_consume_key(&self, key: &K) -> bool {
90 let mut map = self.keyed.write();
91 if let Some(bucket) = map.get_mut(key) {
93 bucket.try_consume(&self.cfg)
94 } else {
95 let mut bucket = Bucket::new(self.cfg.burst_size as f64);
96 let result = bucket.try_consume(&self.cfg);
97 map.put(key.clone(), bucket);
98 result
99 }
100 }
101}
102
103pub type SharedEngine<K> = Arc<Engine<K>>;
104
105use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
110use thiserror::Error;
111
112#[derive(Debug, Error)]
114pub enum JoinRateLimitError {
115 #[error("global join rate limit exceeded: max {max_per_minute} joins per minute")]
117 GlobalLimitExceeded { max_per_minute: u32 },
118
119 #[error("subnet /64 join rate limit exceeded: max {max_per_hour} joins per hour from this /64")]
121 Subnet64LimitExceeded { max_per_hour: u32 },
122
123 #[error("subnet /48 join rate limit exceeded: max {max_per_hour} joins per hour from this /48")]
125 Subnet48LimitExceeded { max_per_hour: u32 },
126
127 #[error("subnet /24 join rate limit exceeded: max {max_per_hour} joins per hour from this /24")]
129 Subnet24LimitExceeded { max_per_hour: u32 },
130}
131
132#[derive(Debug, Clone)]
134pub struct JoinRateLimiterConfig {
135 pub max_joins_per_64_per_hour: u32,
138
139 pub max_joins_per_48_per_hour: u32,
141
142 pub max_joins_per_24_per_hour: u32,
144
145 pub max_global_joins_per_minute: u32,
148
149 pub global_burst_size: u32,
151}
152
153impl Default for JoinRateLimiterConfig {
154 fn default() -> Self {
155 Self {
156 max_joins_per_64_per_hour: 1,
157 max_joins_per_48_per_hour: 5,
158 max_joins_per_24_per_hour: 3,
159 max_global_joins_per_minute: 100,
160 global_burst_size: 10,
161 }
162 }
163}
164
165#[derive(Debug)]
190pub struct JoinRateLimiter {
191 config: JoinRateLimiterConfig,
192 per_subnet_64: Engine<Ipv6Addr>,
194 per_subnet_48: Engine<Ipv6Addr>,
196 per_subnet_24: Engine<Ipv4Addr>,
198 global: Engine<u8>,
200}
201
202impl JoinRateLimiter {
203 pub fn new(config: JoinRateLimiterConfig) -> Self {
205 let subnet_64_config = EngineConfig {
207 window: Duration::from_secs(3600), max_requests: config.max_joins_per_64_per_hour,
209 burst_size: config.max_joins_per_64_per_hour, };
211
212 let subnet_48_config = EngineConfig {
214 window: Duration::from_secs(3600), max_requests: config.max_joins_per_48_per_hour,
216 burst_size: config.max_joins_per_48_per_hour, };
218
219 let subnet_24_config = EngineConfig {
221 window: Duration::from_secs(3600), max_requests: config.max_joins_per_24_per_hour,
223 burst_size: config.max_joins_per_24_per_hour, };
225
226 let global_config = EngineConfig {
228 window: Duration::from_secs(60), max_requests: config.max_global_joins_per_minute,
230 burst_size: config.global_burst_size,
231 };
232
233 Self {
234 config,
235 per_subnet_64: Engine::new(subnet_64_config),
236 per_subnet_48: Engine::new(subnet_48_config),
237 per_subnet_24: Engine::new(subnet_24_config),
238 global: Engine::new(global_config),
239 }
240 }
241
242 pub fn check_join_allowed(&self, ip: &IpAddr) -> Result<(), JoinRateLimitError> {
254 if !self.global.try_consume_key(&0u8) {
256 return Err(JoinRateLimitError::GlobalLimitExceeded {
257 max_per_minute: self.config.max_global_joins_per_minute,
258 });
259 }
260
261 match ip {
263 IpAddr::V6(ipv6) => {
264 let subnet_64 = extract_ipv6_subnet_64(ipv6);
266 if !self.per_subnet_64.try_consume_key(&subnet_64) {
267 return Err(JoinRateLimitError::Subnet64LimitExceeded {
268 max_per_hour: self.config.max_joins_per_64_per_hour,
269 });
270 }
271
272 let subnet_48 = extract_ipv6_subnet_48(ipv6);
274 if !self.per_subnet_48.try_consume_key(&subnet_48) {
275 return Err(JoinRateLimitError::Subnet48LimitExceeded {
276 max_per_hour: self.config.max_joins_per_48_per_hour,
277 });
278 }
279 }
280 IpAddr::V4(ipv4) => {
281 let subnet_24 = extract_ipv4_subnet_24(ipv4);
283 if !self.per_subnet_24.try_consume_key(&subnet_24) {
284 return Err(JoinRateLimitError::Subnet24LimitExceeded {
285 max_per_hour: self.config.max_joins_per_24_per_hour,
286 });
287 }
288 }
289 }
290
291 Ok(())
292 }
293
294 pub fn config(&self) -> &JoinRateLimiterConfig {
296 &self.config
297 }
298}
299
300#[inline]
305pub fn extract_ipv6_subnet_64(addr: &Ipv6Addr) -> Ipv6Addr {
306 let octets = addr.octets();
307 let mut subnet = [0u8; 16];
308 subnet[..8].copy_from_slice(&octets[..8]); Ipv6Addr::from(subnet)
310}
311
312#[inline]
316pub fn extract_ipv6_subnet_48(addr: &Ipv6Addr) -> Ipv6Addr {
317 let octets = addr.octets();
318 let mut subnet = [0u8; 16];
319 subnet[..6].copy_from_slice(&octets[..6]); Ipv6Addr::from(subnet)
321}
322
323#[inline]
327pub fn extract_ipv6_subnet_32(addr: &Ipv6Addr) -> Ipv6Addr {
328 let octets = addr.octets();
329 let mut subnet = [0u8; 16];
330 subnet[..4].copy_from_slice(&octets[..4]); Ipv6Addr::from(subnet)
332}
333
334#[inline]
338pub fn extract_ipv4_subnet_24(addr: &Ipv4Addr) -> Ipv4Addr {
339 let octets = addr.octets();
340 Ipv4Addr::new(octets[0], octets[1], octets[2], 0)
341}
342
343#[inline]
345pub fn extract_ipv4_subnet_16(addr: &Ipv4Addr) -> Ipv4Addr {
346 let octets = addr.octets();
347 Ipv4Addr::new(octets[0], octets[1], 0, 0)
348}
349
350#[inline]
352pub fn extract_ipv4_subnet_8(addr: &Ipv4Addr) -> Ipv4Addr {
353 let octets = addr.octets();
354 Ipv4Addr::new(octets[0], 0, 0, 0)
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_extract_ipv6_subnet_64() {
363 let addr: Ipv6Addr = "2001:db8:85a3:1234:8a2e:370:7334:1234".parse().unwrap();
364 let subnet = extract_ipv6_subnet_64(&addr);
365 assert_eq!(subnet.to_string(), "2001:db8:85a3:1234::");
366 }
367
368 #[test]
369 fn test_extract_ipv6_subnet_48() {
370 let addr: Ipv6Addr = "2001:db8:85a3:1234:8a2e:370:7334:1234".parse().unwrap();
371 let subnet = extract_ipv6_subnet_48(&addr);
372 assert_eq!(subnet.to_string(), "2001:db8:85a3::");
373 }
374
375 #[test]
376 fn test_extract_ipv4_subnet_24() {
377 let addr: Ipv4Addr = "192.168.1.100".parse().unwrap();
378 let subnet = extract_ipv4_subnet_24(&addr);
379 assert_eq!(subnet.to_string(), "192.168.1.0");
380 }
381
382 #[test]
383 fn test_join_rate_limiter_allows_first_join() {
384 let limiter = JoinRateLimiter::new(JoinRateLimiterConfig::default());
385 let ip: IpAddr = "2001:db8::1".parse().unwrap();
386 assert!(limiter.check_join_allowed(&ip).is_ok());
387 }
388
389 #[test]
390 fn test_join_rate_limiter_blocks_second_from_same_64() {
391 let config = JoinRateLimiterConfig {
392 max_joins_per_64_per_hour: 1,
393 ..Default::default()
394 };
395 let limiter = JoinRateLimiter::new(config);
396
397 let ip1: IpAddr = "2001:db8::1".parse().unwrap();
399 assert!(limiter.check_join_allowed(&ip1).is_ok());
400
401 let ip2: IpAddr = "2001:db8::2".parse().unwrap();
403 let result = limiter.check_join_allowed(&ip2);
404 assert!(matches!(
405 result,
406 Err(JoinRateLimitError::Subnet64LimitExceeded { .. })
407 ));
408 }
409
410 #[test]
411 fn test_join_rate_limiter_allows_different_subnets() {
412 let config = JoinRateLimiterConfig {
413 max_joins_per_64_per_hour: 1,
414 ..Default::default()
415 };
416 let limiter = JoinRateLimiter::new(config);
417
418 let ip1: IpAddr = "2001:db8:1::1".parse().unwrap();
420 assert!(limiter.check_join_allowed(&ip1).is_ok());
421
422 let ip2: IpAddr = "2001:db8:2::1".parse().unwrap();
424 assert!(limiter.check_join_allowed(&ip2).is_ok());
425 }
426
427 #[test]
428 fn test_join_rate_limiter_ipv4() {
429 let config = JoinRateLimiterConfig {
430 max_joins_per_24_per_hour: 2,
431 ..Default::default()
432 };
433 let limiter = JoinRateLimiter::new(config);
434
435 let ip1: IpAddr = "192.168.1.1".parse().unwrap();
437 let ip2: IpAddr = "192.168.1.2".parse().unwrap();
438 assert!(limiter.check_join_allowed(&ip1).is_ok());
439 assert!(limiter.check_join_allowed(&ip2).is_ok());
440
441 let ip3: IpAddr = "192.168.1.3".parse().unwrap();
443 let result = limiter.check_join_allowed(&ip3);
444 assert!(matches!(
445 result,
446 Err(JoinRateLimitError::Subnet24LimitExceeded { .. })
447 ));
448 }
449}