1use hickory_resolver::config::{NameServerConfig, ResolverConfig};
2use hickory_resolver::net::runtime::TokioRuntimeProvider;
3use hickory_resolver::proto::rr::domain::Name;
4use hickory_resolver::proto::rr::rdata::{A, AAAA};
5use hickory_resolver::proto::rr::{RData, Record, RecordType};
6use hickory_resolver::TokioResolver;
7use std::collections::{BTreeMap, BTreeSet};
8use std::error::Error;
9use std::fmt;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Default, PartialEq, Eq)]
14pub struct DnsConfig {
15 pub name_servers: Vec<SocketAddr>,
16 pub overrides: BTreeMap<String, Vec<IpAddr>>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum DnsLookupPolicy {
21 CheckPermissions,
22 SkipPermissions,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct DnsLookupRequest {
27 hostname: String,
28 name_servers: Vec<SocketAddr>,
29}
30
31impl DnsLookupRequest {
32 pub fn new(hostname: impl Into<String>, name_servers: Vec<SocketAddr>) -> Self {
33 Self {
34 hostname: hostname.into(),
35 name_servers,
36 }
37 }
38
39 pub fn hostname(&self) -> &str {
40 &self.hostname
41 }
42
43 pub fn name_servers(&self) -> &[SocketAddr] {
44 &self.name_servers
45 }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct DnsRecordLookupRequest {
50 hostname: String,
51 name_servers: Vec<SocketAddr>,
52 record_type: RecordType,
53}
54
55impl DnsRecordLookupRequest {
56 pub fn new(
57 hostname: impl Into<String>,
58 name_servers: Vec<SocketAddr>,
59 record_type: RecordType,
60 ) -> Self {
61 Self {
62 hostname: hostname.into(),
63 name_servers,
64 record_type,
65 }
66 }
67
68 pub fn hostname(&self) -> &str {
69 &self.hostname
70 }
71
72 pub fn name_servers(&self) -> &[SocketAddr] {
73 &self.name_servers
74 }
75
76 pub const fn record_type(&self) -> RecordType {
77 self.record_type
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum DnsResolutionSource {
83 Literal,
84 Override,
85 Resolver,
86}
87
88impl DnsResolutionSource {
89 pub const fn as_str(self) -> &'static str {
90 match self {
91 Self::Literal => "literal",
92 Self::Override => "override",
93 Self::Resolver => "resolver",
94 }
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct DnsResolution {
100 hostname: String,
101 source: DnsResolutionSource,
102 addresses: Vec<IpAddr>,
103}
104
105impl DnsResolution {
106 pub fn new(
107 hostname: impl Into<String>,
108 source: DnsResolutionSource,
109 addresses: Vec<IpAddr>,
110 ) -> Self {
111 Self {
112 hostname: hostname.into(),
113 source,
114 addresses,
115 }
116 }
117
118 pub fn hostname(&self) -> &str {
119 &self.hostname
120 }
121
122 pub const fn source(&self) -> DnsResolutionSource {
123 self.source
124 }
125
126 pub fn addresses(&self) -> &[IpAddr] {
127 &self.addresses
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct DnsRecordResolution {
133 hostname: String,
134 source: DnsResolutionSource,
135 records: Vec<Record>,
136}
137
138impl DnsRecordResolution {
139 pub fn new(
140 hostname: impl Into<String>,
141 source: DnsResolutionSource,
142 records: Vec<Record>,
143 ) -> Self {
144 Self {
145 hostname: hostname.into(),
146 source,
147 records,
148 }
149 }
150
151 pub fn hostname(&self) -> &str {
152 &self.hostname
153 }
154
155 pub const fn source(&self) -> DnsResolutionSource {
156 self.source
157 }
158
159 pub fn records(&self) -> &[Record] {
160 &self.records
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub enum DnsResolverErrorKind {
166 InvalidInput,
167 LookupFailed,
168}
169
170#[derive(Debug, Clone, PartialEq, Eq)]
171pub struct DnsResolverError {
172 kind: DnsResolverErrorKind,
173 message: String,
174}
175
176impl DnsResolverError {
177 pub fn invalid_input(message: impl Into<String>) -> Self {
178 Self {
179 kind: DnsResolverErrorKind::InvalidInput,
180 message: message.into(),
181 }
182 }
183
184 pub fn lookup_failed(message: impl Into<String>) -> Self {
185 Self {
186 kind: DnsResolverErrorKind::LookupFailed,
187 message: message.into(),
188 }
189 }
190
191 pub const fn kind(&self) -> DnsResolverErrorKind {
192 self.kind
193 }
194}
195
196impl fmt::Display for DnsResolverError {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 write!(f, "{}", self.message)
199 }
200}
201
202impl Error for DnsResolverError {}
203
204pub trait DnsResolver {
205 fn lookup_ip(&self, request: &DnsLookupRequest) -> Result<Vec<IpAddr>, DnsResolverError>;
206 fn lookup_records(
207 &self,
208 request: &DnsRecordLookupRequest,
209 ) -> Result<Vec<Record>, DnsResolverError>;
210}
211
212pub type SharedDnsResolver = Arc<dyn DnsResolver + Send + Sync>;
213
214#[derive(Debug, Default)]
215pub struct HickoryDnsResolver;
216
217impl DnsResolver for HickoryDnsResolver {
218 fn lookup_ip(&self, request: &DnsLookupRequest) -> Result<Vec<IpAddr>, DnsResolverError> {
219 let resolver_config = resolver_config_from_name_servers(request.name_servers());
220 let hostname = request.hostname().to_owned();
221 std::thread::spawn(move || -> Result<Vec<IpAddr>, DnsResolverError> {
222 let runtime = tokio::runtime::Runtime::new().map_err(|error| {
223 DnsResolverError::lookup_failed(format!("failed to create DNS runtime: {error}"))
224 })?;
225
226 runtime.block_on(async move {
227 let builder = if let Some(config) = resolver_config {
228 TokioResolver::builder_with_config(config, TokioRuntimeProvider::default())
229 } else {
230 TokioResolver::builder_tokio().map_err(|error| {
231 DnsResolverError::lookup_failed(format!(
232 "failed to initialize DNS resolver from system configuration: {error}"
233 ))
234 })?
235 };
236
237 let resolver = builder.build().map_err(|error| {
238 DnsResolverError::lookup_failed(format!(
239 "failed to build DNS resolver: {error}"
240 ))
241 })?;
242 let lookup = resolver.lookup_ip(&hostname).await.map_err(|error| {
243 DnsResolverError::lookup_failed(format!(
244 "failed to resolve DNS address {hostname}: {error}"
245 ))
246 })?;
247
248 let mut addresses = Vec::new();
249 let mut seen = BTreeSet::new();
250 for ip in lookup.iter() {
251 if seen.insert(ip) {
252 addresses.push(ip);
253 }
254 }
255
256 if addresses.is_empty() {
257 return Err(DnsResolverError::lookup_failed(format!(
258 "failed to resolve DNS address {hostname}"
259 )));
260 }
261
262 Ok(addresses)
263 })
264 })
265 .join()
266 .map_err(|_| DnsResolverError::lookup_failed("dns resolver thread panicked"))?
267 }
268
269 fn lookup_records(
270 &self,
271 request: &DnsRecordLookupRequest,
272 ) -> Result<Vec<Record>, DnsResolverError> {
273 let resolver_config = resolver_config_from_name_servers(request.name_servers());
274 let hostname = request.hostname().to_owned();
275 let record_type = request.record_type();
276 std::thread::spawn(move || -> Result<Vec<Record>, DnsResolverError> {
277 let runtime = tokio::runtime::Runtime::new().map_err(|error| {
278 DnsResolverError::lookup_failed(format!("failed to create DNS runtime: {error}"))
279 })?;
280
281 runtime.block_on(async move {
282 let builder = if let Some(config) = resolver_config {
283 TokioResolver::builder_with_config(config, TokioRuntimeProvider::default())
284 } else {
285 TokioResolver::builder_tokio().map_err(|error| {
286 DnsResolverError::lookup_failed(format!(
287 "failed to initialize DNS resolver from system configuration: {error}"
288 ))
289 })?
290 };
291
292 let resolver = builder.build().map_err(|error| {
293 DnsResolverError::lookup_failed(format!(
294 "failed to build DNS resolver: {error}"
295 ))
296 })?;
297 let lookup = resolver
298 .lookup(&hostname, record_type)
299 .await
300 .map_err(|error| {
301 DnsResolverError::lookup_failed(format!(
302 "failed to resolve DNS {record_type} record {hostname}: {error}"
303 ))
304 })?;
305 let records = lookup.answers().to_vec();
306 if records.is_empty() {
307 return Err(DnsResolverError::lookup_failed(format!(
308 "failed to resolve DNS {record_type} record {hostname}"
309 )));
310 }
311 Ok(records)
312 })
313 })
314 .join()
315 .map_err(|_| DnsResolverError::lookup_failed("dns resolver thread panicked"))?
316 }
317}
318
319pub fn normalize_dns_hostname(hostname: &str) -> Result<String, DnsResolverError> {
320 let normalized = hostname.trim().trim_end_matches('.').to_ascii_lowercase();
321 if normalized.is_empty() {
322 return Err(DnsResolverError::invalid_input(
323 "DNS hostname must not be empty",
324 ));
325 }
326 Ok(normalized)
327}
328
329pub fn format_dns_resource(hostname: &str) -> Result<String, DnsResolverError> {
330 Ok(format!("dns://{}", canonical_dns_subject(hostname)?))
331}
332
333pub fn resolve_dns(
334 config: &DnsConfig,
335 resolver: &dyn DnsResolver,
336 hostname: &str,
337) -> Result<DnsResolution, DnsResolverError> {
338 let trimmed = hostname.trim();
339 if let Ok(ip_addr) = trimmed.parse::<IpAddr>() {
340 return Ok(DnsResolution::new(
341 ip_addr.to_string(),
342 DnsResolutionSource::Literal,
343 vec![ip_addr],
344 ));
345 }
346
347 let normalized_hostname = normalize_dns_hostname(trimmed)?;
348 if let Some(addresses) = config.overrides.get(&normalized_hostname) {
349 return Ok(DnsResolution::new(
350 normalized_hostname,
351 DnsResolutionSource::Override,
352 addresses.clone(),
353 ));
354 }
355
356 let request = DnsLookupRequest::new(normalized_hostname.clone(), config.name_servers.clone());
357 let addresses = resolver.lookup_ip(&request)?;
358 if addresses.is_empty() {
359 return Err(DnsResolverError::lookup_failed(format!(
360 "failed to resolve DNS address {normalized_hostname}"
361 )));
362 }
363
364 Ok(DnsResolution::new(
365 normalized_hostname,
366 DnsResolutionSource::Resolver,
367 dedupe_addresses(addresses),
368 ))
369}
370
371pub fn resolve_dns_records(
372 config: &DnsConfig,
373 resolver: &dyn DnsResolver,
374 hostname: &str,
375 record_type: RecordType,
376) -> Result<DnsRecordResolution, DnsResolverError> {
377 let trimmed = hostname.trim();
378 let normalized_hostname = normalize_dns_hostname(trimmed)?;
379 let owner_name = normalized_hostname.parse::<Name>().map_err(|error| {
380 DnsResolverError::invalid_input(format!("invalid DNS hostname: {error}"))
381 })?;
382
383 if let Some(records) = records_from_literal(trimmed, owner_name.clone(), record_type) {
384 return Ok(DnsRecordResolution::new(
385 normalized_hostname,
386 DnsResolutionSource::Literal,
387 records,
388 ));
389 }
390
391 if let Some(addresses) = config.overrides.get(&normalized_hostname) {
392 let records = records_from_addresses(owner_name.clone(), addresses, record_type);
393 if !records.is_empty() {
394 return Ok(DnsRecordResolution::new(
395 normalized_hostname,
396 DnsResolutionSource::Override,
397 records,
398 ));
399 }
400 }
401
402 let request = DnsRecordLookupRequest::new(
403 normalized_hostname.clone(),
404 config.name_servers.clone(),
405 record_type,
406 );
407 let records = resolver.lookup_records(&request)?;
408 if records.is_empty() {
409 return Err(DnsResolverError::lookup_failed(format!(
410 "failed to resolve DNS {record_type} record {normalized_hostname}"
411 )));
412 }
413
414 Ok(DnsRecordResolution::new(
415 normalized_hostname,
416 DnsResolutionSource::Resolver,
417 records,
418 ))
419}
420
421fn canonical_dns_subject(hostname: &str) -> Result<String, DnsResolverError> {
422 let trimmed = hostname.trim();
423 if let Ok(ip_addr) = trimmed.parse::<IpAddr>() {
424 return Ok(ip_addr.to_string());
425 }
426
427 normalize_dns_hostname(trimmed)
428}
429
430fn resolver_config_from_name_servers(name_servers: &[SocketAddr]) -> Option<ResolverConfig> {
431 if name_servers.is_empty() {
432 return None;
433 }
434
435 let name_servers = name_servers
436 .iter()
437 .map(|server| {
438 let mut config = NameServerConfig::udp_and_tcp(server.ip());
439 for connection in &mut config.connections {
440 connection.port = server.port();
441 connection.bind_addr = Some(SocketAddr::new(
442 if server.is_ipv6() {
443 IpAddr::V6(Ipv6Addr::UNSPECIFIED)
444 } else {
445 IpAddr::V4(Ipv4Addr::UNSPECIFIED)
446 },
447 0,
448 ));
449 }
450 config
451 })
452 .collect();
453
454 Some(ResolverConfig::from_parts(None, vec![], name_servers))
455}
456
457fn dedupe_addresses(addresses: Vec<IpAddr>) -> Vec<IpAddr> {
458 let mut deduped = Vec::with_capacity(addresses.len());
459 let mut seen = BTreeSet::new();
460 for address in addresses {
461 if seen.insert(address) {
462 deduped.push(address);
463 }
464 }
465 deduped
466}
467
468fn records_from_literal(
469 hostname: &str,
470 owner_name: Name,
471 record_type: RecordType,
472) -> Option<Vec<Record>> {
473 let ip_addr = hostname.parse::<IpAddr>().ok()?;
474 let records = records_from_addresses(owner_name, &[ip_addr], record_type);
475 if records.is_empty() {
476 return None;
477 }
478 Some(records)
479}
480
481fn records_from_addresses(
482 owner_name: Name,
483 addresses: &[IpAddr],
484 record_type: RecordType,
485) -> Vec<Record> {
486 addresses
487 .iter()
488 .filter_map(|ip| match (record_type, ip) {
489 (RecordType::A, IpAddr::V4(ipv4)) | (RecordType::ANY, IpAddr::V4(ipv4)) => Some(
490 Record::from_rdata(owner_name.clone(), 60, RData::A(A::from(*ipv4))),
491 ),
492 (RecordType::AAAA, IpAddr::V6(ipv6)) | (RecordType::ANY, IpAddr::V6(ipv6)) => Some(
493 Record::from_rdata(owner_name.clone(), 60, RData::AAAA(AAAA::from(*ipv6))),
494 ),
495 _ => None,
496 })
497 .collect()
498}