1use anyhow::{Result, anyhow};
21use lru::LruCache;
22use serde::{Deserialize, Serialize};
23use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
24use std::num::NonZeroUsize;
25
26const BOOTSTRAP_MAX_TRACKED_SUBNETS: usize = 50_000;
28
29pub const IP_EXACT_LIMIT: usize = 2;
33
34#[cfg(test)]
37const DEFAULT_K_VALUE: usize = 20;
38
39pub fn canonicalize_ip(ip: IpAddr) -> IpAddr {
43 match ip {
44 IpAddr::V6(v6) => v6
45 .to_ipv4_mapped()
46 .map(IpAddr::V4)
47 .unwrap_or(IpAddr::V6(v6)),
48 other => other,
49 }
50}
51
52pub const fn ip_subnet_limit(k: usize) -> usize {
55 if k / 4 > 0 { k / 4 } else { 1 }
56}
57
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
70pub struct IPDiversityConfig {
71 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub max_per_ip: Option<usize>,
75
76 #[serde(default, skip_serializing_if = "Option::is_none")]
79 pub max_per_subnet: Option<usize>,
80}
81
82impl IPDiversityConfig {
83 #[must_use]
99 pub fn testnet() -> Self {
100 Self::permissive()
101 }
102
103 #[must_use]
108 pub fn permissive() -> Self {
109 Self {
110 max_per_ip: Some(usize::MAX),
111 max_per_subnet: Some(usize::MAX),
112 }
113 }
114
115 pub fn validate(&self) -> Result<()> {
119 if let Some(limit) = self.max_per_ip
120 && limit < 1
121 {
122 anyhow::bail!("max_per_ip must be >= 1 (got {limit})");
123 }
124 if let Some(limit) = self.max_per_subnet
125 && limit < 1
126 {
127 anyhow::bail!("max_per_subnet must be >= 1 (got {limit})");
128 }
129 Ok(())
130 }
131}
132
133#[derive(Debug)]
138pub struct BootstrapIpLimiter {
139 config: IPDiversityConfig,
140 allow_loopback: bool,
146 k_value: usize,
149 ip_counts: LruCache<IpAddr, usize>,
151 subnet_counts: LruCache<IpAddr, usize>,
153}
154
155impl BootstrapIpLimiter {
156 #[cfg(test)]
162 pub fn new(config: IPDiversityConfig) -> Self {
163 Self::with_loopback(config, false)
164 }
165
166 #[cfg(test)]
173 pub fn with_loopback(config: IPDiversityConfig, allow_loopback: bool) -> Self {
174 Self::with_loopback_and_k(config, allow_loopback, DEFAULT_K_VALUE)
175 }
176
177 pub fn with_loopback_and_k(
182 config: IPDiversityConfig,
183 allow_loopback: bool,
184 k_value: usize,
185 ) -> Self {
186 let cache_size =
187 NonZeroUsize::new(BOOTSTRAP_MAX_TRACKED_SUBNETS).unwrap_or(NonZeroUsize::MIN);
188 Self {
189 config,
190 allow_loopback,
191 k_value,
192 ip_counts: LruCache::new(cache_size),
193 subnet_counts: LruCache::new(cache_size),
194 }
195 }
196
197 fn subnet_key(ip: IpAddr) -> IpAddr {
199 match ip {
200 IpAddr::V4(v4) => {
201 let o = v4.octets();
202 IpAddr::V4(Ipv4Addr::new(o[0], o[1], o[2], 0))
203 }
204 IpAddr::V6(v6) => {
205 let mut o = v6.octets();
206 for b in &mut o[6..] {
208 *b = 0;
209 }
210 IpAddr::V6(Ipv6Addr::from(o))
211 }
212 }
213 }
214
215 pub fn can_accept(&self, ip: IpAddr) -> bool {
217 let ip = canonicalize_ip(ip);
218
219 if ip.is_loopback() {
221 return self.allow_loopback;
222 }
223
224 if ip.is_unspecified() || ip.is_multicast() {
226 return false;
227 }
228
229 let ip_limit = self.config.max_per_ip.unwrap_or(IP_EXACT_LIMIT);
230 let subnet_limit = self
231 .config
232 .max_per_subnet
233 .unwrap_or(ip_subnet_limit(self.k_value));
234
235 if let Some(&count) = self.ip_counts.peek(&ip)
237 && count >= ip_limit
238 {
239 return false;
240 }
241
242 let subnet = Self::subnet_key(ip);
244 if let Some(&count) = self.subnet_counts.peek(&subnet)
245 && count >= subnet_limit
246 {
247 return false;
248 }
249
250 true
251 }
252
253 pub fn track(&mut self, ip: IpAddr) -> Result<()> {
257 let ip = canonicalize_ip(ip);
258 if !self.can_accept(ip) {
259 return Err(anyhow!("IP diversity limits exceeded"));
260 }
261
262 let count = self.ip_counts.get(&ip).copied().unwrap_or(0) + 1;
263 self.ip_counts.put(ip, count);
264
265 let subnet = Self::subnet_key(ip);
266 let count = self.subnet_counts.get(&subnet).copied().unwrap_or(0) + 1;
267 self.subnet_counts.put(subnet, count);
268
269 Ok(())
270 }
271
272 #[allow(dead_code)]
274 pub fn untrack(&mut self, ip: IpAddr) {
275 let ip = canonicalize_ip(ip);
276 if let Some(count) = self.ip_counts.peek_mut(&ip) {
277 *count = count.saturating_sub(1);
278 if *count == 0 {
279 self.ip_counts.pop(&ip);
280 }
281 }
282
283 let subnet = Self::subnet_key(ip);
284 if let Some(count) = self.subnet_counts.peek_mut(&subnet) {
285 *count = count.saturating_sub(1);
286 if *count == 0 {
287 self.subnet_counts.pop(&subnet);
288 }
289 }
290 }
291}
292
293#[cfg(test)]
294impl BootstrapIpLimiter {
295 #[allow(dead_code)]
296 pub fn config(&self) -> &IPDiversityConfig {
297 &self.config
298 }
299}
300
301#[allow(dead_code)]
306pub trait GeoProvider: std::fmt::Debug {
307 fn lookup(&self, ip: Ipv6Addr) -> GeoInfo;
309}
310
311#[derive(Debug, Clone)]
313#[allow(dead_code)]
314pub struct GeoInfo {
315 pub asn: Option<u32>,
317 pub country: Option<String>,
319 pub is_hosting_provider: bool,
321 pub is_vpn_provider: bool,
323}
324
325#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_ip_diversity_config_default() {
333 let config = IPDiversityConfig::default();
334
335 assert!(config.max_per_ip.is_none());
336 assert!(config.max_per_subnet.is_none());
337 }
338
339 #[test]
340 fn test_bootstrap_ip_limiter_creation() {
341 let config = IPDiversityConfig {
342 max_per_ip: None,
343 max_per_subnet: Some(1),
344 };
345 let enforcer = BootstrapIpLimiter::with_loopback(config.clone(), true);
346
347 assert_eq!(enforcer.config.max_per_subnet, config.max_per_subnet);
348 }
349
350 #[test]
351 fn test_can_accept_basic() {
352 let config = IPDiversityConfig::default();
353 let enforcer = BootstrapIpLimiter::new(config);
354
355 let ip: IpAddr = "192.168.1.1".parse().unwrap();
356 assert!(enforcer.can_accept(ip));
357 }
358
359 #[test]
360 fn test_ip_limit_enforcement() {
361 let config = IPDiversityConfig {
362 max_per_ip: Some(1),
363 max_per_subnet: Some(usize::MAX),
364 };
365 let mut enforcer = BootstrapIpLimiter::new(config);
366
367 let ip: IpAddr = "10.0.0.1".parse().unwrap();
368
369 assert!(enforcer.can_accept(ip));
371 enforcer.track(ip).unwrap();
372
373 assert!(!enforcer.can_accept(ip));
375 assert!(enforcer.track(ip).is_err());
376 }
377
378 #[test]
379 fn test_subnet_limit_enforcement_ipv4() {
380 let config = IPDiversityConfig {
381 max_per_ip: Some(usize::MAX),
382 max_per_subnet: Some(2),
383 };
384 let mut enforcer = BootstrapIpLimiter::new(config);
385
386 let ip1: IpAddr = "10.0.1.1".parse().unwrap();
388 let ip2: IpAddr = "10.0.1.2".parse().unwrap();
389 let ip3: IpAddr = "10.0.1.3".parse().unwrap();
390
391 enforcer.track(ip1).unwrap();
392 enforcer.track(ip2).unwrap();
393
394 assert!(!enforcer.can_accept(ip3));
396 assert!(enforcer.track(ip3).is_err());
397
398 let ip_other: IpAddr = "10.0.2.1".parse().unwrap();
400 assert!(enforcer.can_accept(ip_other));
401 }
402
403 #[test]
404 fn test_subnet_limit_enforcement_ipv6() {
405 let config = IPDiversityConfig {
406 max_per_ip: Some(usize::MAX),
407 max_per_subnet: Some(1),
408 };
409 let mut enforcer = BootstrapIpLimiter::new(config);
410
411 let ip1: IpAddr = "2001:db8:85a3:1234::1".parse().unwrap();
413 let ip2: IpAddr = "2001:db8:85a3:5678::2".parse().unwrap();
414
415 enforcer.track(ip1).unwrap();
416
417 assert!(!enforcer.can_accept(ip2));
419
420 let ip_other: IpAddr = "2001:db8:aaaa::1".parse().unwrap();
422 assert!(enforcer.can_accept(ip_other));
423 }
424
425 #[test]
426 fn test_track_and_untrack() {
427 let config = IPDiversityConfig {
428 max_per_ip: Some(1),
429 max_per_subnet: Some(usize::MAX),
430 };
431 let mut enforcer = BootstrapIpLimiter::new(config);
432
433 let ip: IpAddr = "10.0.0.1".parse().unwrap();
434
435 enforcer.track(ip).unwrap();
437 assert!(!enforcer.can_accept(ip));
438
439 enforcer.untrack(ip);
441 assert!(enforcer.can_accept(ip));
442
443 enforcer.track(ip).unwrap();
445 assert!(!enforcer.can_accept(ip));
446 }
447
448 #[test]
449 fn test_loopback_bypass() {
450 let config = IPDiversityConfig {
451 max_per_ip: Some(1),
452 max_per_subnet: Some(1),
453 };
454
455 let enforcer = BootstrapIpLimiter::with_loopback(config.clone(), true);
457 let loopback_v4: IpAddr = "127.0.0.1".parse().unwrap();
458 let loopback_v6: IpAddr = "::1".parse().unwrap();
459 assert!(enforcer.can_accept(loopback_v4));
460 assert!(enforcer.can_accept(loopback_v6));
461
462 let enforcer_no_lb = BootstrapIpLimiter::new(config);
464 assert!(
465 !enforcer_no_lb.can_accept(loopback_v4),
466 "loopback should be rejected when allow_loopback=false"
467 );
468 assert!(
469 !enforcer_no_lb.can_accept(loopback_v6),
470 "loopback IPv6 should be rejected when allow_loopback=false"
471 );
472 }
473
474 #[test]
475 fn test_subnet_key_ipv4() {
476 let ip: IpAddr = "192.168.42.100".parse().unwrap();
477 let subnet = BootstrapIpLimiter::subnet_key(ip);
478 let expected: IpAddr = "192.168.42.0".parse().unwrap();
479 assert_eq!(subnet, expected);
480 }
481
482 #[test]
483 fn test_subnet_key_ipv6() {
484 let ip: IpAddr = "2001:db8:85a3:1234:5678:8a2e:0370:7334".parse().unwrap();
485 let subnet = BootstrapIpLimiter::subnet_key(ip);
486 let expected: IpAddr = "2001:db8:85a3::".parse().unwrap();
487 assert_eq!(subnet, expected);
488 }
489
490 #[test]
491 fn test_default_ip_limit_is_two() {
492 let config = IPDiversityConfig::default();
493 let mut enforcer = BootstrapIpLimiter::new(config);
494
495 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
496
497 enforcer.track(ip1).unwrap();
499 enforcer.track(ip1).unwrap();
500
501 assert!(!enforcer.can_accept(ip1));
503 }
504
505 #[test]
506 fn test_default_subnet_limit_matches_k() {
507 let config = IPDiversityConfig::default();
509 let mut enforcer = BootstrapIpLimiter::new(config);
510
511 for i in 1..=5 {
513 let ip: IpAddr = format!("10.0.1.{i}").parse().unwrap();
514 enforcer.track(ip).unwrap();
515 }
516
517 let ip6: IpAddr = "10.0.1.6".parse().unwrap();
519 assert!(
520 !enforcer.can_accept(ip6),
521 "6th peer in same /24 should exceed K/4=5 subnet limit"
522 );
523 }
524
525 #[test]
526 fn test_ipv4_mapped_ipv6_counts_as_ipv4() {
527 let config = IPDiversityConfig {
528 max_per_ip: Some(1),
529 max_per_subnet: Some(usize::MAX),
530 };
531 let mut enforcer = BootstrapIpLimiter::new(config);
532
533 let ipv4: IpAddr = "10.0.0.1".parse().unwrap();
535 enforcer.track(ipv4).unwrap();
536
537 let mapped: IpAddr = "::ffff:10.0.0.1".parse().unwrap();
539 assert!(
540 !enforcer.can_accept(mapped),
541 "IPv4-mapped IPv6 should be canonicalized and hit the IPv4 limit"
542 );
543 }
544
545 #[test]
546 fn test_multicast_rejected() {
547 let config = IPDiversityConfig::default();
548 let enforcer = BootstrapIpLimiter::new(config);
549
550 let multicast_v4: IpAddr = "224.0.0.1".parse().unwrap();
551 assert!(!enforcer.can_accept(multicast_v4));
552
553 let multicast_v6: IpAddr = "ff02::1".parse().unwrap();
554 assert!(!enforcer.can_accept(multicast_v6));
555 }
556
557 #[test]
558 fn test_unspecified_rejected() {
559 let config = IPDiversityConfig::default();
560 let enforcer = BootstrapIpLimiter::new(config);
561
562 let unspec_v4: IpAddr = "0.0.0.0".parse().unwrap();
563 assert!(!enforcer.can_accept(unspec_v4));
564
565 let unspec_v6: IpAddr = "::".parse().unwrap();
566 assert!(!enforcer.can_accept(unspec_v6));
567 }
568
569 #[test]
570 fn test_untrack_ipv4_mapped_ipv6() {
571 let config = IPDiversityConfig {
572 max_per_ip: Some(1),
573 max_per_subnet: Some(usize::MAX),
574 };
575 let mut enforcer = BootstrapIpLimiter::new(config);
576
577 let ipv4: IpAddr = "10.0.0.1".parse().unwrap();
579 enforcer.track(ipv4).unwrap();
580 assert!(!enforcer.can_accept(ipv4));
581
582 let mapped: IpAddr = "::ffff:10.0.0.1".parse().unwrap();
584 enforcer.untrack(mapped);
585 assert!(
586 enforcer.can_accept(ipv4),
587 "untrack via mapped form should decrement the IPv4 counter"
588 );
589 }
590}