Skip to main content

seer_core/
validation.rs

1//! Domain validation and SSRF protection utilities
2
3use std::collections::HashSet;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
5
6use once_cell::sync::Lazy;
7
8use crate::error::{Result, SeerError};
9
10/// TLD allowlist loaded from the `SEER_DOMAIN_ALLOWLIST` environment variable.
11/// When set (e.g., `SEER_DOMAIN_ALLOWLIST="com,org,net"`), only domains with
12/// matching TLDs are permitted. When unset, all TLDs are allowed.
13static DOMAIN_ALLOWLIST: Lazy<Option<HashSet<String>>> = Lazy::new(|| {
14    let set: HashSet<String> = std::env::var("SEER_DOMAIN_ALLOWLIST")
15        .ok()?
16        .split(',')
17        .map(|s| s.trim().to_lowercase())
18        .filter(|s| !s.is_empty())
19        .collect();
20
21    if set.is_empty() {
22        None
23    } else {
24        Some(set)
25    }
26});
27
28/// Normalizes and validates a domain name.
29///
30/// This function:
31/// - Removes http:// and https:// prefixes
32/// - Removes www. prefix
33/// - Removes trailing slashes and paths
34/// - Converts to lowercase
35/// - Converts internationalized domain names (IDN) to Punycode (ASCII)
36/// - Validates format (must contain dots, only alphanumeric/hyphens/dots)
37/// - Does NOT perform SSRF checks (use `validate_domain_safe` for network operations)
38pub fn normalize_domain(domain: &str) -> Result<String> {
39    let domain = domain.trim().to_lowercase();
40
41    // Remove protocol
42    let domain = domain
43        .strip_prefix("http://")
44        .or_else(|| domain.strip_prefix("https://"))
45        .unwrap_or(&domain);
46
47    // Remove trailing slash, path, query parameters, and fragments
48    let domain = domain.split('/').next().unwrap_or(domain);
49    let domain = domain.split('?').next().unwrap_or(domain);
50    let domain = domain.split('#').next().unwrap_or(domain);
51
52    // Remove www. prefix
53    let domain = domain.strip_prefix("www.").unwrap_or(domain);
54
55    // Validate domain format
56    if domain.is_empty() || !domain.contains('.') {
57        return Err(SeerError::InvalidDomain(domain.to_string()));
58    }
59
60    // Convert internationalized domain names (IDN) to ASCII/Punycode
61    let domain = if !domain.is_ascii() {
62        domain_to_ascii(domain)?
63    } else {
64        domain.to_string()
65    };
66
67    // Basic validation - alphanumeric, hyphens, dots, and underscores
68    // Underscores are valid in DNS names (RFC 8552) and required for service
69    // records like _dmarc., _domainkey., _sip._tcp., etc.
70    let valid = domain
71        .chars()
72        .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_');
73    if !valid {
74        return Err(SeerError::InvalidDomain(domain.to_string()));
75    }
76
77    // Check for consecutive dots or dots at start/end
78    if domain.contains("..") || domain.starts_with('.') || domain.ends_with('.') {
79        return Err(SeerError::InvalidDomain(domain.to_string()));
80    }
81
82    // RFC 1035: total domain name length ≤ 253 characters
83    if domain.len() > 253 {
84        return Err(SeerError::InvalidDomain(domain.to_string()));
85    }
86
87    // Check label constraints
88    for label in domain.split('.') {
89        // Labels must be non-empty and not start/end with hyphens
90        if label.is_empty() || label.starts_with('-') || label.ends_with('-') {
91            return Err(SeerError::InvalidDomain(domain.to_string()));
92        }
93        // RFC 1035: each label ≤ 63 characters
94        if label.len() > 63 {
95            return Err(SeerError::InvalidDomain(domain.to_string()));
96        }
97    }
98
99    // Check TLD against allowlist (if configured)
100    if let Some(ref allowlist) = *DOMAIN_ALLOWLIST {
101        if let Some(tld) = domain.rsplit('.').next() {
102            if !allowlist.contains(tld) {
103                return Err(SeerError::DomainNotAllowed {
104                    domain: domain.to_string(),
105                    tld: tld.to_string(),
106                });
107            }
108        }
109    }
110
111    Ok(domain.to_string())
112}
113
114/// Converts an internationalized domain name to ASCII (Punycode).
115fn domain_to_ascii(domain: &str) -> Result<String> {
116    idna::domain_to_ascii(domain).map_err(|_| {
117        SeerError::InvalidDomain(format!("invalid internationalized domain: {}", domain))
118    })
119}
120
121/// Checks if an IP address is in a private or reserved range.
122///
123/// This includes:
124/// - Private networks (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)
125/// - Loopback (127.0.0.0/8, ::1/128)
126/// - Link-local (169.254.0.0/16, fe80::/10)
127/// - Cloud metadata (169.254.169.254)
128/// - Unique local addresses (fc00::/7)
129/// - Documentation ranges (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24)
130/// - Multicast and broadcast
131pub fn is_private_or_reserved_ip(ip: &IpAddr) -> bool {
132    match ip {
133        IpAddr::V4(ipv4) => is_private_or_reserved_ipv4(ipv4),
134        IpAddr::V6(ipv6) => is_private_or_reserved_ipv6(ipv6),
135    }
136}
137
138/// Checks if an IPv4 address is private or reserved.
139fn is_private_or_reserved_ipv4(ip: &Ipv4Addr) -> bool {
140    // Standard private/loopback/link-local checks
141    if ip.is_private() || ip.is_loopback() || ip.is_link_local() {
142        return true;
143    }
144
145    let octets = ip.octets();
146
147    // Cloud metadata service (169.254.169.254)
148    if octets[0] == 169 && octets[1] == 254 && octets[2] == 169 && octets[3] == 254 {
149        return true;
150    }
151
152    // Broader link-local range (169.254.0.0/16) - already covered by is_link_local()
153    // But explicitly check cloud metadata range
154    if octets[0] == 169 && octets[1] == 254 {
155        return true;
156    }
157
158    // Documentation ranges
159    // 192.0.2.0/24 (TEST-NET-1)
160    if octets[0] == 192 && octets[1] == 0 && octets[2] == 2 {
161        return true;
162    }
163    // 198.51.100.0/24 (TEST-NET-2)
164    if octets[0] == 198 && octets[1] == 51 && octets[2] == 100 {
165        return true;
166    }
167    // 203.0.113.0/24 (TEST-NET-3)
168    if octets[0] == 203 && octets[1] == 0 && octets[2] == 113 {
169        return true;
170    }
171
172    // Broadcast
173    if ip.is_broadcast() {
174        return true;
175    }
176
177    // Unspecified (0.0.0.0)
178    if ip.is_unspecified() {
179        return true;
180    }
181
182    // Multicast (224.0.0.0/4)
183    if octets[0] >= 224 && octets[0] <= 239 {
184        return true;
185    }
186
187    // Reserved (240.0.0.0/4)
188    if octets[0] >= 240 {
189        return true;
190    }
191
192    false
193}
194
195/// Checks if an IPv6 address is private or reserved.
196fn is_private_or_reserved_ipv6(ip: &Ipv6Addr) -> bool {
197    // Loopback (::1)
198    if ip.is_loopback() {
199        return true;
200    }
201
202    // Unspecified (::)
203    if ip.is_unspecified() {
204        return true;
205    }
206
207    let segments = ip.segments();
208
209    // Unique local addresses (fc00::/7)
210    if (segments[0] & 0xfe00) == 0xfc00 {
211        return true;
212    }
213
214    // Link-local (fe80::/10)
215    if (segments[0] & 0xffc0) == 0xfe80 {
216        return true;
217    }
218
219    // Multicast (ff00::/8)
220    if segments[0] >> 8 == 0xff {
221        return true;
222    }
223
224    // IPv4-mapped IPv6 addresses (::ffff:0:0/96)
225    // Check if it maps to a private IPv4
226    if ip
227        .to_ipv4_mapped()
228        .is_some_and(|ipv4| is_private_or_reserved_ipv4(&ipv4))
229    {
230        return true;
231    }
232
233    false
234}
235
236/// Returns a human-readable reason why an IP is blocked, or `None` if it is
237/// safe.  Intended for error messages — callers should still use
238/// [`is_private_or_reserved_ip`] for the fast boolean check.
239pub fn describe_reserved_ip(ip: &IpAddr) -> Option<&'static str> {
240    match ip {
241        IpAddr::V4(v4) => {
242            if v4.is_unspecified() {
243                return Some("unspecified address (0.0.0.0) — domain has no routable IP");
244            }
245            if v4.is_loopback() {
246                return Some("loopback address (127.0.0.0/8)");
247            }
248            if v4.is_private() {
249                return Some("private network (RFC 1918)");
250            }
251            if v4.is_link_local() {
252                return Some("link-local address (169.254.0.0/16)");
253            }
254            let o = v4.octets();
255            if o[0] == 169 && o[1] == 254 && o[2] == 169 && o[3] == 254 {
256                return Some("cloud metadata endpoint (169.254.169.254)");
257            }
258            if o[0] == 169 && o[1] == 254 {
259                return Some("link-local address (169.254.0.0/16)");
260            }
261            if (o[0] == 192 && o[1] == 0 && o[2] == 2)
262                || (o[0] == 198 && o[1] == 51 && o[2] == 100)
263                || (o[0] == 203 && o[1] == 0 && o[2] == 113)
264            {
265                return Some("documentation/test range (RFC 5737)");
266            }
267            if v4.is_broadcast() {
268                return Some("broadcast address (255.255.255.255)");
269            }
270            if o[0] >= 224 && o[0] <= 239 {
271                return Some("multicast address (224.0.0.0/4)");
272            }
273            if o[0] >= 240 {
274                return Some("reserved address (240.0.0.0/4)");
275            }
276            None
277        }
278        IpAddr::V6(v6) => {
279            if v6.is_loopback() {
280                return Some("IPv6 loopback (::1)");
281            }
282            if v6.is_unspecified() {
283                return Some("IPv6 unspecified address (::) — domain has no routable IP");
284            }
285            let seg = v6.segments();
286            if (seg[0] & 0xfe00) == 0xfc00 {
287                return Some("IPv6 unique local address (fc00::/7)");
288            }
289            if (seg[0] & 0xffc0) == 0xfe80 {
290                return Some("IPv6 link-local address (fe80::/10)");
291            }
292            if seg[0] >> 8 == 0xff {
293                return Some("IPv6 multicast (ff00::/8)");
294            }
295            if let Some(v4) = v6.to_ipv4_mapped() {
296                if is_private_or_reserved_ipv4(&v4) {
297                    return Some("IPv4-mapped IPv6 address in private/reserved range");
298                }
299            }
300            None
301        }
302    }
303}
304
305/// Validates that a domain is safe to query (SSRF protection).
306///
307/// This function:
308/// 1. Normalizes the domain
309/// 2. Resolves it to IP addresses
310/// 3. Checks that none of the IPs are in private/reserved ranges
311///
312/// Use this before making HTTP/TLS connections to user-supplied domains.
313pub async fn validate_domain_safe(domain: &str) -> Result<String> {
314    // First normalize the domain
315    let normalized = normalize_domain(domain)?;
316
317    // Resolve the domain to IP addresses
318    let addr = format!("{}:443", normalized);
319    let socket_addrs = tokio::net::lookup_host(&addr)
320        .await
321        .map_err(|e| SeerError::InvalidDomain(format!("failed to resolve domain: {}", e)))?;
322
323    // Check all resolved IPs
324    for socket_addr in socket_addrs {
325        let ip = socket_addr.ip();
326        if let Some(reason) = describe_reserved_ip(&ip) {
327            return Err(SeerError::InvalidDomain(format!(
328                "cannot connect to '{}': {} — {}",
329                normalized, ip, reason
330            )));
331        }
332    }
333
334    Ok(normalized)
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_normalize_domain() {
343        assert_eq!(normalize_domain("example.com").unwrap(), "example.com");
344        assert_eq!(normalize_domain("EXAMPLE.COM").unwrap(), "example.com");
345        assert_eq!(
346            normalize_domain("https://www.example.com/path").unwrap(),
347            "example.com"
348        );
349        assert_eq!(
350            normalize_domain("http://example.com/").unwrap(),
351            "example.com"
352        );
353        assert_eq!(
354            normalize_domain("  WWW.EXAMPLE.COM  ").unwrap(),
355            "example.com"
356        );
357
358        // Query parameters and fragments
359        assert_eq!(
360            normalize_domain("example.com?query=1").unwrap(),
361            "example.com"
362        );
363        assert_eq!(
364            normalize_domain("example.com#section").unwrap(),
365            "example.com"
366        );
367        assert_eq!(
368            normalize_domain("https://example.com/path?q=1#frag").unwrap(),
369            "example.com"
370        );
371
372        // Underscore domains (DNS service records)
373        assert_eq!(
374            normalize_domain("_dmarc.example.com").unwrap(),
375            "_dmarc.example.com"
376        );
377        assert_eq!(
378            normalize_domain("selector1._domainkey.example.com").unwrap(),
379            "selector1._domainkey.example.com"
380        );
381        assert_eq!(
382            normalize_domain("_sip._tcp.example.com").unwrap(),
383            "_sip._tcp.example.com"
384        );
385
386        // Invalid domains
387        assert!(normalize_domain("").is_err());
388        assert!(normalize_domain("nodots").is_err());
389        assert!(normalize_domain("example..com").is_err());
390        assert!(normalize_domain(".example.com").is_err());
391        assert!(normalize_domain("example.com.").is_err());
392        assert!(normalize_domain("-example.com").is_err());
393        assert!(normalize_domain("example-.com").is_err());
394    }
395
396    #[test]
397    fn test_normalize_idn_domain() {
398        // German: münchen.de -> xn--mnchen-3ya.de
399        let result = normalize_domain("münchen.de").unwrap();
400        assert_eq!(result, "xn--mnchen-3ya.de");
401
402        // Japanese: 例え.jp -> xn--r8jz45g.jp
403        let result = normalize_domain("例え.jp").unwrap();
404        assert_eq!(result, "xn--r8jz45g.jp");
405
406        // Chinese: 中文.com -> xn--fiq228c.com
407        let result = normalize_domain("中文.com").unwrap();
408        assert_eq!(result, "xn--fiq228c.com");
409
410        // With protocol prefix
411        let result = normalize_domain("https://münchen.de/path").unwrap();
412        assert_eq!(result, "xn--mnchen-3ya.de");
413    }
414
415    #[test]
416    fn test_allowlist_not_set_allows_all() {
417        // When SEER_DOMAIN_ALLOWLIST is not set, all domains pass
418        // This test verifies the default behavior (no env var)
419        assert!(normalize_domain("example.com").is_ok());
420        assert!(normalize_domain("example.xyz").is_ok());
421        assert!(normalize_domain("example.co.uk").is_ok());
422    }
423
424    #[test]
425    fn test_is_private_or_reserved_ipv4() {
426        // Private networks
427        assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
428            10, 0, 0, 1
429        ))));
430        assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
431            172, 16, 0, 1
432        ))));
433        assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
434            192, 168, 1, 1
435        ))));
436
437        // Loopback
438        assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
439            127, 0, 0, 1
440        ))));
441
442        // Link-local
443        assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
444            169, 254, 1, 1
445        ))));
446
447        // Cloud metadata
448        assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
449            169, 254, 169, 254
450        ))));
451
452        // Public IP (should not be blocked)
453        assert!(!is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
454            8, 8, 8, 8
455        ))));
456        assert!(!is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
457            1, 1, 1, 1
458        ))));
459    }
460
461    #[test]
462    fn test_is_private_or_reserved_ipv6() {
463        // Loopback
464        assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
465            0, 0, 0, 0, 0, 0, 0, 1
466        ))));
467
468        // Unique local
469        assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
470            0xfc00, 0, 0, 0, 0, 0, 0, 1
471        ))));
472
473        // Link-local
474        assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
475            0xfe80, 0, 0, 0, 0, 0, 0, 1
476        ))));
477
478        // Public IPv6 (should not be blocked)
479        assert!(!is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
480            0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888
481        ))));
482    }
483}