seer_core/rdap/
client.rs

1use std::collections::HashMap;
2use std::net::IpAddr;
3use std::sync::RwLock;
4use std::time::Duration;
5
6use once_cell::sync::Lazy;
7use reqwest::Client;
8use serde::Deserialize;
9use tracing::{debug, instrument};
10
11use super::types::RdapResponse;
12use crate::error::{Result, SeerError};
13use crate::validation::normalize_domain;
14
15const IANA_BOOTSTRAP_DNS: &str = "https://data.iana.org/rdap/dns.json";
16const IANA_BOOTSTRAP_IPV4: &str = "https://data.iana.org/rdap/ipv4.json";
17const IANA_BOOTSTRAP_IPV6: &str = "https://data.iana.org/rdap/ipv6.json";
18const IANA_BOOTSTRAP_ASN: &str = "https://data.iana.org/rdap/asn.json";
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
21
22static BOOTSTRAP_CACHE: Lazy<RwLock<BootstrapCache>> =
23    Lazy::new(|| RwLock::new(BootstrapCache::default()));
24
25#[derive(Default)]
26struct BootstrapCache {
27    dns: HashMap<String, String>,
28    ipv4: Vec<(IpRange, String)>,
29    ipv6: Vec<(IpRange, String)>,
30    asn: Vec<(AsnRange, String)>,
31    initialized: bool,
32}
33
34#[derive(Clone)]
35struct IpRange {
36    prefix: String,
37}
38
39#[derive(Clone)]
40struct AsnRange {
41    start: u32,
42    end: u32,
43}
44
45#[derive(Deserialize)]
46struct BootstrapResponse {
47    services: Vec<Vec<serde_json::Value>>,
48}
49
50#[derive(Debug, Clone)]
51pub struct RdapClient {
52    http: Client,
53    timeout: Duration,
54}
55
56impl Default for RdapClient {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl RdapClient {
63    pub fn new() -> Self {
64        let http = Client::builder()
65            .timeout(DEFAULT_TIMEOUT)
66            .user_agent("Seer/1.0 (RDAP Client)")
67            .build()
68            .expect("Failed to build HTTP client");
69
70        Self {
71            http,
72            timeout: DEFAULT_TIMEOUT,
73        }
74    }
75
76    pub fn with_timeout(mut self, timeout: Duration) -> Self {
77        self.timeout = timeout;
78        self
79    }
80
81    async fn ensure_bootstrap(&self) -> Result<()> {
82        {
83            let cache = BOOTSTRAP_CACHE
84                .read()
85                .map_err(|_| SeerError::RdapError("Bootstrap cache lock poisoned".to_string()))?;
86            if cache.initialized {
87                return Ok(());
88            }
89        }
90
91        self.load_bootstrap().await
92    }
93
94    async fn load_bootstrap(&self) -> Result<()> {
95        debug!("Loading RDAP bootstrap data from IANA");
96
97        let dns_future = self.http.get(IANA_BOOTSTRAP_DNS).send();
98        let ipv4_future = self.http.get(IANA_BOOTSTRAP_IPV4).send();
99        let ipv6_future = self.http.get(IANA_BOOTSTRAP_IPV6).send();
100        let asn_future = self.http.get(IANA_BOOTSTRAP_ASN).send();
101
102        let (dns_resp, ipv4_resp, ipv6_resp, asn_resp) =
103            tokio::try_join!(dns_future, ipv4_future, ipv6_future, asn_future)?;
104
105        let dns_data: BootstrapResponse = dns_resp.json().await?;
106        let ipv4_data: BootstrapResponse = ipv4_resp.json().await?;
107        let ipv6_data: BootstrapResponse = ipv6_resp.json().await?;
108        let asn_data: BootstrapResponse = asn_resp.json().await?;
109
110        let mut cache = BOOTSTRAP_CACHE
111            .write()
112            .map_err(|_| SeerError::RdapError("Bootstrap cache lock poisoned".to_string()))?;
113
114        // Parse DNS bootstrap
115        for service in dns_data.services {
116            if service.len() >= 2 {
117                if let (Some(tlds), Some(urls)) = (service[0].as_array(), service[1].as_array()) {
118                    if let Some(url) = urls.first().and_then(|u| u.as_str()) {
119                        for tld in tlds {
120                            if let Some(tld_str) = tld.as_str() {
121                                cache.dns.insert(tld_str.to_lowercase(), url.to_string());
122                            }
123                        }
124                    }
125                }
126            }
127        }
128
129        // Parse IPv4 bootstrap
130        for service in ipv4_data.services {
131            if service.len() >= 2 {
132                if let (Some(prefixes), Some(urls)) = (service[0].as_array(), service[1].as_array())
133                {
134                    if let Some(url) = urls.first().and_then(|u| u.as_str()) {
135                        for prefix in prefixes {
136                            if let Some(prefix_str) = prefix.as_str() {
137                                cache.ipv4.push((
138                                    IpRange {
139                                        prefix: prefix_str.to_string(),
140                                    },
141                                    url.to_string(),
142                                ));
143                            }
144                        }
145                    }
146                }
147            }
148        }
149
150        // Parse IPv6 bootstrap
151        for service in ipv6_data.services {
152            if service.len() >= 2 {
153                if let (Some(prefixes), Some(urls)) = (service[0].as_array(), service[1].as_array())
154                {
155                    if let Some(url) = urls.first().and_then(|u| u.as_str()) {
156                        for prefix in prefixes {
157                            if let Some(prefix_str) = prefix.as_str() {
158                                cache.ipv6.push((
159                                    IpRange {
160                                        prefix: prefix_str.to_string(),
161                                    },
162                                    url.to_string(),
163                                ));
164                            }
165                        }
166                    }
167                }
168            }
169        }
170
171        // Parse ASN bootstrap
172        for service in asn_data.services {
173            if service.len() >= 2 {
174                if let (Some(ranges), Some(urls)) = (service[0].as_array(), service[1].as_array()) {
175                    if let Some(url) = urls.first().and_then(|u| u.as_str()) {
176                        for range in ranges {
177                            if let Some(range_str) = range.as_str() {
178                                if let Some((start, end)) = parse_asn_range(range_str) {
179                                    cache.asn.push((AsnRange { start, end }, url.to_string()));
180                                }
181                            }
182                        }
183                    }
184                }
185            }
186        }
187
188        cache.initialized = true;
189        debug!(
190            dns_entries = cache.dns.len(),
191            ipv4_entries = cache.ipv4.len(),
192            ipv6_entries = cache.ipv6.len(),
193            asn_entries = cache.asn.len(),
194            "RDAP bootstrap loaded"
195        );
196
197        Ok(())
198    }
199
200    fn get_rdap_url_for_domain(&self, domain: &str) -> Option<String> {
201        let cache = BOOTSTRAP_CACHE.read().ok()?;
202        let tld = domain.rsplit('.').next()?;
203        cache.dns.get(&tld.to_lowercase()).cloned()
204    }
205
206    fn get_rdap_url_for_ip(&self, ip: &IpAddr) -> Option<String> {
207        let cache = BOOTSTRAP_CACHE.read().ok()?;
208
209        match ip {
210            IpAddr::V4(addr) => {
211                let octets = addr.octets();
212                for (range, url) in &cache.ipv4 {
213                    if ip_matches_prefix(&range.prefix, &octets) {
214                        return Some(url.clone());
215                    }
216                }
217            }
218            IpAddr::V6(addr) => {
219                let segments = addr.segments();
220                for (range, url) in &cache.ipv6 {
221                    if ipv6_matches_prefix(&range.prefix, &segments) {
222                        return Some(url.clone());
223                    }
224                }
225            }
226        }
227
228        None
229    }
230
231    fn get_rdap_url_for_asn(&self, asn: u32) -> Option<String> {
232        let cache = BOOTSTRAP_CACHE.read().ok()?;
233
234        for (range, url) in &cache.asn {
235            if asn >= range.start && asn <= range.end {
236                return Some(url.clone());
237            }
238        }
239
240        None
241    }
242
243    #[instrument(skip(self), fields(domain = %domain))]
244    pub async fn lookup_domain(&self, domain: &str) -> Result<RdapResponse> {
245        self.ensure_bootstrap().await?;
246
247        let domain = normalize_domain(domain)?;
248        let base_url = self
249            .get_rdap_url_for_domain(&domain)
250            .ok_or_else(|| SeerError::RdapBootstrapError(format!("No RDAP server for {}", domain)))?;
251
252        let url = format!("{}domain/{}", ensure_trailing_slash(&base_url), domain);
253        debug!(url = %url, "Querying RDAP");
254
255        let response = self
256            .http
257            .get(&url)
258            .header("Accept", "application/rdap+json")
259            .send()
260            .await?;
261
262        if !response.status().is_success() {
263            return Err(SeerError::RdapError(format!(
264                "RDAP query failed with status {}",
265                response.status()
266            )));
267        }
268
269        let rdap: RdapResponse = response.json().await?;
270        Ok(rdap)
271    }
272
273    #[instrument(skip(self), fields(ip = %ip))]
274    pub async fn lookup_ip(&self, ip: &str) -> Result<RdapResponse> {
275        self.ensure_bootstrap().await?;
276
277        let ip_addr: IpAddr = ip
278            .parse()
279            .map_err(|_| SeerError::InvalidIpAddress(ip.to_string()))?;
280
281        let base_url = self
282            .get_rdap_url_for_ip(&ip_addr)
283            .ok_or_else(|| SeerError::RdapBootstrapError(format!("No RDAP server for {}", ip)))?;
284
285        let url = format!("{}ip/{}", ensure_trailing_slash(&base_url), ip);
286        debug!(url = %url, "Querying RDAP");
287
288        let response = self
289            .http
290            .get(&url)
291            .header("Accept", "application/rdap+json")
292            .send()
293            .await?;
294
295        if !response.status().is_success() {
296            return Err(SeerError::RdapError(format!(
297                "RDAP query failed with status {}",
298                response.status()
299            )));
300        }
301
302        let rdap: RdapResponse = response.json().await?;
303        Ok(rdap)
304    }
305
306    #[instrument(skip(self), fields(asn = %asn))]
307    pub async fn lookup_asn(&self, asn: u32) -> Result<RdapResponse> {
308        self.ensure_bootstrap().await?;
309
310        let base_url = self
311            .get_rdap_url_for_asn(asn)
312            .ok_or_else(|| SeerError::RdapBootstrapError(format!("No RDAP server for AS{}", asn)))?;
313
314        let url = format!("{}autnum/{}", ensure_trailing_slash(&base_url), asn);
315        debug!(url = %url, "Querying RDAP");
316
317        let response = self
318            .http
319            .get(&url)
320            .header("Accept", "application/rdap+json")
321            .send()
322            .await?;
323
324        if !response.status().is_success() {
325            return Err(SeerError::RdapError(format!(
326                "RDAP query failed with status {}",
327                response.status()
328            )));
329        }
330
331        let rdap: RdapResponse = response.json().await?;
332        Ok(rdap)
333    }
334}
335
336// Domain normalization is now handled by the shared validation module
337
338fn ensure_trailing_slash(url: &str) -> String {
339    if url.ends_with('/') {
340        url.to_string()
341    } else {
342        format!("{}/", url)
343    }
344}
345
346fn parse_asn_range(range: &str) -> Option<(u32, u32)> {
347    if let Some(pos) = range.find('-') {
348        let start = range[..pos].parse().ok()?;
349        let end = range[pos + 1..].parse().ok()?;
350        Some((start, end))
351    } else {
352        let num = range.parse().ok()?;
353        Some((num, num))
354    }
355}
356
357fn ip_matches_prefix(prefix: &str, octets: &[u8; 4]) -> bool {
358    let parts: Vec<&str> = prefix.split('/').collect();
359    if parts.is_empty() {
360        return false;
361    }
362
363    let prefix_octets: Vec<u8> = parts[0]
364        .split('.')
365        .filter_map(|s| s.parse().ok())
366        .collect();
367
368    if prefix_octets.is_empty() {
369        return false;
370    }
371
372    let mask_bits = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(8);
373    let full_octets = mask_bits / 8;
374
375    for (i, &octet) in octets.iter().enumerate().take(full_octets.min(prefix_octets.len())) {
376        if i >= 4 || prefix_octets.get(i) != Some(&octet) {
377            return false;
378        }
379    }
380
381    true
382}
383
384fn ipv6_matches_prefix(prefix: &str, segments: &[u16; 8]) -> bool {
385    let parts: Vec<&str> = prefix.split('/').collect();
386    if parts.is_empty() {
387        return false;
388    }
389
390    // Parse IPv6 prefix (simplified)
391    let prefix_str = parts[0];
392    if let Ok(addr) = prefix_str.parse::<std::net::Ipv6Addr>() {
393        let prefix_segments = addr.segments();
394        let mask_bits: usize = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(48);
395        let full_segments = mask_bits / 16;
396
397        for (i, &segment) in segments.iter().enumerate().take(full_segments.min(8)) {
398            if prefix_segments[i] != segment {
399                return false;
400            }
401        }
402
403        return true;
404    }
405
406    false
407}