Skip to main content

saorsa_core/
rate_limit.rs

1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone)]
8pub struct EngineConfig {
9    pub window: Duration,
10    pub max_requests: u32,
11    pub burst_size: u32,
12}
13
14#[derive(Debug)]
15struct Bucket {
16    tokens: f64,
17    last_update: Instant,
18    requests_in_window: u32,
19    window_start: Instant,
20}
21
22impl Bucket {
23    fn new(initial_tokens: f64) -> Self {
24        let now = Instant::now();
25        Self {
26            tokens: initial_tokens,
27            last_update: now,
28            requests_in_window: 0,
29            window_start: now,
30        }
31    }
32
33    fn try_consume(&mut self, cfg: &EngineConfig) -> bool {
34        let now = Instant::now();
35        if now.duration_since(self.window_start) > cfg.window {
36            self.window_start = now;
37            self.requests_in_window = 0;
38        }
39        let elapsed = now.duration_since(self.last_update).as_secs_f64();
40        let refill_rate = cfg.max_requests as f64 / cfg.window.as_secs_f64();
41        self.tokens += elapsed * refill_rate;
42        self.tokens = self.tokens.min(cfg.burst_size as f64);
43        self.last_update = now;
44        if self.tokens >= 1.0 && self.requests_in_window < cfg.max_requests {
45            self.tokens -= 1.0;
46            self.requests_in_window += 1;
47            true
48        } else {
49            false
50        }
51    }
52}
53
54#[derive(Debug)]
55pub struct Engine<K: Eq + Hash + Clone + ToString> {
56    cfg: EngineConfig,
57    global: Mutex<Bucket>,
58    keyed: RwLock<HashMap<K, Bucket>>,
59}
60
61impl<K: Eq + Hash + Clone + ToString> Engine<K> {
62    pub fn new(cfg: EngineConfig) -> Self {
63        let burst_size = cfg.burst_size as f64;
64        Self {
65            cfg,
66            global: Mutex::new(Bucket::new(burst_size)),
67            keyed: RwLock::new(HashMap::new()),
68        }
69    }
70
71    pub fn try_consume_global(&self) -> bool {
72        match self.global.lock() {
73            Ok(mut guard) => guard.try_consume(&self.cfg),
74            Err(_poisoned) => {
75                // Treat poisoned mutex as a denial to maintain safety
76                // and avoid panicking in production code.
77                false
78            }
79        }
80    }
81
82    pub fn try_consume_key(&self, key: &K) -> bool {
83        let mut map = self.keyed.write();
84        let bucket = map
85            .entry(key.clone())
86            .or_insert_with(|| Bucket::new(self.cfg.burst_size as f64));
87        bucket.try_consume(&self.cfg)
88    }
89}
90
91pub type SharedEngine<K> = Arc<Engine<K>>;
92
93// ============================================================================
94// Join Rate Limiting for Sybil Protection
95// ============================================================================
96
97use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
98use thiserror::Error;
99
100/// Error types for join rate limiting
101#[derive(Debug, Error)]
102pub enum JoinRateLimitError {
103    /// Global join limit exceeded (network is under high load)
104    #[error("global join rate limit exceeded: max {max_per_minute} joins per minute")]
105    GlobalLimitExceeded { max_per_minute: u32 },
106
107    /// Per-subnet /64 limit exceeded (potential Sybil attack)
108    #[error("subnet /64 join rate limit exceeded: max {max_per_hour} joins per hour from this /64")]
109    Subnet64LimitExceeded { max_per_hour: u32 },
110
111    /// Per-subnet /48 limit exceeded (potential coordinated attack)
112    #[error("subnet /48 join rate limit exceeded: max {max_per_hour} joins per hour from this /48")]
113    Subnet48LimitExceeded { max_per_hour: u32 },
114
115    /// Per-subnet /24 limit exceeded (IPv4 Sybil attack)
116    #[error("subnet /24 join rate limit exceeded: max {max_per_hour} joins per hour from this /24")]
117    Subnet24LimitExceeded { max_per_hour: u32 },
118}
119
120/// Configuration for join rate limiting
121#[derive(Debug, Clone)]
122pub struct JoinRateLimiterConfig {
123    /// Maximum joins per /64 subnet per hour (default: 1)
124    /// This is the strictest limit to prevent Sybil attacks
125    pub max_joins_per_64_per_hour: u32,
126
127    /// Maximum joins per /48 subnet per hour (default: 5)
128    pub max_joins_per_48_per_hour: u32,
129
130    /// Maximum joins per /24 subnet per hour for IPv4 (default: 3)
131    pub max_joins_per_24_per_hour: u32,
132
133    /// Maximum global joins per minute (default: 100)
134    /// This protects against network-wide flooding
135    pub max_global_joins_per_minute: u32,
136
137    /// Burst allowance for global limit (default: 10)
138    pub global_burst_size: u32,
139}
140
141impl Default for JoinRateLimiterConfig {
142    fn default() -> Self {
143        Self {
144            max_joins_per_64_per_hour: 1,
145            max_joins_per_48_per_hour: 5,
146            max_joins_per_24_per_hour: 3,
147            max_global_joins_per_minute: 100,
148            global_burst_size: 10,
149        }
150    }
151}
152
153/// Join rate limiter for Sybil attack protection
154///
155/// Implements multi-level rate limiting to prevent attackers from flooding
156/// the network with Sybil identities:
157///
158/// - **Global limit**: Protects against network-wide flooding attacks
159/// - **Per-subnet /64 limit**: Prevents single residential/small org Sybil attacks
160/// - **Per-subnet /48 limit**: Prevents coordinated attacks from larger organizations
161/// - **Per-subnet /24 limit**: IPv4-specific protection
162///
163/// # Example
164///
165/// ```rust,ignore
166/// use saorsa_core::rate_limit::{JoinRateLimiter, JoinRateLimiterConfig};
167/// use std::net::IpAddr;
168///
169/// let limiter = JoinRateLimiter::new(JoinRateLimiterConfig::default());
170///
171/// let ip: IpAddr = "2001:db8::1".parse().unwrap();
172/// match limiter.check_join_allowed(&ip) {
173///     Ok(()) => println!("Join allowed"),
174///     Err(e) => println!("Join denied: {}", e),
175/// }
176/// ```
177#[derive(Debug)]
178pub struct JoinRateLimiter {
179    config: JoinRateLimiterConfig,
180    /// Per /64 subnet rate limiter (1 hour window)
181    per_subnet_64: Engine<Ipv6Addr>,
182    /// Per /48 subnet rate limiter (1 hour window)
183    per_subnet_48: Engine<Ipv6Addr>,
184    /// Per /24 subnet rate limiter for IPv4 (1 hour window)
185    per_subnet_24: Engine<Ipv4Addr>,
186    /// Global rate limiter (1 minute window) - uses u8 key with constant 0
187    global: Engine<u8>,
188}
189
190impl JoinRateLimiter {
191    /// Create a new join rate limiter with the given configuration
192    pub fn new(config: JoinRateLimiterConfig) -> Self {
193        // /64 subnet limiter: max_joins_per_64_per_hour over 1 hour
194        let subnet_64_config = EngineConfig {
195            window: Duration::from_secs(3600), // 1 hour
196            max_requests: config.max_joins_per_64_per_hour,
197            burst_size: config.max_joins_per_64_per_hour, // Allow configured limit as burst
198        };
199
200        // /48 subnet limiter: max_joins_per_48_per_hour over 1 hour
201        let subnet_48_config = EngineConfig {
202            window: Duration::from_secs(3600), // 1 hour
203            max_requests: config.max_joins_per_48_per_hour,
204            burst_size: config.max_joins_per_48_per_hour, // Allow configured limit as burst
205        };
206
207        // /24 subnet limiter for IPv4
208        let subnet_24_config = EngineConfig {
209            window: Duration::from_secs(3600), // 1 hour
210            max_requests: config.max_joins_per_24_per_hour,
211            burst_size: config.max_joins_per_24_per_hour, // Allow full burst up to limit
212        };
213
214        // Global limiter: max_global_joins_per_minute over 1 minute
215        let global_config = EngineConfig {
216            window: Duration::from_secs(60), // 1 minute
217            max_requests: config.max_global_joins_per_minute,
218            burst_size: config.global_burst_size,
219        };
220
221        Self {
222            config,
223            per_subnet_64: Engine::new(subnet_64_config),
224            per_subnet_48: Engine::new(subnet_48_config),
225            per_subnet_24: Engine::new(subnet_24_config),
226            global: Engine::new(global_config),
227        }
228    }
229
230    /// Check if a join request from the given IP is allowed
231    ///
232    /// Returns `Ok(())` if the join is allowed, or `Err(JoinRateLimitError)`
233    /// if any rate limit is exceeded.
234    ///
235    /// # Rate Limit Checks (in order)
236    ///
237    /// 1. Global rate limit (protects against network flooding)
238    /// 2. Per-subnet limits based on IP version:
239    ///    - IPv6: /64 and /48 subnet limits
240    ///    - IPv4: /24 subnet limit
241    pub fn check_join_allowed(&self, ip: &IpAddr) -> Result<(), JoinRateLimitError> {
242        // 1. Check global limit first (uses constant key 0)
243        if !self.global.try_consume_key(&0u8) {
244            return Err(JoinRateLimitError::GlobalLimitExceeded {
245                max_per_minute: self.config.max_global_joins_per_minute,
246            });
247        }
248
249        // 2. Check per-subnet limits based on IP version
250        match ip {
251            IpAddr::V6(ipv6) => {
252                // Check /64 subnet limit (strictest for Sybil protection)
253                let subnet_64 = extract_ipv6_subnet_64(ipv6);
254                if !self.per_subnet_64.try_consume_key(&subnet_64) {
255                    return Err(JoinRateLimitError::Subnet64LimitExceeded {
256                        max_per_hour: self.config.max_joins_per_64_per_hour,
257                    });
258                }
259
260                // Check /48 subnet limit
261                let subnet_48 = extract_ipv6_subnet_48(ipv6);
262                if !self.per_subnet_48.try_consume_key(&subnet_48) {
263                    return Err(JoinRateLimitError::Subnet48LimitExceeded {
264                        max_per_hour: self.config.max_joins_per_48_per_hour,
265                    });
266                }
267            }
268            IpAddr::V4(ipv4) => {
269                // Check /24 subnet limit for IPv4
270                let subnet_24 = extract_ipv4_subnet_24(ipv4);
271                if !self.per_subnet_24.try_consume_key(&subnet_24) {
272                    return Err(JoinRateLimitError::Subnet24LimitExceeded {
273                        max_per_hour: self.config.max_joins_per_24_per_hour,
274                    });
275                }
276            }
277        }
278
279        Ok(())
280    }
281
282    /// Get the current configuration
283    pub fn config(&self) -> &JoinRateLimiterConfig {
284        &self.config
285    }
286}
287
288/// Extract /64 subnet prefix from an IPv6 address
289///
290/// Returns an IPv6 address with only the first 64 bits preserved (network portion),
291/// with the remaining 64 bits zeroed (interface identifier).
292#[inline]
293pub fn extract_ipv6_subnet_64(addr: &Ipv6Addr) -> Ipv6Addr {
294    let octets = addr.octets();
295    let mut subnet = [0u8; 16];
296    subnet[..8].copy_from_slice(&octets[..8]); // Keep first 64 bits
297    Ipv6Addr::from(subnet)
298}
299
300/// Extract /48 subnet prefix from an IPv6 address
301///
302/// Returns an IPv6 address with only the first 48 bits preserved.
303#[inline]
304pub fn extract_ipv6_subnet_48(addr: &Ipv6Addr) -> Ipv6Addr {
305    let octets = addr.octets();
306    let mut subnet = [0u8; 16];
307    subnet[..6].copy_from_slice(&octets[..6]); // Keep first 48 bits
308    Ipv6Addr::from(subnet)
309}
310
311/// Extract /32 subnet prefix from an IPv6 address
312///
313/// Returns an IPv6 address with only the first 32 bits preserved.
314#[inline]
315pub fn extract_ipv6_subnet_32(addr: &Ipv6Addr) -> Ipv6Addr {
316    let octets = addr.octets();
317    let mut subnet = [0u8; 16];
318    subnet[..4].copy_from_slice(&octets[..4]); // Keep first 32 bits
319    Ipv6Addr::from(subnet)
320}
321
322/// Extract /24 subnet prefix from an IPv4 address
323///
324/// Returns an IPv4 address with only the first 24 bits preserved.
325#[inline]
326pub fn extract_ipv4_subnet_24(addr: &Ipv4Addr) -> Ipv4Addr {
327    let octets = addr.octets();
328    Ipv4Addr::new(octets[0], octets[1], octets[2], 0)
329}
330
331/// Extract /16 subnet prefix from an IPv4 address
332#[inline]
333pub fn extract_ipv4_subnet_16(addr: &Ipv4Addr) -> Ipv4Addr {
334    let octets = addr.octets();
335    Ipv4Addr::new(octets[0], octets[1], 0, 0)
336}
337
338/// Extract /8 subnet prefix from an IPv4 address
339#[inline]
340pub fn extract_ipv4_subnet_8(addr: &Ipv4Addr) -> Ipv4Addr {
341    let octets = addr.octets();
342    Ipv4Addr::new(octets[0], 0, 0, 0)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_extract_ipv6_subnet_64() {
351        let addr: Ipv6Addr = "2001:db8:85a3:1234:8a2e:370:7334:1234".parse().unwrap();
352        let subnet = extract_ipv6_subnet_64(&addr);
353        assert_eq!(subnet.to_string(), "2001:db8:85a3:1234::");
354    }
355
356    #[test]
357    fn test_extract_ipv6_subnet_48() {
358        let addr: Ipv6Addr = "2001:db8:85a3:1234:8a2e:370:7334:1234".parse().unwrap();
359        let subnet = extract_ipv6_subnet_48(&addr);
360        assert_eq!(subnet.to_string(), "2001:db8:85a3::");
361    }
362
363    #[test]
364    fn test_extract_ipv4_subnet_24() {
365        let addr: Ipv4Addr = "192.168.1.100".parse().unwrap();
366        let subnet = extract_ipv4_subnet_24(&addr);
367        assert_eq!(subnet.to_string(), "192.168.1.0");
368    }
369
370    #[test]
371    fn test_join_rate_limiter_allows_first_join() {
372        let limiter = JoinRateLimiter::new(JoinRateLimiterConfig::default());
373        let ip: IpAddr = "2001:db8::1".parse().unwrap();
374        assert!(limiter.check_join_allowed(&ip).is_ok());
375    }
376
377    #[test]
378    fn test_join_rate_limiter_blocks_second_from_same_64() {
379        let config = JoinRateLimiterConfig {
380            max_joins_per_64_per_hour: 1,
381            ..Default::default()
382        };
383        let limiter = JoinRateLimiter::new(config);
384
385        // First join should succeed
386        let ip1: IpAddr = "2001:db8::1".parse().unwrap();
387        assert!(limiter.check_join_allowed(&ip1).is_ok());
388
389        // Second join from same /64 should fail
390        let ip2: IpAddr = "2001:db8::2".parse().unwrap();
391        let result = limiter.check_join_allowed(&ip2);
392        assert!(matches!(
393            result,
394            Err(JoinRateLimitError::Subnet64LimitExceeded { .. })
395        ));
396    }
397
398    #[test]
399    fn test_join_rate_limiter_allows_different_subnets() {
400        let config = JoinRateLimiterConfig {
401            max_joins_per_64_per_hour: 1,
402            ..Default::default()
403        };
404        let limiter = JoinRateLimiter::new(config);
405
406        // First join from one /64
407        let ip1: IpAddr = "2001:db8:1::1".parse().unwrap();
408        assert!(limiter.check_join_allowed(&ip1).is_ok());
409
410        // Second join from different /64 should succeed
411        let ip2: IpAddr = "2001:db8:2::1".parse().unwrap();
412        assert!(limiter.check_join_allowed(&ip2).is_ok());
413    }
414
415    #[test]
416    fn test_join_rate_limiter_ipv4() {
417        let config = JoinRateLimiterConfig {
418            max_joins_per_24_per_hour: 2,
419            ..Default::default()
420        };
421        let limiter = JoinRateLimiter::new(config);
422
423        // First two joins should succeed
424        let ip1: IpAddr = "192.168.1.1".parse().unwrap();
425        let ip2: IpAddr = "192.168.1.2".parse().unwrap();
426        assert!(limiter.check_join_allowed(&ip1).is_ok());
427        assert!(limiter.check_join_allowed(&ip2).is_ok());
428
429        // Third join from same /24 should fail
430        let ip3: IpAddr = "192.168.1.3".parse().unwrap();
431        let result = limiter.check_join_allowed(&ip3);
432        assert!(matches!(
433            result,
434            Err(JoinRateLimitError::Subnet24LimitExceeded { .. })
435        ));
436    }
437}