1use crate::error::{OverlayError, Result};
7use ipnet::IpNet;
8use serde::{Deserialize, Serialize};
9use std::collections::HashSet;
10use std::net::{IpAddr, Ipv6Addr};
11use std::path::Path;
12
13#[derive(Debug, Clone)]
18pub struct IpAllocator {
19 network: IpNet,
21 allocated: HashSet<IpAddr>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct IpAllocatorState {
28 pub cidr: String,
30 pub allocated: Vec<IpAddr>,
32}
33
34fn ipv6_add(base: Ipv6Addr, offset: u128) -> Option<Ipv6Addr> {
38 let base_u128 = u128::from(base);
39 base_u128.checked_add(offset).map(Ipv6Addr::from)
40}
41
42fn host_count(is_ipv6: bool, prefix_len: u8) -> u128 {
49 if is_ipv6 {
50 let bits = 128 - u32::from(prefix_len);
51 if bits == 128 {
52 u128::MAX
54 } else if bits == 0 {
55 0
57 } else {
58 (1u128 << bits) - 1
60 }
61 } else {
62 let bits = 32 - u32::from(prefix_len);
63 if bits <= 1 {
64 0
66 } else {
67 (1u128 << bits) - 2
69 }
70 }
71}
72
73impl IpAllocator {
74 pub fn new(cidr: &str) -> Result<Self> {
93 let network: IpNet = cidr
94 .parse()
95 .map_err(|e| OverlayError::InvalidCidr(format!("{cidr}: {e}")))?;
96
97 Ok(Self {
98 network,
99 allocated: HashSet::new(),
100 })
101 }
102
103 pub fn from_state(state: IpAllocatorState) -> Result<Self> {
109 let mut allocator = Self::new(&state.cidr)?;
110 for ip in state.allocated {
111 allocator.mark_allocated(ip)?;
112 }
113 Ok(allocator)
114 }
115
116 #[must_use]
118 pub fn to_state(&self) -> IpAllocatorState {
119 IpAllocatorState {
120 cidr: self.network.to_string(),
121 allocated: self.allocated.iter().copied().collect(),
122 }
123 }
124
125 pub async fn load(path: &Path) -> Result<Self> {
131 let contents = tokio::fs::read_to_string(path).await?;
132 let state: IpAllocatorState = serde_json::from_str(&contents)?;
133 Self::from_state(state)
134 }
135
136 pub async fn save(&self, path: &Path) -> Result<()> {
142 let state = self.to_state();
143 let contents = serde_json::to_string_pretty(&state)?;
144 tokio::fs::write(path, contents).await?;
145 Ok(())
146 }
147
148 pub fn allocate(&mut self) -> Option<IpAddr> {
164 match self.network {
165 IpNet::V4(v4net) => {
166 for ip in v4net.hosts() {
168 let addr = IpAddr::V4(ip);
169 if !self.allocated.contains(&addr) {
170 self.allocated.insert(addr);
171 return Some(addr);
172 }
173 }
174 None
175 }
176 IpNet::V6(v6net) => {
177 let base = v6net.network();
180 let total = host_count(true, v6net.prefix_len());
181
182 for offset in 1..=total {
183 if let Some(candidate) = ipv6_add(base, offset) {
184 let addr = IpAddr::V6(candidate);
185 if !self.allocated.contains(&addr) {
186 self.allocated.insert(addr);
187 return Some(addr);
188 }
189 } else {
190 break;
191 }
192 }
193 None
194 }
195 }
196 }
197
198 pub fn allocate_specific(&mut self, ip: IpAddr) -> Result<()> {
204 if !self.network.contains(&ip) {
205 return Err(OverlayError::IpNotInRange(ip, self.network.to_string()));
206 }
207
208 if self.allocated.contains(&ip) {
209 return Err(OverlayError::IpAlreadyAllocated(ip));
210 }
211
212 self.allocated.insert(ip);
213 Ok(())
214 }
215
216 pub fn allocate_first(&mut self) -> Result<IpAddr> {
231 let first_ip = self.first_host().ok_or(OverlayError::NoAvailableIps)?;
232
233 if self.allocated.contains(&first_ip) {
234 return Err(OverlayError::IpAlreadyAllocated(first_ip));
235 }
236
237 self.allocated.insert(first_ip);
238 Ok(first_ip)
239 }
240
241 fn first_host(&self) -> Option<IpAddr> {
246 match self.network {
247 IpNet::V4(v4net) => v4net.hosts().next().map(IpAddr::V4),
248 IpNet::V6(v6net) => {
249 let base = v6net.network();
250 ipv6_add(base, 1).map(IpAddr::V6)
251 }
252 }
253 }
254
255 pub fn mark_allocated(&mut self, ip: IpAddr) -> Result<()> {
261 if !self.network.contains(&ip) {
262 return Err(OverlayError::IpNotInRange(ip, self.network.to_string()));
263 }
264 self.allocated.insert(ip);
265 Ok(())
266 }
267
268 pub fn release(&mut self, ip: IpAddr) -> bool {
272 self.allocated.remove(&ip)
273 }
274
275 #[must_use]
277 pub fn is_allocated(&self, ip: IpAddr) -> bool {
278 self.allocated.contains(&ip)
279 }
280
281 #[must_use]
283 pub fn contains(&self, ip: IpAddr) -> bool {
284 self.network.contains(&ip)
285 }
286
287 #[must_use]
289 pub fn allocated_count(&self) -> usize {
290 self.allocated.len()
291 }
292
293 #[must_use]
297 #[allow(clippy::cast_possible_truncation)]
298 pub fn total_hosts(&self) -> u32 {
299 let is_v6 = matches!(self.network, IpNet::V6(_));
300 let count = host_count(is_v6, self.network.prefix_len());
301 if count > u128::from(u32::MAX) {
303 u32::MAX
304 } else {
305 count as u32
306 }
307 }
308
309 #[must_use]
311 #[allow(clippy::cast_possible_truncation)]
312 pub fn available_count(&self) -> u32 {
313 self.total_hosts()
314 .saturating_sub(self.allocated.len() as u32)
315 }
316
317 #[must_use]
319 pub fn cidr(&self) -> String {
320 self.network.to_string()
321 }
322
323 #[must_use]
325 pub fn network_addr(&self) -> IpAddr {
326 self.network.network()
327 }
328
329 #[must_use]
333 pub fn broadcast_addr(&self) -> IpAddr {
334 self.network.broadcast()
335 }
336
337 #[must_use]
339 pub fn prefix_len(&self) -> u8 {
340 self.network.prefix_len()
341 }
342
343 #[must_use]
345 pub fn host_prefix_len(&self) -> u8 {
346 self.network.max_prefix_len()
347 }
348
349 #[must_use]
351 pub fn allocated_ips(&self) -> Vec<IpAddr> {
352 self.allocated.iter().copied().collect()
353 }
354}
355
356pub fn first_ip_from_cidr(cidr: &str) -> Result<IpAddr> {
364 let network: IpNet = cidr
365 .parse()
366 .map_err(|e| OverlayError::InvalidCidr(format!("{cidr}: {e}")))?;
367
368 match network {
369 IpNet::V4(v4net) => v4net
370 .hosts()
371 .next()
372 .map(IpAddr::V4)
373 .ok_or(OverlayError::NoAvailableIps),
374 IpNet::V6(v6net) => {
375 let base = v6net.network();
376 ipv6_add(base, 1)
377 .map(IpAddr::V6)
378 .ok_or(OverlayError::NoAvailableIps)
379 }
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use std::net::{Ipv4Addr, Ipv6Addr};
387
388 fn ipv4_add(base: Ipv4Addr, offset: u32) -> Option<Ipv4Addr> {
392 let base_u32 = u32::from(base);
393 base_u32.checked_add(offset).map(Ipv4Addr::from)
394 }
395
396 #[test]
401 fn test_allocator_new() {
402 let allocator = IpAllocator::new("10.200.0.0/24").unwrap();
403 assert_eq!(allocator.cidr(), "10.200.0.0/24");
404 assert_eq!(allocator.allocated_count(), 0);
405 }
406
407 #[test]
408 fn test_allocator_invalid_cidr() {
409 let result = IpAllocator::new("invalid");
410 assert!(result.is_err());
411 }
412
413 #[test]
414 fn test_allocate_sequential() {
415 let mut allocator = IpAllocator::new("10.200.0.0/30").unwrap();
416
417 let ip1 = allocator.allocate().unwrap();
419 let ip2 = allocator.allocate().unwrap();
420
421 assert_eq!(ip1.to_string(), "10.200.0.1");
422 assert_eq!(ip2.to_string(), "10.200.0.2");
423
424 assert!(allocator.allocate().is_none());
426 }
427
428 #[test]
429 fn test_allocate_first() {
430 let mut allocator = IpAllocator::new("10.200.0.0/24").unwrap();
431
432 let first = allocator.allocate_first().unwrap();
433 assert_eq!(first.to_string(), "10.200.0.1");
434
435 assert!(allocator.allocate_first().is_err());
437 }
438
439 #[test]
440 fn test_allocate_specific() {
441 let mut allocator = IpAllocator::new("10.200.0.0/24").unwrap();
442
443 let specific_ip: IpAddr = "10.200.0.50".parse().unwrap();
444 allocator.allocate_specific(specific_ip).unwrap();
445
446 assert!(allocator.is_allocated(specific_ip));
447
448 assert!(allocator.allocate_specific(specific_ip).is_err());
450 }
451
452 #[test]
453 fn test_allocate_specific_out_of_range() {
454 let mut allocator = IpAllocator::new("10.200.0.0/24").unwrap();
455
456 let out_of_range: IpAddr = "192.168.1.1".parse().unwrap();
457 assert!(allocator.allocate_specific(out_of_range).is_err());
458 }
459
460 #[test]
461 fn test_release() {
462 let mut allocator = IpAllocator::new("10.200.0.0/24").unwrap();
463
464 let ip = allocator.allocate().unwrap();
465 assert!(allocator.is_allocated(ip));
466
467 assert!(allocator.release(ip));
468 assert!(!allocator.is_allocated(ip));
469
470 let ip2 = allocator.allocate().unwrap();
472 assert_eq!(ip, ip2);
473 }
474
475 #[test]
476 fn test_mark_allocated() {
477 let mut allocator = IpAllocator::new("10.200.0.0/24").unwrap();
478
479 let ip: IpAddr = "10.200.0.100".parse().unwrap();
480 allocator.mark_allocated(ip).unwrap();
481
482 assert!(allocator.is_allocated(ip));
483 }
484
485 #[test]
486 fn test_contains() {
487 let allocator = IpAllocator::new("10.200.0.0/24").unwrap();
488
489 assert!(allocator.contains("10.200.0.50".parse().unwrap()));
490 assert!(!allocator.contains("10.201.0.50".parse().unwrap()));
491 }
492
493 #[test]
494 fn test_total_hosts() {
495 let allocator = IpAllocator::new("10.200.0.0/24").unwrap();
497 assert_eq!(allocator.total_hosts(), 254);
498
499 let allocator = IpAllocator::new("10.200.0.0/30").unwrap();
501 assert_eq!(allocator.total_hosts(), 2);
502 }
503
504 #[test]
505 fn test_available_count() {
506 let mut allocator = IpAllocator::new("10.200.0.0/30").unwrap();
507
508 assert_eq!(allocator.available_count(), 2);
509
510 allocator.allocate();
511 assert_eq!(allocator.available_count(), 1);
512
513 allocator.allocate();
514 assert_eq!(allocator.available_count(), 0);
515 }
516
517 #[test]
518 fn test_state_roundtrip() {
519 let mut allocator = IpAllocator::new("10.200.0.0/24").unwrap();
520 allocator.allocate();
521 allocator.allocate();
522
523 let state = allocator.to_state();
524 let restored = IpAllocator::from_state(state).unwrap();
525
526 assert_eq!(allocator.cidr(), restored.cidr());
527 assert_eq!(allocator.allocated_count(), restored.allocated_count());
528 }
529
530 #[test]
531 fn test_first_ip_from_cidr() {
532 let ip = first_ip_from_cidr("10.200.0.0/24").unwrap();
533 assert_eq!(ip.to_string(), "10.200.0.1");
534 }
535
536 #[test]
537 fn test_network_addr_v4() {
538 let allocator = IpAllocator::new("10.200.0.0/24").unwrap();
539 assert_eq!(
540 allocator.network_addr(),
541 IpAddr::V4("10.200.0.0".parse().unwrap())
542 );
543 }
544
545 #[test]
546 fn test_broadcast_addr_v4() {
547 let allocator = IpAllocator::new("10.200.0.0/24").unwrap();
548 assert_eq!(
549 allocator.broadcast_addr(),
550 IpAddr::V4("10.200.0.255".parse().unwrap())
551 );
552 }
553
554 #[test]
555 fn test_host_prefix_len_v4() {
556 let allocator = IpAllocator::new("10.200.0.0/24").unwrap();
557 assert_eq!(allocator.host_prefix_len(), 32);
558 }
559
560 #[test]
565 fn test_allocator_new_v6() {
566 let allocator = IpAllocator::new("fd00::/48").unwrap();
567 assert_eq!(allocator.cidr(), "fd00::/48");
568 assert_eq!(allocator.allocated_count(), 0);
569 }
570
571 #[test]
572 fn test_allocate_sequential_v6() {
573 let mut allocator = IpAllocator::new("fd00::/126").unwrap();
574
575 let ip1 = allocator.allocate().unwrap();
577 let ip2 = allocator.allocate().unwrap();
578 let ip3 = allocator.allocate().unwrap();
579
580 assert_eq!(ip1.to_string(), "fd00::1");
581 assert_eq!(ip2.to_string(), "fd00::2");
582 assert_eq!(ip3.to_string(), "fd00::3");
583
584 assert!(allocator.allocate().is_none());
586 }
587
588 #[test]
589 fn test_allocate_first_v6() {
590 let mut allocator = IpAllocator::new("fd00::/48").unwrap();
591
592 let first = allocator.allocate_first().unwrap();
593 assert_eq!(first.to_string(), "fd00::1");
594
595 assert!(allocator.allocate_first().is_err());
597 }
598
599 #[test]
600 fn test_allocate_specific_v6() {
601 let mut allocator = IpAllocator::new("fd00::/48").unwrap();
602
603 let specific_ip: IpAddr = "fd00::beef".parse().unwrap();
604 allocator.allocate_specific(specific_ip).unwrap();
605
606 assert!(allocator.is_allocated(specific_ip));
607
608 assert!(allocator.allocate_specific(specific_ip).is_err());
610 }
611
612 #[test]
613 fn test_allocate_specific_out_of_range_v6() {
614 let mut allocator = IpAllocator::new("fd00::/48").unwrap();
615
616 let out_of_range: IpAddr = "fe80::1".parse().unwrap();
617 assert!(allocator.allocate_specific(out_of_range).is_err());
618 }
619
620 #[test]
621 fn test_release_v6() {
622 let mut allocator = IpAllocator::new("fd00::/48").unwrap();
623
624 let ip = allocator.allocate().unwrap();
625 assert!(allocator.is_allocated(ip));
626
627 assert!(allocator.release(ip));
628 assert!(!allocator.is_allocated(ip));
629
630 let ip2 = allocator.allocate().unwrap();
632 assert_eq!(ip, ip2);
633 }
634
635 #[test]
636 fn test_mark_allocated_v6() {
637 let mut allocator = IpAllocator::new("fd00::/48").unwrap();
638
639 let ip: IpAddr = "fd00::ff".parse().unwrap();
640 allocator.mark_allocated(ip).unwrap();
641
642 assert!(allocator.is_allocated(ip));
643 }
644
645 #[test]
646 fn test_contains_v6() {
647 let allocator = IpAllocator::new("fd00::/48").unwrap();
648
649 assert!(allocator.contains("fd00::50".parse().unwrap()));
650 assert!(!allocator.contains("fe80::1".parse().unwrap()));
651 }
652
653 #[test]
654 fn test_total_hosts_v6_small() {
655 let allocator = IpAllocator::new("fd00::/126").unwrap();
657 assert_eq!(allocator.total_hosts(), 3);
658
659 let allocator = IpAllocator::new("fd00::/127").unwrap();
661 assert_eq!(allocator.total_hosts(), 1);
662 }
663
664 #[test]
665 fn test_total_hosts_v6_large() {
666 let allocator = IpAllocator::new("fd00::/48").unwrap();
668 assert_eq!(allocator.total_hosts(), u32::MAX);
669 }
670
671 #[test]
672 fn test_available_count_v6() {
673 let mut allocator = IpAllocator::new("fd00::/126").unwrap();
674
675 assert_eq!(allocator.available_count(), 3);
676
677 allocator.allocate();
678 assert_eq!(allocator.available_count(), 2);
679
680 allocator.allocate();
681 assert_eq!(allocator.available_count(), 1);
682
683 allocator.allocate();
684 assert_eq!(allocator.available_count(), 0);
685 }
686
687 #[test]
688 fn test_state_roundtrip_v6() {
689 let mut allocator = IpAllocator::new("fd00::/48").unwrap();
690 allocator.allocate();
691 allocator.allocate();
692
693 let state = allocator.to_state();
694
695 let json = serde_json::to_string_pretty(&state).unwrap();
697 assert!(json.contains("fd00::1"));
698 assert!(json.contains("fd00::2"));
699
700 let restored = IpAllocator::from_state(state).unwrap();
701
702 assert_eq!(allocator.cidr(), restored.cidr());
703 assert_eq!(allocator.allocated_count(), restored.allocated_count());
704 }
705
706 #[test]
707 fn test_first_ip_from_cidr_v6() {
708 let ip = first_ip_from_cidr("fd00::/48").unwrap();
709 assert_eq!(ip.to_string(), "fd00::1");
710 }
711
712 #[test]
713 fn test_network_addr_v6() {
714 let allocator = IpAllocator::new("fd00::/48").unwrap();
715 assert_eq!(
716 allocator.network_addr(),
717 IpAddr::V6("fd00::".parse().unwrap())
718 );
719 }
720
721 #[test]
722 fn test_broadcast_addr_v6() {
723 let allocator = IpAllocator::new("fd00::/126").unwrap();
724 assert_eq!(
725 allocator.broadcast_addr(),
726 IpAddr::V6("fd00::3".parse().unwrap())
727 );
728 }
729
730 #[test]
731 fn test_host_prefix_len_v6() {
732 let allocator = IpAllocator::new("fd00::/48").unwrap();
733 assert_eq!(allocator.host_prefix_len(), 128);
734 }
735
736 #[test]
741 fn test_v4_and_v6_allocators_independent() {
742 let mut v4 = IpAllocator::new("10.200.0.0/30").unwrap();
743 let mut v6 = IpAllocator::new("fd00::/126").unwrap();
744
745 let v4_ip = v4.allocate().unwrap();
746 let v6_ip = v6.allocate().unwrap();
747
748 assert!(v4_ip.is_ipv4());
749 assert!(v6_ip.is_ipv6());
750 assert_eq!(v4_ip.to_string(), "10.200.0.1");
751 assert_eq!(v6_ip.to_string(), "fd00::1");
752 }
753
754 #[test]
755 fn test_ipv6_does_not_contain_ipv4() {
756 let allocator = IpAllocator::new("fd00::/48").unwrap();
757 assert!(!allocator.contains("10.200.0.1".parse().unwrap()));
758 }
759
760 #[test]
761 fn test_ipv4_does_not_contain_ipv6() {
762 let allocator = IpAllocator::new("10.200.0.0/24").unwrap();
763 assert!(!allocator.contains("fd00::1".parse().unwrap()));
764 }
765
766 #[test]
767 fn test_allocate_specific_wrong_family() {
768 let mut v4_alloc = IpAllocator::new("10.200.0.0/24").unwrap();
769 let v6_ip: IpAddr = "fd00::1".parse().unwrap();
770 assert!(v4_alloc.allocate_specific(v6_ip).is_err());
771
772 let mut v6_alloc = IpAllocator::new("fd00::/48").unwrap();
773 let v4_ip: IpAddr = "10.200.0.1".parse().unwrap();
774 assert!(v6_alloc.allocate_specific(v4_ip).is_err());
775 }
776
777 #[test]
782 fn test_ipv4_add() {
783 let base: Ipv4Addr = "10.0.0.0".parse().unwrap();
784 assert_eq!(ipv4_add(base, 1), Some("10.0.0.1".parse().unwrap()));
785 assert_eq!(ipv4_add(base, 256), Some("10.0.1.0".parse().unwrap()));
786 }
787
788 #[test]
789 fn test_ipv4_add_overflow() {
790 let base: Ipv4Addr = "255.255.255.255".parse().unwrap();
791 assert_eq!(ipv4_add(base, 1), None);
792 }
793
794 #[test]
795 fn test_ipv6_add() {
796 let base: Ipv6Addr = "fd00::".parse().unwrap();
797 assert_eq!(ipv6_add(base, 1), Some("fd00::1".parse().unwrap()));
798 assert_eq!(ipv6_add(base, 0xffff), Some("fd00::ffff".parse().unwrap()));
799 }
800
801 #[test]
802 fn test_ipv6_add_overflow() {
803 let base: Ipv6Addr = "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff".parse().unwrap();
804 assert_eq!(ipv6_add(base, 1), None);
805 }
806
807 #[test]
808 fn test_host_count_v4() {
809 assert_eq!(host_count(false, 24), 254); assert_eq!(host_count(false, 30), 2); assert_eq!(host_count(false, 16), 65534); assert_eq!(host_count(false, 31), 0); assert_eq!(host_count(false, 32), 0); }
815
816 #[test]
817 fn test_host_count_v6() {
818 assert_eq!(host_count(true, 126), 3); assert_eq!(host_count(true, 127), 1); assert_eq!(host_count(true, 128), 0); assert_eq!(host_count(true, 64), (1u128 << 64) - 1); }
823}