1use std::collections::HashMap;
2use std::net::IpAddr;
3use std::sync::RwLock;
4use std::time::Duration;
5
6use once_cell::sync::Lazy;
7use reqwest::Client;
8use serde::Deserialize;
9use tracing::{debug, instrument};
10
11use super::types::RdapResponse;
12use crate::error::{Result, SeerError};
13use crate::validation::normalize_domain;
14
15const IANA_BOOTSTRAP_DNS: &str = "https://data.iana.org/rdap/dns.json";
16const IANA_BOOTSTRAP_IPV4: &str = "https://data.iana.org/rdap/ipv4.json";
17const IANA_BOOTSTRAP_IPV6: &str = "https://data.iana.org/rdap/ipv6.json";
18const IANA_BOOTSTRAP_ASN: &str = "https://data.iana.org/rdap/asn.json";
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
21
22static BOOTSTRAP_CACHE: Lazy<RwLock<BootstrapCache>> =
23 Lazy::new(|| RwLock::new(BootstrapCache::default()));
24
25#[derive(Default)]
26struct BootstrapCache {
27 dns: HashMap<String, String>,
28 ipv4: Vec<(IpRange, String)>,
29 ipv6: Vec<(IpRange, String)>,
30 asn: Vec<(AsnRange, String)>,
31 initialized: bool,
32}
33
34#[derive(Clone)]
35struct IpRange {
36 prefix: String,
37}
38
39#[derive(Clone)]
40struct AsnRange {
41 start: u32,
42 end: u32,
43}
44
45#[derive(Deserialize)]
46struct BootstrapResponse {
47 services: Vec<Vec<serde_json::Value>>,
48}
49
50#[derive(Debug, Clone)]
51pub struct RdapClient {
52 http: Client,
53 timeout: Duration,
54}
55
56impl Default for RdapClient {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl RdapClient {
63 pub fn new() -> Self {
64 let http = Client::builder()
65 .timeout(DEFAULT_TIMEOUT)
66 .user_agent("Seer/1.0 (RDAP Client)")
67 .build()
68 .expect("Failed to build HTTP client");
69
70 Self {
71 http,
72 timeout: DEFAULT_TIMEOUT,
73 }
74 }
75
76 pub fn with_timeout(mut self, timeout: Duration) -> Self {
77 self.timeout = timeout;
78 self
79 }
80
81 async fn ensure_bootstrap(&self) -> Result<()> {
82 {
83 let cache = BOOTSTRAP_CACHE
84 .read()
85 .map_err(|_| SeerError::RdapError("Bootstrap cache lock poisoned".to_string()))?;
86 if cache.initialized {
87 return Ok(());
88 }
89 }
90
91 self.load_bootstrap().await
92 }
93
94 async fn load_bootstrap(&self) -> Result<()> {
95 debug!("Loading RDAP bootstrap data from IANA");
96
97 let dns_future = self.http.get(IANA_BOOTSTRAP_DNS).send();
98 let ipv4_future = self.http.get(IANA_BOOTSTRAP_IPV4).send();
99 let ipv6_future = self.http.get(IANA_BOOTSTRAP_IPV6).send();
100 let asn_future = self.http.get(IANA_BOOTSTRAP_ASN).send();
101
102 let (dns_resp, ipv4_resp, ipv6_resp, asn_resp) =
103 tokio::try_join!(dns_future, ipv4_future, ipv6_future, asn_future)?;
104
105 let dns_data: BootstrapResponse = dns_resp.json().await?;
106 let ipv4_data: BootstrapResponse = ipv4_resp.json().await?;
107 let ipv6_data: BootstrapResponse = ipv6_resp.json().await?;
108 let asn_data: BootstrapResponse = asn_resp.json().await?;
109
110 let mut cache = BOOTSTRAP_CACHE
111 .write()
112 .map_err(|_| SeerError::RdapError("Bootstrap cache lock poisoned".to_string()))?;
113
114 for service in dns_data.services {
116 if service.len() >= 2 {
117 if let (Some(tlds), Some(urls)) = (service[0].as_array(), service[1].as_array()) {
118 if let Some(url) = urls.first().and_then(|u| u.as_str()) {
119 for tld in tlds {
120 if let Some(tld_str) = tld.as_str() {
121 cache.dns.insert(tld_str.to_lowercase(), url.to_string());
122 }
123 }
124 }
125 }
126 }
127 }
128
129 for service in ipv4_data.services {
131 if service.len() >= 2 {
132 if let (Some(prefixes), Some(urls)) = (service[0].as_array(), service[1].as_array())
133 {
134 if let Some(url) = urls.first().and_then(|u| u.as_str()) {
135 for prefix in prefixes {
136 if let Some(prefix_str) = prefix.as_str() {
137 cache.ipv4.push((
138 IpRange {
139 prefix: prefix_str.to_string(),
140 },
141 url.to_string(),
142 ));
143 }
144 }
145 }
146 }
147 }
148 }
149
150 for service in ipv6_data.services {
152 if service.len() >= 2 {
153 if let (Some(prefixes), Some(urls)) = (service[0].as_array(), service[1].as_array())
154 {
155 if let Some(url) = urls.first().and_then(|u| u.as_str()) {
156 for prefix in prefixes {
157 if let Some(prefix_str) = prefix.as_str() {
158 cache.ipv6.push((
159 IpRange {
160 prefix: prefix_str.to_string(),
161 },
162 url.to_string(),
163 ));
164 }
165 }
166 }
167 }
168 }
169 }
170
171 for service in asn_data.services {
173 if service.len() >= 2 {
174 if let (Some(ranges), Some(urls)) = (service[0].as_array(), service[1].as_array()) {
175 if let Some(url) = urls.first().and_then(|u| u.as_str()) {
176 for range in ranges {
177 if let Some(range_str) = range.as_str() {
178 if let Some((start, end)) = parse_asn_range(range_str) {
179 cache.asn.push((AsnRange { start, end }, url.to_string()));
180 }
181 }
182 }
183 }
184 }
185 }
186 }
187
188 cache.initialized = true;
189 debug!(
190 dns_entries = cache.dns.len(),
191 ipv4_entries = cache.ipv4.len(),
192 ipv6_entries = cache.ipv6.len(),
193 asn_entries = cache.asn.len(),
194 "RDAP bootstrap loaded"
195 );
196
197 Ok(())
198 }
199
200 fn get_rdap_url_for_domain(&self, domain: &str) -> Option<String> {
201 let cache = BOOTSTRAP_CACHE.read().ok()?;
202 let tld = domain.rsplit('.').next()?;
203 cache.dns.get(&tld.to_lowercase()).cloned()
204 }
205
206 fn get_rdap_url_for_ip(&self, ip: &IpAddr) -> Option<String> {
207 let cache = BOOTSTRAP_CACHE.read().ok()?;
208
209 match ip {
210 IpAddr::V4(addr) => {
211 let octets = addr.octets();
212 for (range, url) in &cache.ipv4 {
213 if ip_matches_prefix(&range.prefix, &octets) {
214 return Some(url.clone());
215 }
216 }
217 }
218 IpAddr::V6(addr) => {
219 let segments = addr.segments();
220 for (range, url) in &cache.ipv6 {
221 if ipv6_matches_prefix(&range.prefix, &segments) {
222 return Some(url.clone());
223 }
224 }
225 }
226 }
227
228 None
229 }
230
231 fn get_rdap_url_for_asn(&self, asn: u32) -> Option<String> {
232 let cache = BOOTSTRAP_CACHE.read().ok()?;
233
234 for (range, url) in &cache.asn {
235 if asn >= range.start && asn <= range.end {
236 return Some(url.clone());
237 }
238 }
239
240 None
241 }
242
243 #[instrument(skip(self), fields(domain = %domain))]
244 pub async fn lookup_domain(&self, domain: &str) -> Result<RdapResponse> {
245 self.ensure_bootstrap().await?;
246
247 let domain = normalize_domain(domain)?;
248 let base_url = self
249 .get_rdap_url_for_domain(&domain)
250 .ok_or_else(|| SeerError::RdapBootstrapError(format!("No RDAP server for {}", domain)))?;
251
252 let url = format!("{}domain/{}", ensure_trailing_slash(&base_url), domain);
253 debug!(url = %url, "Querying RDAP");
254
255 let response = self
256 .http
257 .get(&url)
258 .header("Accept", "application/rdap+json")
259 .send()
260 .await?;
261
262 if !response.status().is_success() {
263 return Err(SeerError::RdapError(format!(
264 "RDAP query failed with status {}",
265 response.status()
266 )));
267 }
268
269 let rdap: RdapResponse = response.json().await?;
270 Ok(rdap)
271 }
272
273 #[instrument(skip(self), fields(ip = %ip))]
274 pub async fn lookup_ip(&self, ip: &str) -> Result<RdapResponse> {
275 self.ensure_bootstrap().await?;
276
277 let ip_addr: IpAddr = ip
278 .parse()
279 .map_err(|_| SeerError::InvalidIpAddress(ip.to_string()))?;
280
281 let base_url = self
282 .get_rdap_url_for_ip(&ip_addr)
283 .ok_or_else(|| SeerError::RdapBootstrapError(format!("No RDAP server for {}", ip)))?;
284
285 let url = format!("{}ip/{}", ensure_trailing_slash(&base_url), ip);
286 debug!(url = %url, "Querying RDAP");
287
288 let response = self
289 .http
290 .get(&url)
291 .header("Accept", "application/rdap+json")
292 .send()
293 .await?;
294
295 if !response.status().is_success() {
296 return Err(SeerError::RdapError(format!(
297 "RDAP query failed with status {}",
298 response.status()
299 )));
300 }
301
302 let rdap: RdapResponse = response.json().await?;
303 Ok(rdap)
304 }
305
306 #[instrument(skip(self), fields(asn = %asn))]
307 pub async fn lookup_asn(&self, asn: u32) -> Result<RdapResponse> {
308 self.ensure_bootstrap().await?;
309
310 let base_url = self
311 .get_rdap_url_for_asn(asn)
312 .ok_or_else(|| SeerError::RdapBootstrapError(format!("No RDAP server for AS{}", asn)))?;
313
314 let url = format!("{}autnum/{}", ensure_trailing_slash(&base_url), asn);
315 debug!(url = %url, "Querying RDAP");
316
317 let response = self
318 .http
319 .get(&url)
320 .header("Accept", "application/rdap+json")
321 .send()
322 .await?;
323
324 if !response.status().is_success() {
325 return Err(SeerError::RdapError(format!(
326 "RDAP query failed with status {}",
327 response.status()
328 )));
329 }
330
331 let rdap: RdapResponse = response.json().await?;
332 Ok(rdap)
333 }
334}
335
336fn ensure_trailing_slash(url: &str) -> String {
339 if url.ends_with('/') {
340 url.to_string()
341 } else {
342 format!("{}/", url)
343 }
344}
345
346fn parse_asn_range(range: &str) -> Option<(u32, u32)> {
347 if let Some(pos) = range.find('-') {
348 let start = range[..pos].parse().ok()?;
349 let end = range[pos + 1..].parse().ok()?;
350 Some((start, end))
351 } else {
352 let num = range.parse().ok()?;
353 Some((num, num))
354 }
355}
356
357fn ip_matches_prefix(prefix: &str, octets: &[u8; 4]) -> bool {
358 let parts: Vec<&str> = prefix.split('/').collect();
359 if parts.is_empty() {
360 return false;
361 }
362
363 let prefix_octets: Vec<u8> = parts[0]
364 .split('.')
365 .filter_map(|s| s.parse().ok())
366 .collect();
367
368 if prefix_octets.is_empty() {
369 return false;
370 }
371
372 let mask_bits = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(8);
373 let full_octets = mask_bits / 8;
374
375 for (i, &octet) in octets.iter().enumerate().take(full_octets.min(prefix_octets.len())) {
376 if i >= 4 || prefix_octets.get(i) != Some(&octet) {
377 return false;
378 }
379 }
380
381 true
382}
383
384fn ipv6_matches_prefix(prefix: &str, segments: &[u16; 8]) -> bool {
385 let parts: Vec<&str> = prefix.split('/').collect();
386 if parts.is_empty() {
387 return false;
388 }
389
390 let prefix_str = parts[0];
392 if let Ok(addr) = prefix_str.parse::<std::net::Ipv6Addr>() {
393 let prefix_segments = addr.segments();
394 let mask_bits: usize = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(48);
395 let full_segments = mask_bits / 16;
396
397 for (i, &segment) in segments.iter().enumerate().take(full_segments.min(8)) {
398 if prefix_segments[i] != segment {
399 return false;
400 }
401 }
402
403 return true;
404 }
405
406 false
407}