Skip to main content

saorsa_core/
rate_limit.rs

1use lru::LruCache;
2use parking_lot::RwLock;
3use std::hash::Hash;
4use std::num::NonZeroUsize;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8/// Maximum rate limit keys before evicting oldest (prevents memory DoS from many IPs)
9const MAX_RATE_LIMIT_KEYS: usize = 100_000;
10
11#[derive(Debug, Clone)]
12pub struct EngineConfig {
13    pub window: Duration,
14    pub max_requests: u32,
15    pub burst_size: u32,
16}
17
18#[derive(Debug)]
19struct Bucket {
20    tokens: f64,
21    last_update: Instant,
22    requests_in_window: u32,
23    window_start: Instant,
24}
25
26impl Bucket {
27    fn new(initial_tokens: f64) -> Self {
28        let now = Instant::now();
29        Self {
30            tokens: initial_tokens,
31            last_update: now,
32            requests_in_window: 0,
33            window_start: now,
34        }
35    }
36
37    fn try_consume(&mut self, cfg: &EngineConfig) -> bool {
38        let now = Instant::now();
39        if now.duration_since(self.window_start) > cfg.window {
40            self.window_start = now;
41            self.requests_in_window = 0;
42        }
43        let elapsed = now.duration_since(self.last_update).as_secs_f64();
44        let refill_rate = cfg.max_requests as f64 / cfg.window.as_secs_f64();
45        self.tokens += elapsed * refill_rate;
46        self.tokens = self.tokens.min(cfg.burst_size as f64);
47        self.last_update = now;
48        if self.tokens >= 1.0 && self.requests_in_window < cfg.max_requests {
49            self.tokens -= 1.0;
50            self.requests_in_window += 1;
51            true
52        } else {
53            false
54        }
55    }
56}
57
58#[derive(Debug)]
59pub struct Engine<K: Eq + Hash + Clone + ToString> {
60    cfg: EngineConfig,
61    global: Mutex<Bucket>,
62    /// LRU cache with max 100k entries to prevent memory DoS from many IPs
63    keyed: RwLock<LruCache<K, Bucket>>,
64}
65
66impl<K: Eq + Hash + Clone + ToString> Engine<K> {
67    pub fn new(cfg: EngineConfig) -> Self {
68        let burst_size = cfg.burst_size as f64;
69        // Safety: MAX_RATE_LIMIT_KEYS is a const > 0, so unwrap_or with MIN (=1) is safe
70        let cache_size = NonZeroUsize::new(MAX_RATE_LIMIT_KEYS).unwrap_or(NonZeroUsize::MIN);
71        Self {
72            cfg,
73            global: Mutex::new(Bucket::new(burst_size)),
74            keyed: RwLock::new(LruCache::new(cache_size)),
75        }
76    }
77
78    pub fn try_consume_global(&self) -> bool {
79        match self.global.lock() {
80            Ok(mut guard) => guard.try_consume(&self.cfg),
81            Err(_poisoned) => {
82                // Treat poisoned mutex as a denial to maintain safety
83                // and avoid panicking in production code.
84                false
85            }
86        }
87    }
88
89    pub fn try_consume_key(&self, key: &K) -> bool {
90        let mut map = self.keyed.write();
91        // Get or insert with LRU cache (automatically evicts oldest if at capacity)
92        if let Some(bucket) = map.get_mut(key) {
93            bucket.try_consume(&self.cfg)
94        } else {
95            let mut bucket = Bucket::new(self.cfg.burst_size as f64);
96            let result = bucket.try_consume(&self.cfg);
97            map.put(key.clone(), bucket);
98            result
99        }
100    }
101}
102
103pub type SharedEngine<K> = Arc<Engine<K>>;
104
105// ============================================================================
106// Join Rate Limiting for Sybil Protection
107// ============================================================================
108
109use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
110use thiserror::Error;
111
112/// Error types for join rate limiting
113#[derive(Debug, Error)]
114pub enum JoinRateLimitError {
115    /// Global join limit exceeded (network is under high load)
116    #[error("global join rate limit exceeded: max {max_per_minute} joins per minute")]
117    GlobalLimitExceeded { max_per_minute: u32 },
118
119    /// Per-subnet /64 limit exceeded (potential Sybil attack)
120    #[error("subnet /64 join rate limit exceeded: max {max_per_hour} joins per hour from this /64")]
121    Subnet64LimitExceeded { max_per_hour: u32 },
122
123    /// Per-subnet /48 limit exceeded (potential coordinated attack)
124    #[error("subnet /48 join rate limit exceeded: max {max_per_hour} joins per hour from this /48")]
125    Subnet48LimitExceeded { max_per_hour: u32 },
126
127    /// Per-subnet /24 limit exceeded (IPv4 Sybil attack)
128    #[error("subnet /24 join rate limit exceeded: max {max_per_hour} joins per hour from this /24")]
129    Subnet24LimitExceeded { max_per_hour: u32 },
130}
131
132/// Configuration for join rate limiting
133#[derive(Debug, Clone)]
134pub struct JoinRateLimiterConfig {
135    /// Maximum joins per /64 subnet per hour (default: 1)
136    /// This is the strictest limit to prevent Sybil attacks
137    pub max_joins_per_64_per_hour: u32,
138
139    /// Maximum joins per /48 subnet per hour (default: 5)
140    pub max_joins_per_48_per_hour: u32,
141
142    /// Maximum joins per /24 subnet per hour for IPv4 (default: 3)
143    pub max_joins_per_24_per_hour: u32,
144
145    /// Maximum global joins per minute (default: 100)
146    /// This protects against network-wide flooding
147    pub max_global_joins_per_minute: u32,
148
149    /// Burst allowance for global limit (default: 10)
150    pub global_burst_size: u32,
151}
152
153impl Default for JoinRateLimiterConfig {
154    fn default() -> Self {
155        Self {
156            max_joins_per_64_per_hour: 1,
157            max_joins_per_48_per_hour: 5,
158            max_joins_per_24_per_hour: 3,
159            max_global_joins_per_minute: 100,
160            global_burst_size: 10,
161        }
162    }
163}
164
165/// Join rate limiter for Sybil attack protection
166///
167/// Implements multi-level rate limiting to prevent attackers from flooding
168/// the network with Sybil identities:
169///
170/// - **Global limit**: Protects against network-wide flooding attacks
171/// - **Per-subnet /64 limit**: Prevents single residential/small org Sybil attacks
172/// - **Per-subnet /48 limit**: Prevents coordinated attacks from larger organizations
173/// - **Per-subnet /24 limit**: IPv4-specific protection
174///
175/// # Example
176///
177/// ```rust,ignore
178/// use saorsa_core::rate_limit::{JoinRateLimiter, JoinRateLimiterConfig};
179/// use std::net::IpAddr;
180///
181/// let limiter = JoinRateLimiter::new(JoinRateLimiterConfig::default());
182///
183/// let ip: IpAddr = "2001:db8::1".parse().unwrap();
184/// match limiter.check_join_allowed(&ip) {
185///     Ok(()) => println!("Join allowed"),
186///     Err(e) => println!("Join denied: {}", e),
187/// }
188/// ```
189#[derive(Debug)]
190pub struct JoinRateLimiter {
191    config: JoinRateLimiterConfig,
192    /// Per /64 subnet rate limiter (1 hour window)
193    per_subnet_64: Engine<Ipv6Addr>,
194    /// Per /48 subnet rate limiter (1 hour window)
195    per_subnet_48: Engine<Ipv6Addr>,
196    /// Per /24 subnet rate limiter for IPv4 (1 hour window)
197    per_subnet_24: Engine<Ipv4Addr>,
198    /// Global rate limiter (1 minute window) - uses u8 key with constant 0
199    global: Engine<u8>,
200}
201
202impl JoinRateLimiter {
203    /// Create a new join rate limiter with the given configuration
204    pub fn new(config: JoinRateLimiterConfig) -> Self {
205        // /64 subnet limiter: max_joins_per_64_per_hour over 1 hour
206        let subnet_64_config = EngineConfig {
207            window: Duration::from_secs(3600), // 1 hour
208            max_requests: config.max_joins_per_64_per_hour,
209            burst_size: config.max_joins_per_64_per_hour, // Allow configured limit as burst
210        };
211
212        // /48 subnet limiter: max_joins_per_48_per_hour over 1 hour
213        let subnet_48_config = EngineConfig {
214            window: Duration::from_secs(3600), // 1 hour
215            max_requests: config.max_joins_per_48_per_hour,
216            burst_size: config.max_joins_per_48_per_hour, // Allow configured limit as burst
217        };
218
219        // /24 subnet limiter for IPv4
220        let subnet_24_config = EngineConfig {
221            window: Duration::from_secs(3600), // 1 hour
222            max_requests: config.max_joins_per_24_per_hour,
223            burst_size: config.max_joins_per_24_per_hour, // Allow full burst up to limit
224        };
225
226        // Global limiter: max_global_joins_per_minute over 1 minute
227        let global_config = EngineConfig {
228            window: Duration::from_secs(60), // 1 minute
229            max_requests: config.max_global_joins_per_minute,
230            burst_size: config.global_burst_size,
231        };
232
233        Self {
234            config,
235            per_subnet_64: Engine::new(subnet_64_config),
236            per_subnet_48: Engine::new(subnet_48_config),
237            per_subnet_24: Engine::new(subnet_24_config),
238            global: Engine::new(global_config),
239        }
240    }
241
242    /// Check if a join request from the given IP is allowed
243    ///
244    /// Returns `Ok(())` if the join is allowed, or `Err(JoinRateLimitError)`
245    /// if any rate limit is exceeded.
246    ///
247    /// # Rate Limit Checks (in order)
248    ///
249    /// 1. Global rate limit (protects against network flooding)
250    /// 2. Per-subnet limits based on IP version:
251    ///    - IPv6: /64 and /48 subnet limits
252    ///    - IPv4: /24 subnet limit
253    pub fn check_join_allowed(&self, ip: &IpAddr) -> Result<(), JoinRateLimitError> {
254        // 1. Check global limit first (uses constant key 0)
255        if !self.global.try_consume_key(&0u8) {
256            return Err(JoinRateLimitError::GlobalLimitExceeded {
257                max_per_minute: self.config.max_global_joins_per_minute,
258            });
259        }
260
261        // 2. Check per-subnet limits based on IP version
262        match ip {
263            IpAddr::V6(ipv6) => {
264                // Check /64 subnet limit (strictest for Sybil protection)
265                let subnet_64 = extract_ipv6_subnet_64(ipv6);
266                if !self.per_subnet_64.try_consume_key(&subnet_64) {
267                    return Err(JoinRateLimitError::Subnet64LimitExceeded {
268                        max_per_hour: self.config.max_joins_per_64_per_hour,
269                    });
270                }
271
272                // Check /48 subnet limit
273                let subnet_48 = extract_ipv6_subnet_48(ipv6);
274                if !self.per_subnet_48.try_consume_key(&subnet_48) {
275                    return Err(JoinRateLimitError::Subnet48LimitExceeded {
276                        max_per_hour: self.config.max_joins_per_48_per_hour,
277                    });
278                }
279            }
280            IpAddr::V4(ipv4) => {
281                // Check /24 subnet limit for IPv4
282                let subnet_24 = extract_ipv4_subnet_24(ipv4);
283                if !self.per_subnet_24.try_consume_key(&subnet_24) {
284                    return Err(JoinRateLimitError::Subnet24LimitExceeded {
285                        max_per_hour: self.config.max_joins_per_24_per_hour,
286                    });
287                }
288            }
289        }
290
291        Ok(())
292    }
293
294    /// Get the current configuration
295    pub fn config(&self) -> &JoinRateLimiterConfig {
296        &self.config
297    }
298}
299
300/// Extract /64 subnet prefix from an IPv6 address
301///
302/// Returns an IPv6 address with only the first 64 bits preserved (network portion),
303/// with the remaining 64 bits zeroed (interface identifier).
304#[inline]
305pub fn extract_ipv6_subnet_64(addr: &Ipv6Addr) -> Ipv6Addr {
306    let octets = addr.octets();
307    let mut subnet = [0u8; 16];
308    subnet[..8].copy_from_slice(&octets[..8]); // Keep first 64 bits
309    Ipv6Addr::from(subnet)
310}
311
312/// Extract /48 subnet prefix from an IPv6 address
313///
314/// Returns an IPv6 address with only the first 48 bits preserved.
315#[inline]
316pub fn extract_ipv6_subnet_48(addr: &Ipv6Addr) -> Ipv6Addr {
317    let octets = addr.octets();
318    let mut subnet = [0u8; 16];
319    subnet[..6].copy_from_slice(&octets[..6]); // Keep first 48 bits
320    Ipv6Addr::from(subnet)
321}
322
323/// Extract /32 subnet prefix from an IPv6 address
324///
325/// Returns an IPv6 address with only the first 32 bits preserved.
326#[inline]
327pub fn extract_ipv6_subnet_32(addr: &Ipv6Addr) -> Ipv6Addr {
328    let octets = addr.octets();
329    let mut subnet = [0u8; 16];
330    subnet[..4].copy_from_slice(&octets[..4]); // Keep first 32 bits
331    Ipv6Addr::from(subnet)
332}
333
334/// Extract /24 subnet prefix from an IPv4 address
335///
336/// Returns an IPv4 address with only the first 24 bits preserved.
337#[inline]
338pub fn extract_ipv4_subnet_24(addr: &Ipv4Addr) -> Ipv4Addr {
339    let octets = addr.octets();
340    Ipv4Addr::new(octets[0], octets[1], octets[2], 0)
341}
342
343/// Extract /16 subnet prefix from an IPv4 address
344#[inline]
345pub fn extract_ipv4_subnet_16(addr: &Ipv4Addr) -> Ipv4Addr {
346    let octets = addr.octets();
347    Ipv4Addr::new(octets[0], octets[1], 0, 0)
348}
349
350/// Extract /8 subnet prefix from an IPv4 address
351#[inline]
352pub fn extract_ipv4_subnet_8(addr: &Ipv4Addr) -> Ipv4Addr {
353    let octets = addr.octets();
354    Ipv4Addr::new(octets[0], 0, 0, 0)
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_extract_ipv6_subnet_64() {
363        let addr: Ipv6Addr = "2001:db8:85a3:1234:8a2e:370:7334:1234".parse().unwrap();
364        let subnet = extract_ipv6_subnet_64(&addr);
365        assert_eq!(subnet.to_string(), "2001:db8:85a3:1234::");
366    }
367
368    #[test]
369    fn test_extract_ipv6_subnet_48() {
370        let addr: Ipv6Addr = "2001:db8:85a3:1234:8a2e:370:7334:1234".parse().unwrap();
371        let subnet = extract_ipv6_subnet_48(&addr);
372        assert_eq!(subnet.to_string(), "2001:db8:85a3::");
373    }
374
375    #[test]
376    fn test_extract_ipv4_subnet_24() {
377        let addr: Ipv4Addr = "192.168.1.100".parse().unwrap();
378        let subnet = extract_ipv4_subnet_24(&addr);
379        assert_eq!(subnet.to_string(), "192.168.1.0");
380    }
381
382    #[test]
383    fn test_join_rate_limiter_allows_first_join() {
384        let limiter = JoinRateLimiter::new(JoinRateLimiterConfig::default());
385        let ip: IpAddr = "2001:db8::1".parse().unwrap();
386        assert!(limiter.check_join_allowed(&ip).is_ok());
387    }
388
389    #[test]
390    fn test_join_rate_limiter_blocks_second_from_same_64() {
391        let config = JoinRateLimiterConfig {
392            max_joins_per_64_per_hour: 1,
393            ..Default::default()
394        };
395        let limiter = JoinRateLimiter::new(config);
396
397        // First join should succeed
398        let ip1: IpAddr = "2001:db8::1".parse().unwrap();
399        assert!(limiter.check_join_allowed(&ip1).is_ok());
400
401        // Second join from same /64 should fail
402        let ip2: IpAddr = "2001:db8::2".parse().unwrap();
403        let result = limiter.check_join_allowed(&ip2);
404        assert!(matches!(
405            result,
406            Err(JoinRateLimitError::Subnet64LimitExceeded { .. })
407        ));
408    }
409
410    #[test]
411    fn test_join_rate_limiter_allows_different_subnets() {
412        let config = JoinRateLimiterConfig {
413            max_joins_per_64_per_hour: 1,
414            ..Default::default()
415        };
416        let limiter = JoinRateLimiter::new(config);
417
418        // First join from one /64
419        let ip1: IpAddr = "2001:db8:1::1".parse().unwrap();
420        assert!(limiter.check_join_allowed(&ip1).is_ok());
421
422        // Second join from different /64 should succeed
423        let ip2: IpAddr = "2001:db8:2::1".parse().unwrap();
424        assert!(limiter.check_join_allowed(&ip2).is_ok());
425    }
426
427    #[test]
428    fn test_join_rate_limiter_ipv4() {
429        let config = JoinRateLimiterConfig {
430            max_joins_per_24_per_hour: 2,
431            ..Default::default()
432        };
433        let limiter = JoinRateLimiter::new(config);
434
435        // First two joins should succeed
436        let ip1: IpAddr = "192.168.1.1".parse().unwrap();
437        let ip2: IpAddr = "192.168.1.2".parse().unwrap();
438        assert!(limiter.check_join_allowed(&ip1).is_ok());
439        assert!(limiter.check_join_allowed(&ip2).is_ok());
440
441        // Third join from same /24 should fail
442        let ip3: IpAddr = "192.168.1.3".parse().unwrap();
443        let result = limiter.check_join_allowed(&ip3);
444        assert!(matches!(
445            result,
446            Err(JoinRateLimitError::Subnet24LimitExceeded { .. })
447        ));
448    }
449}