1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
12
13use thiserror::Error;
14use url::Url;
15
16use crate::metrics::SecurityMetrics;
17
18pub const ENV_SSRF_ALLOWLIST: &str = "SSRF_ALLOWLIST";
22
23pub const ENV_SSRF_DENYLIST: &str = "SSRF_DENYLIST";
26
27pub const ENV_SSRF_ALLOW_PRIVATE: &str = "SSRF_ALLOW_PRIVATE";
30
31pub const ENV_SSRF_ALLOW_LOOPBACK: &str = "SSRF_ALLOW_LOOPBACK";
34
35pub const ENV_SSRF_ALLOW_LINK_LOCAL: &str = "SSRF_ALLOW_LINK_LOCAL";
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum IpClass {
48 Public,
50 Private,
52 Loopback,
54 LinkLocal,
57 Multicast,
59 Reserved,
63}
64
65#[derive(Debug, Error)]
67pub enum SsrfError {
68 #[error("URL has no host component: {0}")]
70 MissingHost(String),
71
72 #[error("DNS resolution failed for host '{host}': {source}")]
75 DnsFailure {
76 host: String,
77 #[source]
78 source: std::io::Error,
79 },
80
81 #[error("DNS resolution returned no addresses for host '{host}'")]
83 NoAddresses { host: String },
84
85 #[error("host '{host}' (resolved to {ip}) is denylisted")]
87 Denylisted { host: String, ip: IpAddr },
88
89 #[error("host '{host}' (resolved to {ip}) blocked: {class:?}")]
91 BlockedClass {
92 host: String,
93 ip: IpAddr,
94 class: IpClass,
95 },
96}
97
98#[derive(Debug, Clone)]
103pub struct SsrfPolicy {
104 allow_private: bool,
105 allow_loopback: bool,
106 allow_link_local: bool,
107 allowlist: Vec<String>,
108 denylist: Vec<String>,
109 metrics: Option<SecurityMetrics>,
110}
111
112impl SsrfPolicy {
113 pub fn new() -> Self {
119 Self {
120 allow_private: false,
121 allow_loopback: false,
122 allow_link_local: false,
123 allowlist: Vec::new(),
124 denylist: Vec::new(),
125 metrics: None,
126 }
127 }
128
129 pub fn from_env() -> Self {
138 Self {
139 allow_private: env_bool(ENV_SSRF_ALLOW_PRIVATE),
140 allow_loopback: env_bool(ENV_SSRF_ALLOW_LOOPBACK),
141 allow_link_local: env_bool(ENV_SSRF_ALLOW_LINK_LOCAL),
142 allowlist: env_csv(ENV_SSRF_ALLOWLIST),
143 denylist: env_csv(ENV_SSRF_DENYLIST),
144 metrics: None,
145 }
146 }
147
148 pub fn with_metrics(mut self, metrics: SecurityMetrics) -> Self {
151 self.metrics = Some(metrics);
152 self
153 }
154
155 pub fn with_allowlist(mut self, hosts: Vec<String>) -> Self {
158 self.allowlist = hosts;
159 self
160 }
161
162 pub fn with_denylist(mut self, hosts: Vec<String>) -> Self {
164 self.denylist = hosts;
165 self
166 }
167
168 pub fn with_allow_private(mut self, allow: bool) -> Self {
170 self.allow_private = allow;
171 self
172 }
173
174 pub fn with_allow_loopback(mut self, allow: bool) -> Self {
176 self.allow_loopback = allow;
177 self
178 }
179
180 pub fn with_allow_link_local(mut self, allow: bool) -> Self {
182 self.allow_link_local = allow;
183 self
184 }
185
186 pub fn classify(ip: IpAddr) -> IpClass {
188 match ip {
189 IpAddr::V4(v4) => classify_v4(v4),
190 IpAddr::V6(v6) => classify_v6(v6),
191 }
192 }
193
194 pub async fn resolve_and_check(&self, url: &Url) -> Result<IpAddr, SsrfError> {
207 let host = url
208 .host_str()
209 .ok_or_else(|| SsrfError::MissingHost(url.to_string()))?;
210 let host_lower = host.to_ascii_lowercase();
211
212 let port = url.port_or_known_default().unwrap_or(80);
215 let lookup_target = format!("{host}:{port}");
216 let mut addrs = tokio::net::lookup_host(&lookup_target)
217 .await
218 .map_err(|e| SsrfError::DnsFailure {
219 host: host.to_string(),
220 source: e,
221 })?;
222 let sock_addr = addrs.next().ok_or_else(|| SsrfError::NoAddresses {
223 host: host.to_string(),
224 })?;
225 let ip = sock_addr.ip();
226
227 if list_contains_host(&self.denylist, &host_lower) {
229 self.record_block(IpClass::Reserved);
230 return Err(SsrfError::Denylisted {
231 host: host.to_string(),
232 ip,
233 });
234 }
235
236 if list_contains_host(&self.allowlist, &host_lower) {
238 return Ok(ip);
239 }
240
241 let class = Self::classify(ip);
242 let permitted = match class {
243 IpClass::Public => true,
244 IpClass::Private => self.allow_private,
245 IpClass::Loopback => self.allow_loopback,
246 IpClass::LinkLocal => self.allow_link_local,
247 IpClass::Multicast | IpClass::Reserved => false,
251 };
252
253 if !permitted {
254 self.record_block(class);
255 return Err(SsrfError::BlockedClass {
256 host: host.to_string(),
257 ip,
258 class,
259 });
260 }
261
262 Ok(ip)
263 }
264
265 fn record_block(&self, class: IpClass) {
266 if let Some(m) = &self.metrics {
267 m.record_ssrf_block(class);
268 }
269 }
270}
271
272impl Default for SsrfPolicy {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278fn classify_v4(v4: Ipv4Addr) -> IpClass {
281 let o = v4.octets();
282
283 if o == [169, 254, 169, 254] {
287 return IpClass::Reserved;
288 }
289
290 if v4.is_unspecified() || v4.is_broadcast() || v4.is_documentation() {
291 return IpClass::Reserved;
292 }
293 if v4.is_loopback() {
294 return IpClass::Loopback;
295 }
296 if v4.is_link_local() {
297 return IpClass::LinkLocal;
298 }
299 if v4.is_multicast() {
300 return IpClass::Multicast;
301 }
302 if v4.is_private() {
303 return IpClass::Private;
304 }
305
306 match o[0] {
317 0 => return IpClass::Reserved,
318 100 if (o[1] & 0xC0) == 0x40 => return IpClass::Reserved, 192 if o[1] == 0 && o[2] == 0 => return IpClass::Reserved,
320 192 if o[1] == 88 && o[2] == 99 => return IpClass::Reserved,
321 198 if o[1] == 18 || o[1] == 19 => return IpClass::Reserved,
322 240..=255 => return IpClass::Reserved,
323 _ => {}
324 }
325
326 IpClass::Public
327}
328
329fn classify_v6(v6: Ipv6Addr) -> IpClass {
330 let segs = v6.segments();
332 if segs == [0xfd00, 0x0ec2, 0, 0, 0, 0, 0, 0x0254] {
333 return IpClass::Reserved;
334 }
335
336 if v6.is_unspecified() {
337 return IpClass::Reserved;
338 }
339 if v6.is_loopback() {
340 return IpClass::Loopback;
341 }
342 if v6.is_multicast() {
343 return IpClass::Multicast;
344 }
345
346 if let Some(v4) = v6.to_ipv4_mapped() {
349 return classify_v4(v4);
350 }
351
352 let first = segs[0];
353
354 if (first & 0xFFC0) == 0xFE80 {
356 return IpClass::LinkLocal;
357 }
358
359 if (first & 0xFE00) == 0xFC00 {
361 return IpClass::Private;
362 }
363
364 if (first & 0xFFC0) == 0xFEC0 {
366 return IpClass::Private;
367 }
368
369 if first == 0x0100 && segs[1] == 0 && segs[2] == 0 && segs[3] == 0 {
376 return IpClass::Reserved;
377 }
378 if first == 0x2001 && segs[1] == 0x0db8 {
379 return IpClass::Reserved;
380 }
381
382 IpClass::Public
383}
384
385fn list_contains_host(list: &[String], host_lower: &str) -> bool {
388 list.iter().any(|entry| {
389 let e = entry.trim().to_ascii_lowercase();
390 let e_host = e.split(':').next().unwrap_or(&e);
392 !e_host.is_empty() && e_host == host_lower
393 })
394}
395
396fn env_bool(key: &str) -> bool {
397 std::env::var(key)
398 .ok()
399 .map(|v| {
400 let v = v.trim().to_ascii_lowercase();
401 matches!(v.as_str(), "1" | "true" | "yes" | "on")
402 })
403 .unwrap_or(false)
404}
405
406fn env_csv(key: &str) -> Vec<String> {
407 std::env::var(key)
408 .ok()
409 .map(|raw| {
410 raw.split(',')
411 .map(|s| s.trim().to_string())
412 .filter(|s| !s.is_empty())
413 .collect()
414 })
415 .unwrap_or_default()
416}
417
418pub fn is_safe_url(url: &str) -> Result<(), SsrfError> {
445 let parsed = Url::parse(url).map_err(|_| SsrfError::MissingHost(url.to_string()))?;
446 let host = parsed
447 .host()
448 .ok_or_else(|| SsrfError::MissingHost(url.to_string()))?;
449
450 match host {
451 url::Host::Ipv4(v4) => check_ip_safe(&v4.to_string(), IpAddr::V4(v4)),
452 url::Host::Ipv6(v6) => check_ip_safe(&v6.to_string(), IpAddr::V6(v6)),
453 url::Host::Domain(d) => {
454 if is_known_metadata_hostname(d) {
455 return Err(SsrfError::BlockedClass {
456 host: d.to_string(),
457 ip: IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)),
458 class: IpClass::Reserved,
459 });
460 }
461 Ok(())
462 }
463 }
464}
465
466pub async fn resolve_and_check(host: &str) -> Result<IpAddr, SsrfError> {
478 if is_known_metadata_hostname(host) {
482 return Err(SsrfError::BlockedClass {
483 host: host.to_string(),
484 ip: IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)),
485 class: IpClass::Reserved,
486 });
487 }
488
489 let lookup_target = if host.contains(':') {
490 host.to_string()
491 } else {
492 format!("{host}:80")
493 };
494 let addrs = tokio::net::lookup_host(&lookup_target)
495 .await
496 .map_err(|e| SsrfError::DnsFailure {
497 host: host.to_string(),
498 source: e,
499 })?;
500
501 let mut first: Option<IpAddr> = None;
502 for sock in addrs {
503 let ip = sock.ip();
504 check_ip_safe(host, ip)?;
505 if first.is_none() {
506 first = Some(ip);
507 }
508 }
509 first.ok_or_else(|| SsrfError::NoAddresses {
510 host: host.to_string(),
511 })
512}
513
514fn check_ip_safe(host: &str, ip: IpAddr) -> Result<(), SsrfError> {
515 let class = SsrfPolicy::classify(ip);
516 match class {
517 IpClass::Public => Ok(()),
518 IpClass::Private
519 | IpClass::Loopback
520 | IpClass::LinkLocal
521 | IpClass::Multicast
522 | IpClass::Reserved => Err(SsrfError::BlockedClass {
523 host: host.to_string(),
524 ip,
525 class,
526 }),
527 }
528}
529
530fn is_known_metadata_hostname(host: &str) -> bool {
531 let host_only = host.split(':').next().unwrap_or(host);
533 let lc = host_only.to_ascii_lowercase();
534 matches!(
538 lc.as_str(),
539 "metadata.google.internal" | "metadata" | "metadata.goog"
540 )
541}
542
543#[cfg(test)]
546mod tests {
547 use super::*;
548 use std::net::{Ipv4Addr, Ipv6Addr};
549
550 #[test]
551 fn classify_rfc1918_private() {
552 assert_eq!(
553 SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))),
554 IpClass::Private
555 );
556 assert_eq!(
557 SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))),
558 IpClass::Private
559 );
560 assert_eq!(
561 SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))),
562 IpClass::Private
563 );
564 }
565
566 #[test]
567 fn classify_loopback() {
568 assert_eq!(
569 SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
570 IpClass::Loopback
571 );
572 assert_eq!(
573 SsrfPolicy::classify(IpAddr::V6(Ipv6Addr::LOCALHOST)),
574 IpClass::Loopback
575 );
576 }
577
578 #[test]
579 fn classify_public() {
580 assert_eq!(
581 SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))),
582 IpClass::Public
583 );
584 assert_eq!(
585 SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))),
586 IpClass::Public
587 );
588 }
589
590 #[test]
591 fn classify_cloud_metadata() {
592 assert_eq!(
593 SsrfPolicy::classify(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))),
594 IpClass::Reserved
595 );
596 }
597
598 #[test]
599 fn classify_ipv6_link_local() {
600 assert_eq!(
601 SsrfPolicy::classify(IpAddr::V6("fe80::1".parse().unwrap())),
602 IpClass::LinkLocal
603 );
604 }
605
606 #[test]
607 fn classify_ipv6_ula() {
608 assert_eq!(
609 SsrfPolicy::classify(IpAddr::V6("fc00::1".parse().unwrap())),
610 IpClass::Private
611 );
612 assert_eq!(
613 SsrfPolicy::classify(IpAddr::V6("fd12:3456:789a::1".parse().unwrap())),
614 IpClass::Private
615 );
616 }
617
618 #[test]
619 fn classify_ipv6_public() {
620 assert_eq!(
621 SsrfPolicy::classify(IpAddr::V6("2001:4860:4860::8888".parse().unwrap())),
622 IpClass::Public
623 );
624 }
625
626 #[test]
627 fn default_policy_blocks_private() {
628 let p = SsrfPolicy::new();
629 assert!(!p.allow_private);
630 assert!(!p.allow_loopback);
631 assert!(!p.allow_link_local);
632 }
633
634 fn assert_blocked(url: &str, want_class: IpClass) {
637 match is_safe_url(url) {
638 Err(SsrfError::BlockedClass { class, .. }) => assert_eq!(
639 class, want_class,
640 "url {url} blocked with {class:?}, wanted {want_class:?}"
641 ),
642 Err(other) => panic!("url {url} rejected with unexpected error: {other}"),
643 Ok(()) => panic!("url {url} accepted but expected block for {want_class:?}"),
644 }
645 }
646
647 #[test]
648 fn blocks_rfc1918_addresses() {
649 let cases = [
650 "http://10.0.0.1/",
651 "http://10.255.255.255/",
652 "http://172.16.0.1/",
653 "http://172.31.255.255/",
654 "http://192.168.0.1/",
655 "http://192.168.255.255/",
656 "http://[fc00::1]/",
657 "http://[fd00::1]/",
658 ];
659 for url in cases {
660 assert_blocked(url, IpClass::Private);
661 }
662 }
663
664 #[test]
665 fn blocks_loopback() {
666 assert_blocked("http://127.0.0.1/", IpClass::Loopback);
667 assert_blocked("http://127.255.255.254/", IpClass::Loopback);
668 assert_blocked("http://[::1]/", IpClass::Loopback);
669 }
670
671 #[test]
672 fn blocks_link_local() {
673 assert_blocked("http://169.254.1.1/", IpClass::LinkLocal);
674 assert_blocked("http://169.254.254.254/", IpClass::LinkLocal);
675 assert_blocked("http://[fe80::1]/", IpClass::LinkLocal);
676 }
677
678 #[test]
679 fn blocks_aws_metadata_ip() {
680 assert_blocked("http://169.254.169.254/latest/meta-data/", IpClass::Reserved);
682 assert_blocked("http://[fd00:ec2::254]/latest/meta-data/", IpClass::Reserved);
683 }
684
685 #[tokio::test]
686 async fn blocks_aws_metadata_hostname() {
687 assert_blocked(
688 "http://metadata.google.internal/computeMetadata/v1/",
689 IpClass::Reserved,
690 );
691 match resolve_and_check("metadata.google.internal").await {
692 Err(SsrfError::BlockedClass { class, .. }) => assert_eq!(class, IpClass::Reserved),
693 other => panic!("expected BlockedClass for metadata.google.internal, got {other:?}"),
694 }
695 match resolve_and_check("metadata").await {
696 Err(SsrfError::BlockedClass { class, .. }) => assert_eq!(class, IpClass::Reserved),
697 other => panic!("expected BlockedClass for bare 'metadata', got {other:?}"),
698 }
699 }
700
701 #[test]
702 fn allows_public_ipv4() {
703 assert!(is_safe_url("https://8.8.8.8/").is_ok());
704 assert!(is_safe_url("https://1.1.1.1/").is_ok());
705 assert!(is_safe_url("https://93.184.216.34/").is_ok());
706 }
707
708 #[test]
709 fn allows_public_ipv6() {
710 assert!(is_safe_url("https://[2001:4860:4860::8888]/").is_ok());
711 assert!(is_safe_url("https://[2606:4700:4700::1111]/").is_ok());
712 }
713
714 #[test]
715 fn rejects_malformed_url() {
716 match is_safe_url("not a url") {
717 Err(SsrfError::MissingHost(_)) => {}
718 other => panic!("expected MissingHost for malformed url, got {other:?}"),
719 }
720 match is_safe_url("") {
721 Err(SsrfError::MissingHost(_)) => {}
722 other => panic!("expected MissingHost for empty url, got {other:?}"),
723 }
724 }
725
726 #[test]
727 fn rejects_http_without_host() {
728 match is_safe_url("file:///etc/passwd") {
729 Err(SsrfError::MissingHost(_)) => {}
730 other => panic!("expected MissingHost for file URL, got {other:?}"),
731 }
732 }
733
734 #[tokio::test]
740 async fn dns_failure_blocks_request() {
741 let result = resolve_and_check("this-host-does-not-exist.invalid").await;
744 match result {
745 Err(SsrfError::DnsFailure { host, .. }) => {
746 assert_eq!(host, "this-host-does-not-exist.invalid");
747 }
748 Err(SsrfError::NoAddresses { host, .. }) => {
749 assert_eq!(host, "this-host-does-not-exist.invalid");
752 }
753 Err(other) => panic!(
754 "expected DnsFailure or NoAddresses for unresolvable host, got {other:?}"
755 ),
756 Ok(ip) => panic!(
757 "expected DNS failure for unresolvable host, got Ok({ip})"
758 ),
759 }
760 }
761
762 #[tokio::test]
763 async fn policy_dns_failure_blocks_request() {
764 let policy = SsrfPolicy::new();
766 let url = Url::parse("https://this-host-does-not-exist.invalid/resource")
767 .expect("valid URL");
768 let result = policy.resolve_and_check(&url).await;
769 match result {
770 Err(SsrfError::DnsFailure { host, .. }) => {
771 assert!(host.contains("this-host-does-not-exist.invalid"));
772 }
773 Err(SsrfError::NoAddresses { host, .. }) => {
774 assert!(host.contains("this-host-does-not-exist.invalid"));
775 }
776 Err(other) => panic!(
777 "expected DnsFailure/NoAddresses through policy, got {other:?}"
778 ),
779 Ok(ip) => panic!(
780 "expected DNS failure through policy, got Ok({ip})"
781 ),
782 }
783 }
784}