1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::net::IpAddr;
9use tracing::debug;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum AccessDecision {
14 Allow,
16 Deny,
18 NoMatch,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
24#[serde(rename_all = "lowercase")]
25pub enum AccessAction {
26 Allow,
27 #[default]
28 Deny,
29}
30
31#[derive(Debug, Clone)]
33pub struct CidrRange {
34 network: IpAddr,
36 prefix_len: u8,
38}
39
40impl CidrRange {
41 pub fn parse(cidr: &str) -> Result<Self, AccessError> {
43 use std::str::FromStr;
44 Self::from_str(cidr)
45 }
46
47 pub fn contains(&self, ip: &IpAddr) -> bool {
49 match (&self.network, ip) {
50 (IpAddr::V4(net), IpAddr::V4(addr)) => {
51 let net_bits = u32::from_be_bytes(net.octets());
52 let addr_bits = u32::from_be_bytes(addr.octets());
53 let mask = if self.prefix_len == 0 {
54 0
55 } else {
56 !0u32 << (32 - self.prefix_len)
57 };
58 (addr_bits & mask) == (net_bits & mask)
59 }
60 (IpAddr::V6(net), IpAddr::V6(addr)) => {
61 let net_bits = u128::from_be_bytes(net.octets());
62 let addr_bits = u128::from_be_bytes(addr.octets());
63 let mask = if self.prefix_len == 0 {
64 0
65 } else {
66 !0u128 << (128 - self.prefix_len)
67 };
68 (addr_bits & mask) == (net_bits & mask)
69 }
70 _ => false,
72 }
73 }
74}
75
76impl std::str::FromStr for CidrRange {
77 type Err = AccessError;
78
79 fn from_str(cidr: &str) -> Result<Self, Self::Err> {
80 let (addr_str, prefix_str) = if let Some(idx) = cidr.find('/') {
81 (&cidr[..idx], Some(&cidr[idx + 1..]))
82 } else {
83 (cidr, None)
84 };
85
86 let network: IpAddr = addr_str.parse().map_err(|_| AccessError::InvalidCidr {
87 cidr: cidr.to_string(),
88 reason: "invalid IP address".to_string(),
89 })?;
90
91 let max_prefix = match network {
92 IpAddr::V4(_) => 32,
93 IpAddr::V6(_) => 128,
94 };
95
96 let prefix_len: u8 = match prefix_str {
97 Some(s) => s.parse().map_err(|_| AccessError::InvalidCidr {
98 cidr: cidr.to_string(),
99 reason: "invalid prefix length".to_string(),
100 })?,
101 None => max_prefix,
102 };
103
104 if prefix_len > max_prefix {
105 return Err(AccessError::InvalidCidr {
106 cidr: cidr.to_string(),
107 reason: format!(
108 "prefix length {} exceeds maximum {}",
109 prefix_len, max_prefix
110 ),
111 });
112 }
113
114 Ok(Self {
115 network,
116 prefix_len,
117 })
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct AccessRule {
124 pub cidr: CidrRange,
126 pub action: AccessAction,
128 pub comment: Option<String>,
130}
131
132impl AccessRule {
133 pub fn allow(cidr: &str) -> Result<Self, AccessError> {
135 Ok(Self {
136 cidr: CidrRange::parse(cidr)?,
137 action: AccessAction::Allow,
138 comment: None,
139 })
140 }
141
142 pub fn deny(cidr: &str) -> Result<Self, AccessError> {
144 Ok(Self {
145 cidr: CidrRange::parse(cidr)?,
146 action: AccessAction::Deny,
147 comment: None,
148 })
149 }
150
151 pub fn with_comment(mut self, comment: &str) -> Self {
153 self.comment = Some(comment.to_string());
154 self
155 }
156
157 pub fn matches(&self, ip: &IpAddr) -> bool {
159 self.cidr.contains(ip)
160 }
161}
162
163#[derive(Debug, Default)]
165pub struct AccessList {
166 rules: Vec<AccessRule>,
168 default_action: AccessAction,
170}
171
172impl AccessList {
173 pub fn new() -> Self {
175 Self {
176 rules: Vec::new(),
177 default_action: AccessAction::Deny,
178 }
179 }
180
181 pub fn allow_all() -> Self {
183 Self {
184 rules: Vec::new(),
185 default_action: AccessAction::Allow,
186 }
187 }
188
189 pub fn deny_all() -> Self {
191 Self {
192 rules: Vec::new(),
193 default_action: AccessAction::Deny,
194 }
195 }
196
197 pub fn add_rule(&mut self, rule: AccessRule) {
199 self.rules.push(rule);
200 }
201
202 pub fn allow(&mut self, cidr: &str) -> Result<(), AccessError> {
204 self.rules.push(AccessRule::allow(cidr)?);
205 Ok(())
206 }
207
208 pub fn deny(&mut self, cidr: &str) -> Result<(), AccessError> {
210 self.rules.push(AccessRule::deny(cidr)?);
211 Ok(())
212 }
213
214 pub fn set_default(&mut self, action: AccessAction) {
216 self.default_action = action;
217 }
218
219 pub fn check(&self, ip: &IpAddr) -> AccessDecision {
221 for rule in &self.rules {
222 if rule.matches(ip) {
223 debug!(
224 "IP {} matched rule {:?} -> {:?}",
225 ip, rule.cidr.network, rule.action
226 );
227 return match rule.action {
228 AccessAction::Allow => AccessDecision::Allow,
229 AccessAction::Deny => AccessDecision::Deny,
230 };
231 }
232 }
233
234 debug!(
235 "IP {} no match, using default {:?}",
236 ip, self.default_action
237 );
238 match self.default_action {
239 AccessAction::Allow => AccessDecision::Allow,
240 AccessAction::Deny => AccessDecision::Deny,
241 }
242 }
243
244 pub fn is_allowed(&self, ip: &IpAddr) -> bool {
246 matches!(self.check(ip), AccessDecision::Allow)
247 }
248
249 pub fn rule_count(&self) -> usize {
251 self.rules.len()
252 }
253}
254
255#[derive(Debug, Default)]
257pub struct AccessListManager {
258 lists: HashMap<String, AccessList>,
260 global: AccessList,
262}
263
264impl AccessListManager {
265 pub fn new() -> Self {
267 Self {
268 lists: HashMap::new(),
269 global: AccessList::allow_all(),
270 }
271 }
272
273 pub fn set_global(&mut self, list: AccessList) {
275 self.global = list;
276 }
277
278 pub fn add_site(&mut self, hostname: &str, list: AccessList) {
280 self.lists.insert(hostname.to_lowercase(), list);
281 }
282
283 pub fn remove_site(&mut self, hostname: &str) {
285 self.lists.remove(&hostname.to_lowercase());
286 }
287
288 pub fn check(&self, hostname: &str, ip: &IpAddr) -> AccessDecision {
296 let global_decision = self.global.check(ip);
298 if matches!(global_decision, AccessDecision::Deny) {
299 return AccessDecision::Deny;
300 }
301
302 let normalized = hostname.to_lowercase();
304 if let Some(site_list) = self.lists.get(&normalized) {
305 let site_decision = site_list.check(ip);
306 if !matches!(site_decision, AccessDecision::NoMatch) {
307 return site_decision;
308 }
309 }
310
311 global_decision
313 }
314
315 pub fn is_allowed(&self, hostname: &str, ip: &IpAddr) -> bool {
317 matches!(self.check(hostname, ip), AccessDecision::Allow)
318 }
319
320 pub fn site_count(&self) -> usize {
322 self.lists.len()
323 }
324
325 pub fn add_deny_ip(&mut self, ip: &IpAddr, comment: Option<&str>) -> Result<(), AccessError> {
336 let cidr = match ip {
337 IpAddr::V4(_) => format!("{}/32", ip),
338 IpAddr::V6(_) => format!("{}/128", ip),
339 };
340
341 let mut rule = AccessRule::deny(&cidr)?;
342 if let Some(c) = comment {
343 rule = rule.with_comment(c);
344 }
345
346 self.global.add_rule(rule);
347 tracing::info!(ip = %ip, comment = ?comment, "Added dynamic deny rule");
348 Ok(())
349 }
350
351 pub fn remove_deny_ip(&mut self, ip: &IpAddr) -> usize {
361 let ip_str = ip.to_string();
362
363 let before_count = self.global.rules.len();
364 self.global.rules.retain(|rule| {
365 let network_str = match rule.cidr.network {
367 IpAddr::V4(v4) => v4.to_string(),
368 IpAddr::V6(v6) => v6.to_string(),
369 };
370 !(network_str == ip_str && matches!(rule.action, AccessAction::Deny))
371 });
372 let removed = before_count - self.global.rules.len();
373
374 if removed > 0 {
375 tracing::info!(ip = %ip, removed = removed, "Removed dynamic deny rules");
376 }
377
378 removed
379 }
380
381 pub fn list_sites(&self) -> Vec<String> {
383 self.lists.keys().cloned().collect()
384 }
385
386 pub fn global_list(&self) -> &AccessList {
388 &self.global
389 }
390
391 pub fn global_list_mut(&mut self) -> &mut AccessList {
393 &mut self.global
394 }
395}
396
397#[derive(Debug, thiserror::Error)]
399pub enum AccessError {
400 #[error("invalid CIDR '{cidr}': {reason}")]
401 InvalidCidr { cidr: String, reason: String },
402}
403
404pub fn parse_ip(s: &str) -> Result<IpAddr, AccessError> {
406 let s = s.trim_start_matches('[').trim_end_matches(']');
408
409 s.parse().map_err(|_| AccessError::InvalidCidr {
410 cidr: s.to_string(),
411 reason: "invalid IP address format".to_string(),
412 })
413}
414
415pub fn extract_mapped_ipv4(ip: &IpAddr) -> Option<std::net::Ipv4Addr> {
425 match ip {
426 IpAddr::V6(v6) => {
427 let segments = v6.segments();
429 if segments[0] == 0
432 && segments[1] == 0
433 && segments[2] == 0
434 && segments[3] == 0
435 && segments[4] == 0
436 && segments[5] == 0xffff
437 {
438 let octets = v6.octets();
439 Some(std::net::Ipv4Addr::new(
440 octets[12], octets[13], octets[14], octets[15],
441 ))
442 } else {
443 None
444 }
445 }
446 _ => None,
447 }
448}
449
450fn is_private_ipv4(ip: &std::net::Ipv4Addr) -> bool {
457 let octets = ip.octets();
458 if octets[0] == 10 {
460 return true;
461 }
462 if octets[0] == 172 && (16..=31).contains(&octets[1]) {
464 return true;
465 }
466 if octets[0] == 192 && octets[1] == 168 {
468 return true;
469 }
470 false
471}
472
473fn is_loopback(ip: &IpAddr) -> bool {
479 match ip {
480 IpAddr::V4(v4) => v4.octets()[0] == 127,
481 IpAddr::V6(v6) => v6.is_loopback(),
482 }
483}
484
485fn is_link_local(ip: &IpAddr) -> bool {
491 match ip {
492 IpAddr::V4(v4) => {
493 let octets = v4.octets();
494 octets[0] == 169 && octets[1] == 254
495 }
496 IpAddr::V6(v6) => {
497 let segments = v6.segments();
499 (segments[0] & 0xffc0) == 0xfe80
500 }
501 }
502}
503
504fn is_cloud_metadata(ip: &IpAddr) -> bool {
511 match ip {
512 IpAddr::V4(v4) => {
513 let octets = v4.octets();
514 if octets == [169, 254, 169, 254] {
516 return true;
517 }
518 if octets == [169, 254, 170, 2] {
520 return true;
521 }
522 false
523 }
524 IpAddr::V6(_) => false,
525 }
526}
527
528pub fn is_ssrf_target(ip: &IpAddr) -> bool {
559 if let Some(mapped_v4) = extract_mapped_ipv4(ip) {
562 if mapped_v4.octets()[0] == 127 {
564 tracing::warn!(
565 ip = %ip,
566 mapped = %mapped_v4,
567 "SSRF attempt blocked: IPv6-mapped loopback"
568 );
569 return true;
570 }
571 if is_private_ipv4(&mapped_v4) {
572 tracing::warn!(
573 ip = %ip,
574 mapped = %mapped_v4,
575 "SSRF attempt blocked: IPv6-mapped private IP"
576 );
577 return true;
578 }
579 if is_cloud_metadata(&IpAddr::V4(mapped_v4)) {
580 tracing::warn!(
581 ip = %ip,
582 mapped = %mapped_v4,
583 "SSRF attempt blocked: IPv6-mapped cloud metadata"
584 );
585 return true;
586 }
587 if is_link_local(&IpAddr::V4(mapped_v4)) {
588 tracing::warn!(
589 ip = %ip,
590 mapped = %mapped_v4,
591 "SSRF attempt blocked: IPv6-mapped link-local"
592 );
593 return true;
594 }
595 return false;
597 }
598
599 if is_loopback(ip) {
601 tracing::debug!(ip = %ip, "SSRF blocked: loopback address");
602 return true;
603 }
604
605 if is_cloud_metadata(ip) {
606 tracing::warn!(ip = %ip, "SSRF blocked: cloud metadata endpoint");
607 return true;
608 }
609
610 if is_link_local(ip) {
611 tracing::debug!(ip = %ip, "SSRF blocked: link-local address");
612 return true;
613 }
614
615 if let IpAddr::V4(v4) = ip {
617 if is_private_ipv4(v4) {
618 tracing::debug!(ip = %ip, "SSRF blocked: private IPv4");
619 return true;
620 }
621 }
622
623 if let IpAddr::V6(v6) = ip {
625 let segments = v6.segments();
626 if (segments[0] & 0xfe00) == 0xfc00 {
628 tracing::debug!(ip = %ip, "SSRF blocked: IPv6 unique local");
629 return true;
630 }
631 }
632
633 false
634}
635
636#[derive(Debug, Clone, PartialEq, Eq)]
638pub enum SsrfCheckResult {
639 Safe,
641 Loopback,
643 Private,
645 LinkLocal,
647 CloudMetadata,
649 MappedBlocked {
651 mapped_v4: std::net::Ipv4Addr,
652 reason: &'static str,
653 },
654 Ipv6UniqueLocal,
656}
657
658impl SsrfCheckResult {
659 pub fn is_blocked(&self) -> bool {
661 !matches!(self, Self::Safe)
662 }
663}
664
665pub fn check_ssrf(ip: &IpAddr) -> SsrfCheckResult {
670 if let Some(mapped_v4) = extract_mapped_ipv4(ip) {
672 if mapped_v4.octets()[0] == 127 {
673 return SsrfCheckResult::MappedBlocked {
674 mapped_v4,
675 reason: "loopback",
676 };
677 }
678 if is_private_ipv4(&mapped_v4) {
679 return SsrfCheckResult::MappedBlocked {
680 mapped_v4,
681 reason: "private",
682 };
683 }
684 if is_cloud_metadata(&IpAddr::V4(mapped_v4)) {
685 return SsrfCheckResult::MappedBlocked {
686 mapped_v4,
687 reason: "cloud_metadata",
688 };
689 }
690 if is_link_local(&IpAddr::V4(mapped_v4)) {
691 return SsrfCheckResult::MappedBlocked {
692 mapped_v4,
693 reason: "link_local",
694 };
695 }
696 return SsrfCheckResult::Safe;
697 }
698
699 if is_loopback(ip) {
700 return SsrfCheckResult::Loopback;
701 }
702 if is_cloud_metadata(ip) {
703 return SsrfCheckResult::CloudMetadata;
704 }
705 if is_link_local(ip) {
706 return SsrfCheckResult::LinkLocal;
707 }
708 if let IpAddr::V4(v4) = ip {
709 if is_private_ipv4(v4) {
710 return SsrfCheckResult::Private;
711 }
712 }
713 if let IpAddr::V6(v6) = ip {
714 let segments = v6.segments();
715 if (segments[0] & 0xfe00) == 0xfc00 {
716 return SsrfCheckResult::Ipv6UniqueLocal;
717 }
718 }
719
720 SsrfCheckResult::Safe
721}
722
723#[cfg(test)]
724mod tests {
725 use super::*;
726
727 #[test]
728 fn test_cidr_parse_ipv4() {
729 let cidr = CidrRange::parse("192.168.1.0/24").unwrap();
730 assert!(cidr.contains(&"192.168.1.1".parse().unwrap()));
731 assert!(cidr.contains(&"192.168.1.254".parse().unwrap()));
732 assert!(!cidr.contains(&"192.168.2.1".parse().unwrap()));
733 }
734
735 #[test]
736 fn test_cidr_parse_ipv4_single() {
737 let cidr = CidrRange::parse("10.0.0.1").unwrap();
738 assert!(cidr.contains(&"10.0.0.1".parse().unwrap()));
739 assert!(!cidr.contains(&"10.0.0.2".parse().unwrap()));
740 }
741
742 #[test]
743 fn test_cidr_parse_ipv6() {
744 let cidr = CidrRange::parse("2001:db8::/32").unwrap();
745 assert!(cidr.contains(&"2001:db8::1".parse().unwrap()));
746 assert!(cidr.contains(&"2001:db8:ffff::1".parse().unwrap()));
747 assert!(!cidr.contains(&"2001:db9::1".parse().unwrap()));
748 }
749
750 #[test]
751 fn test_cidr_invalid() {
752 assert!(CidrRange::parse("not-an-ip").is_err());
753 assert!(CidrRange::parse("192.168.1.0/33").is_err());
754 assert!(CidrRange::parse("192.168.1.0/abc").is_err());
755 }
756
757 #[test]
758 fn test_access_rule_allow() {
759 let rule = AccessRule::allow("10.0.0.0/8").unwrap();
760 assert!(rule.matches(&"10.1.2.3".parse().unwrap()));
761 assert!(!rule.matches(&"192.168.1.1".parse().unwrap()));
762 }
763
764 #[test]
765 fn test_access_rule_deny() {
766 let rule = AccessRule::deny("192.168.0.0/16").unwrap();
767 assert_eq!(rule.action, AccessAction::Deny);
768 assert!(rule.matches(&"192.168.1.1".parse().unwrap()));
769 }
770
771 #[test]
772 fn test_access_list_allow_all() {
773 let list = AccessList::allow_all();
774 assert!(list.is_allowed(&"1.2.3.4".parse().unwrap()));
775 assert!(list.is_allowed(&"::1".parse().unwrap()));
776 }
777
778 #[test]
779 fn test_access_list_deny_all() {
780 let list = AccessList::deny_all();
781 assert!(!list.is_allowed(&"1.2.3.4".parse().unwrap()));
782 assert!(!list.is_allowed(&"::1".parse().unwrap()));
783 }
784
785 #[test]
786 fn test_access_list_rules() {
787 let mut list = AccessList::deny_all();
788 list.deny("10.0.0.1").unwrap(); list.allow("10.0.0.0/8").unwrap(); list.allow("192.168.1.0/24").unwrap();
793
794 assert!(!list.is_allowed(&"10.0.0.1".parse().unwrap())); assert!(list.is_allowed(&"10.0.0.2".parse().unwrap())); assert!(list.is_allowed(&"192.168.1.100".parse().unwrap()));
797 assert!(!list.is_allowed(&"8.8.8.8".parse().unwrap())); }
799
800 #[test]
801 fn test_access_list_manager() {
802 let mut manager = AccessListManager::new();
803
804 let mut global = AccessList::allow_all();
806 global.deny("1.2.3.4").unwrap();
807 manager.set_global(global);
808
809 let mut site_list = AccessList::deny_all();
811 site_list.allow("10.0.0.0/8").unwrap();
812 manager.add_site("internal.example.com", site_list);
813
814 assert!(!manager.is_allowed("any.com", &"1.2.3.4".parse().unwrap()));
816
817 assert!(manager.is_allowed("internal.example.com", &"10.0.0.1".parse().unwrap()));
819 assert!(!manager.is_allowed("internal.example.com", &"8.8.8.8".parse().unwrap()));
820
821 assert!(manager.is_allowed("public.example.com", &"8.8.8.8".parse().unwrap()));
823 }
824
825 #[test]
826 fn test_manager_case_insensitive() {
827 let mut manager = AccessListManager::new();
828 manager.add_site("Example.COM", AccessList::deny_all());
829
830 assert!(!manager.is_allowed("example.com", &"1.2.3.4".parse().unwrap()));
831 assert!(!manager.is_allowed("EXAMPLE.COM", &"1.2.3.4".parse().unwrap()));
832 }
833
834 #[test]
835 fn test_rule_with_comment() {
836 let rule = AccessRule::deny("0.0.0.0/0")
837 .unwrap()
838 .with_comment("Block all by default");
839
840 assert_eq!(rule.comment, Some("Block all by default".to_string()));
841 }
842
843 #[test]
844 fn test_parse_ip_formats() {
845 assert!(parse_ip("192.168.1.1").is_ok());
846 assert!(parse_ip("::1").is_ok());
847 assert!(parse_ip("[::1]").is_ok()); assert!(parse_ip("invalid").is_err());
849 }
850
851 #[test]
852 fn test_cidr_zero_prefix() {
853 let cidr = CidrRange::parse("0.0.0.0/0").unwrap();
854 assert!(cidr.contains(&"1.2.3.4".parse().unwrap()));
855 assert!(cidr.contains(&"255.255.255.255".parse().unwrap()));
856 }
857
858 #[test]
859 fn test_rule_count() {
860 let mut list = AccessList::new();
861 assert_eq!(list.rule_count(), 0);
862
863 list.allow("10.0.0.0/8").unwrap();
864 list.deny("192.168.0.0/16").unwrap();
865
866 assert_eq!(list.rule_count(), 2);
867 }
868
869 #[test]
872 fn test_extract_mapped_ipv4() {
873 let mapped_localhost: IpAddr = "::ffff:127.0.0.1".parse().unwrap();
875 let extracted = extract_mapped_ipv4(&mapped_localhost);
876 assert!(extracted.is_some());
877 assert_eq!(extracted.unwrap().to_string(), "127.0.0.1");
878
879 let mapped_private: IpAddr = "::ffff:192.168.1.1".parse().unwrap();
881 let extracted = extract_mapped_ipv4(&mapped_private);
882 assert!(extracted.is_some());
883 assert_eq!(extracted.unwrap().to_string(), "192.168.1.1");
884
885 let regular_v6: IpAddr = "2001:db8::1".parse().unwrap();
887 assert!(extract_mapped_ipv4(®ular_v6).is_none());
888
889 let v4: IpAddr = "127.0.0.1".parse().unwrap();
891 assert!(extract_mapped_ipv4(&v4).is_none());
892
893 let mapped_metadata: IpAddr = "::ffff:169.254.169.254".parse().unwrap();
895 let extracted = extract_mapped_ipv4(&mapped_metadata);
896 assert!(extracted.is_some());
897 assert_eq!(extracted.unwrap().to_string(), "169.254.169.254");
898 }
899
900 #[test]
901 fn test_ssrf_loopback() {
902 assert!(is_ssrf_target(&"127.0.0.1".parse().unwrap()));
904 assert!(is_ssrf_target(&"127.0.0.2".parse().unwrap()));
905 assert!(is_ssrf_target(&"127.255.255.255".parse().unwrap()));
906
907 assert!(is_ssrf_target(&"::1".parse().unwrap()));
909 }
910
911 #[test]
912 fn test_ssrf_private_ipv4() {
913 assert!(is_ssrf_target(&"10.0.0.1".parse().unwrap()));
915 assert!(is_ssrf_target(&"10.255.255.255".parse().unwrap()));
916
917 assert!(is_ssrf_target(&"172.16.0.1".parse().unwrap()));
919 assert!(is_ssrf_target(&"172.31.255.255".parse().unwrap()));
920 assert!(!is_ssrf_target(&"172.15.0.1".parse().unwrap())); assert!(!is_ssrf_target(&"172.32.0.1".parse().unwrap())); assert!(is_ssrf_target(&"192.168.0.1".parse().unwrap()));
925 assert!(is_ssrf_target(&"192.168.255.255".parse().unwrap()));
926 }
927
928 #[test]
929 fn test_ssrf_cloud_metadata() {
930 assert!(is_ssrf_target(&"169.254.169.254".parse().unwrap()));
932 assert!(is_ssrf_target(&"169.254.170.2".parse().unwrap()));
934 }
935
936 #[test]
937 fn test_ssrf_link_local() {
938 assert!(is_ssrf_target(&"169.254.0.1".parse().unwrap()));
940 assert!(is_ssrf_target(&"169.254.255.255".parse().unwrap()));
941
942 assert!(is_ssrf_target(&"fe80::1".parse().unwrap()));
944 assert!(is_ssrf_target(&"fe80::abcd:1234".parse().unwrap()));
945 }
946
947 #[test]
948 fn test_ssrf_ipv6_mapped_bypass_attempts() {
949 assert!(is_ssrf_target(&"::ffff:127.0.0.1".parse().unwrap()));
953
954 assert!(is_ssrf_target(&"::ffff:10.0.0.1".parse().unwrap()));
956 assert!(is_ssrf_target(&"::ffff:172.16.0.1".parse().unwrap()));
957 assert!(is_ssrf_target(&"::ffff:192.168.1.1".parse().unwrap()));
958
959 assert!(is_ssrf_target(&"::ffff:169.254.169.254".parse().unwrap()));
961
962 assert!(is_ssrf_target(&"::ffff:169.254.1.1".parse().unwrap()));
964
965 assert!(!is_ssrf_target(&"::ffff:8.8.8.8".parse().unwrap()));
967 }
968
969 #[test]
970 fn test_ssrf_ipv6_unique_local() {
971 assert!(is_ssrf_target(&"fc00::1".parse().unwrap()));
973 assert!(is_ssrf_target(&"fd00::1".parse().unwrap()));
974 assert!(is_ssrf_target(&"fdab:cdef::1234".parse().unwrap()));
975 }
976
977 #[test]
978 fn test_ssrf_public_ips_allowed() {
979 assert!(!is_ssrf_target(&"8.8.8.8".parse().unwrap()));
981 assert!(!is_ssrf_target(&"1.1.1.1".parse().unwrap()));
982 assert!(!is_ssrf_target(&"203.0.113.1".parse().unwrap()));
983
984 assert!(!is_ssrf_target(&"2001:4860:4860::8888".parse().unwrap()));
986 assert!(!is_ssrf_target(&"2606:4700::1111".parse().unwrap()));
987 }
988
989 #[test]
990 fn test_check_ssrf_detailed() {
991 assert_eq!(
993 check_ssrf(&"127.0.0.1".parse().unwrap()),
994 SsrfCheckResult::Loopback
995 );
996
997 assert_eq!(
999 check_ssrf(&"10.0.0.1".parse().unwrap()),
1000 SsrfCheckResult::Private
1001 );
1002
1003 assert_eq!(
1005 check_ssrf(&"169.254.169.254".parse().unwrap()),
1006 SsrfCheckResult::CloudMetadata
1007 );
1008
1009 assert_eq!(
1011 check_ssrf(&"169.254.1.1".parse().unwrap()),
1012 SsrfCheckResult::LinkLocal
1013 );
1014
1015 assert_eq!(
1017 check_ssrf(&"fc00::1".parse().unwrap()),
1018 SsrfCheckResult::Ipv6UniqueLocal
1019 );
1020
1021 assert_eq!(
1023 check_ssrf(&"8.8.8.8".parse().unwrap()),
1024 SsrfCheckResult::Safe
1025 );
1026
1027 let result = check_ssrf(&"::ffff:127.0.0.1".parse().unwrap());
1029 assert!(result.is_blocked());
1030 if let SsrfCheckResult::MappedBlocked { mapped_v4, reason } = result {
1031 assert_eq!(mapped_v4.to_string(), "127.0.0.1");
1032 assert_eq!(reason, "loopback");
1033 } else {
1034 panic!("Expected MappedBlocked");
1035 }
1036 }
1037
1038 #[test]
1039 fn test_ssrf_check_result_is_blocked() {
1040 assert!(!SsrfCheckResult::Safe.is_blocked());
1041 assert!(SsrfCheckResult::Loopback.is_blocked());
1042 assert!(SsrfCheckResult::Private.is_blocked());
1043 assert!(SsrfCheckResult::LinkLocal.is_blocked());
1044 assert!(SsrfCheckResult::CloudMetadata.is_blocked());
1045 assert!(SsrfCheckResult::Ipv6UniqueLocal.is_blocked());
1046 assert!(SsrfCheckResult::MappedBlocked {
1047 mapped_v4: "127.0.0.1".parse().unwrap(),
1048 reason: "loopback"
1049 }
1050 .is_blocked());
1051 }
1052}