1use std::collections::HashSet;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
5
6use once_cell::sync::Lazy;
7
8use crate::error::{Result, SeerError};
9
10static DOMAIN_ALLOWLIST: Lazy<Option<HashSet<String>>> = Lazy::new(|| {
14 let set: HashSet<String> = std::env::var("SEER_DOMAIN_ALLOWLIST")
15 .ok()?
16 .split(',')
17 .map(|s| s.trim().to_lowercase())
18 .filter(|s| !s.is_empty())
19 .collect();
20
21 if set.is_empty() {
22 None
23 } else {
24 Some(set)
25 }
26});
27
28pub fn normalize_domain(domain: &str) -> Result<String> {
39 let domain = domain.trim().to_lowercase();
40
41 let domain = domain
43 .strip_prefix("http://")
44 .or_else(|| domain.strip_prefix("https://"))
45 .unwrap_or(&domain);
46
47 let domain = domain.split('/').next().unwrap_or(domain);
49 let domain = domain.split('?').next().unwrap_or(domain);
50 let domain = domain.split('#').next().unwrap_or(domain);
51
52 let domain = domain.strip_prefix("www.").unwrap_or(domain);
54
55 if domain.is_empty() || !domain.contains('.') {
57 return Err(SeerError::InvalidDomain(domain.to_string()));
58 }
59
60 let domain = if !domain.is_ascii() {
62 domain_to_ascii(domain)?
63 } else {
64 domain.to_string()
65 };
66
67 let valid = domain
71 .chars()
72 .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_');
73 if !valid {
74 return Err(SeerError::InvalidDomain(domain.to_string()));
75 }
76
77 if domain.contains("..") || domain.starts_with('.') || domain.ends_with('.') {
79 return Err(SeerError::InvalidDomain(domain.to_string()));
80 }
81
82 if domain.len() > 253 {
84 return Err(SeerError::InvalidDomain(domain.to_string()));
85 }
86
87 for label in domain.split('.') {
89 if label.is_empty() || label.starts_with('-') || label.ends_with('-') {
91 return Err(SeerError::InvalidDomain(domain.to_string()));
92 }
93 if label.len() > 63 {
95 return Err(SeerError::InvalidDomain(domain.to_string()));
96 }
97 }
98
99 if let Some(ref allowlist) = *DOMAIN_ALLOWLIST {
101 if let Some(tld) = domain.rsplit('.').next() {
102 if !allowlist.contains(tld) {
103 return Err(SeerError::DomainNotAllowed {
104 domain: domain.to_string(),
105 tld: tld.to_string(),
106 });
107 }
108 }
109 }
110
111 Ok(domain.to_string())
112}
113
114fn domain_to_ascii(domain: &str) -> Result<String> {
116 idna::domain_to_ascii(domain).map_err(|_| {
117 SeerError::InvalidDomain(format!("invalid internationalized domain: {}", domain))
118 })
119}
120
121pub fn is_private_or_reserved_ip(ip: &IpAddr) -> bool {
132 match ip {
133 IpAddr::V4(ipv4) => is_private_or_reserved_ipv4(ipv4),
134 IpAddr::V6(ipv6) => is_private_or_reserved_ipv6(ipv6),
135 }
136}
137
138fn is_private_or_reserved_ipv4(ip: &Ipv4Addr) -> bool {
140 if ip.is_private() || ip.is_loopback() || ip.is_link_local() {
142 return true;
143 }
144
145 let octets = ip.octets();
146
147 if octets[0] == 169 && octets[1] == 254 && octets[2] == 169 && octets[3] == 254 {
149 return true;
150 }
151
152 if octets[0] == 169 && octets[1] == 254 {
155 return true;
156 }
157
158 if octets[0] == 192 && octets[1] == 0 && octets[2] == 2 {
161 return true;
162 }
163 if octets[0] == 198 && octets[1] == 51 && octets[2] == 100 {
165 return true;
166 }
167 if octets[0] == 203 && octets[1] == 0 && octets[2] == 113 {
169 return true;
170 }
171
172 if ip.is_broadcast() {
174 return true;
175 }
176
177 if ip.is_unspecified() {
179 return true;
180 }
181
182 if octets[0] >= 224 && octets[0] <= 239 {
184 return true;
185 }
186
187 if octets[0] >= 240 {
189 return true;
190 }
191
192 false
193}
194
195fn is_private_or_reserved_ipv6(ip: &Ipv6Addr) -> bool {
197 if ip.is_loopback() {
199 return true;
200 }
201
202 if ip.is_unspecified() {
204 return true;
205 }
206
207 let segments = ip.segments();
208
209 if (segments[0] & 0xfe00) == 0xfc00 {
211 return true;
212 }
213
214 if (segments[0] & 0xffc0) == 0xfe80 {
216 return true;
217 }
218
219 if segments[0] >> 8 == 0xff {
221 return true;
222 }
223
224 if ip
227 .to_ipv4_mapped()
228 .is_some_and(|ipv4| is_private_or_reserved_ipv4(&ipv4))
229 {
230 return true;
231 }
232
233 false
234}
235
236pub fn describe_reserved_ip(ip: &IpAddr) -> Option<&'static str> {
240 match ip {
241 IpAddr::V4(v4) => {
242 if v4.is_unspecified() {
243 return Some("unspecified address (0.0.0.0) — domain has no routable IP");
244 }
245 if v4.is_loopback() {
246 return Some("loopback address (127.0.0.0/8)");
247 }
248 if v4.is_private() {
249 return Some("private network (RFC 1918)");
250 }
251 if v4.is_link_local() {
252 return Some("link-local address (169.254.0.0/16)");
253 }
254 let o = v4.octets();
255 if o[0] == 169 && o[1] == 254 && o[2] == 169 && o[3] == 254 {
256 return Some("cloud metadata endpoint (169.254.169.254)");
257 }
258 if o[0] == 169 && o[1] == 254 {
259 return Some("link-local address (169.254.0.0/16)");
260 }
261 if (o[0] == 192 && o[1] == 0 && o[2] == 2)
262 || (o[0] == 198 && o[1] == 51 && o[2] == 100)
263 || (o[0] == 203 && o[1] == 0 && o[2] == 113)
264 {
265 return Some("documentation/test range (RFC 5737)");
266 }
267 if v4.is_broadcast() {
268 return Some("broadcast address (255.255.255.255)");
269 }
270 if o[0] >= 224 && o[0] <= 239 {
271 return Some("multicast address (224.0.0.0/4)");
272 }
273 if o[0] >= 240 {
274 return Some("reserved address (240.0.0.0/4)");
275 }
276 None
277 }
278 IpAddr::V6(v6) => {
279 if v6.is_loopback() {
280 return Some("IPv6 loopback (::1)");
281 }
282 if v6.is_unspecified() {
283 return Some("IPv6 unspecified address (::) — domain has no routable IP");
284 }
285 let seg = v6.segments();
286 if (seg[0] & 0xfe00) == 0xfc00 {
287 return Some("IPv6 unique local address (fc00::/7)");
288 }
289 if (seg[0] & 0xffc0) == 0xfe80 {
290 return Some("IPv6 link-local address (fe80::/10)");
291 }
292 if seg[0] >> 8 == 0xff {
293 return Some("IPv6 multicast (ff00::/8)");
294 }
295 if let Some(v4) = v6.to_ipv4_mapped() {
296 if is_private_or_reserved_ipv4(&v4) {
297 return Some("IPv4-mapped IPv6 address in private/reserved range");
298 }
299 }
300 None
301 }
302 }
303}
304
305pub async fn validate_domain_safe(domain: &str) -> Result<String> {
314 let normalized = normalize_domain(domain)?;
316
317 let addr = format!("{}:443", normalized);
319 let socket_addrs = tokio::net::lookup_host(&addr)
320 .await
321 .map_err(|e| SeerError::InvalidDomain(format!("failed to resolve domain: {}", e)))?;
322
323 for socket_addr in socket_addrs {
325 let ip = socket_addr.ip();
326 if let Some(reason) = describe_reserved_ip(&ip) {
327 return Err(SeerError::InvalidDomain(format!(
328 "cannot connect to '{}': {} — {}",
329 normalized, ip, reason
330 )));
331 }
332 }
333
334 Ok(normalized)
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_normalize_domain() {
343 assert_eq!(normalize_domain("example.com").unwrap(), "example.com");
344 assert_eq!(normalize_domain("EXAMPLE.COM").unwrap(), "example.com");
345 assert_eq!(
346 normalize_domain("https://www.example.com/path").unwrap(),
347 "example.com"
348 );
349 assert_eq!(
350 normalize_domain("http://example.com/").unwrap(),
351 "example.com"
352 );
353 assert_eq!(
354 normalize_domain(" WWW.EXAMPLE.COM ").unwrap(),
355 "example.com"
356 );
357
358 assert_eq!(
360 normalize_domain("example.com?query=1").unwrap(),
361 "example.com"
362 );
363 assert_eq!(
364 normalize_domain("example.com#section").unwrap(),
365 "example.com"
366 );
367 assert_eq!(
368 normalize_domain("https://example.com/path?q=1#frag").unwrap(),
369 "example.com"
370 );
371
372 assert_eq!(
374 normalize_domain("_dmarc.example.com").unwrap(),
375 "_dmarc.example.com"
376 );
377 assert_eq!(
378 normalize_domain("selector1._domainkey.example.com").unwrap(),
379 "selector1._domainkey.example.com"
380 );
381 assert_eq!(
382 normalize_domain("_sip._tcp.example.com").unwrap(),
383 "_sip._tcp.example.com"
384 );
385
386 assert!(normalize_domain("").is_err());
388 assert!(normalize_domain("nodots").is_err());
389 assert!(normalize_domain("example..com").is_err());
390 assert!(normalize_domain(".example.com").is_err());
391 assert!(normalize_domain("example.com.").is_err());
392 assert!(normalize_domain("-example.com").is_err());
393 assert!(normalize_domain("example-.com").is_err());
394 }
395
396 #[test]
397 fn test_normalize_idn_domain() {
398 let result = normalize_domain("münchen.de").unwrap();
400 assert_eq!(result, "xn--mnchen-3ya.de");
401
402 let result = normalize_domain("例え.jp").unwrap();
404 assert_eq!(result, "xn--r8jz45g.jp");
405
406 let result = normalize_domain("中文.com").unwrap();
408 assert_eq!(result, "xn--fiq228c.com");
409
410 let result = normalize_domain("https://münchen.de/path").unwrap();
412 assert_eq!(result, "xn--mnchen-3ya.de");
413 }
414
415 #[test]
416 fn test_allowlist_not_set_allows_all() {
417 assert!(normalize_domain("example.com").is_ok());
420 assert!(normalize_domain("example.xyz").is_ok());
421 assert!(normalize_domain("example.co.uk").is_ok());
422 }
423
424 #[test]
425 fn test_is_private_or_reserved_ipv4() {
426 assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
428 10, 0, 0, 1
429 ))));
430 assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
431 172, 16, 0, 1
432 ))));
433 assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
434 192, 168, 1, 1
435 ))));
436
437 assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
439 127, 0, 0, 1
440 ))));
441
442 assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
444 169, 254, 1, 1
445 ))));
446
447 assert!(is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
449 169, 254, 169, 254
450 ))));
451
452 assert!(!is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
454 8, 8, 8, 8
455 ))));
456 assert!(!is_private_or_reserved_ip(&IpAddr::V4(Ipv4Addr::new(
457 1, 1, 1, 1
458 ))));
459 }
460
461 #[test]
462 fn test_is_private_or_reserved_ipv6() {
463 assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
465 0, 0, 0, 0, 0, 0, 0, 1
466 ))));
467
468 assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
470 0xfc00, 0, 0, 0, 0, 0, 0, 1
471 ))));
472
473 assert!(is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
475 0xfe80, 0, 0, 0, 0, 0, 0, 1
476 ))));
477
478 assert!(!is_private_or_reserved_ip(&IpAddr::V6(Ipv6Addr::new(
480 0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888
481 ))));
482 }
483}