1use hickory_client::client::{Client, SyncClient};
4use hickory_client::udp::UdpClientConnection;
5use hickory_server::authority::{Catalog, ZoneType};
6use hickory_server::proto::rr::rdata::{A, AAAA};
7use hickory_server::proto::rr::{DNSClass, LowerName, Name, RData, Record, RecordType};
8use hickory_server::resolver::config::NameServerConfigGroup;
9use hickory_server::server::ServerFuture;
10use hickory_server::store::in_memory::InMemoryAuthority;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
14use std::str::FromStr;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::net::{TcpListener, UdpSocket};
18use tokio::sync::RwLock;
19
20pub const DEFAULT_DNS_PORT: u16 = 15353;
22
23const STANDARD_DNS_PORT: u16 = 53;
26
27const PUBLIC_FALLBACK_UPSTREAMS: [IpAddr; 2] = [
36 IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
37 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
38];
39
40pub(crate) const RESOLV_CONF_PATH: &str = "/etc/resolv.conf";
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct DnsConfig {
46 pub zone: String,
48 pub port: u16,
50 pub bind_addr: IpAddr,
52 #[serde(default)]
69 pub upstreams: Option<Vec<SocketAddr>>,
70}
71
72impl DnsConfig {
73 #[must_use]
75 pub fn new(zone: &str, bind_addr: IpAddr) -> Self {
76 Self {
77 zone: zone.to_string(),
78 port: DEFAULT_DNS_PORT,
79 bind_addr,
80 upstreams: None,
81 }
82 }
83
84 #[must_use]
86 pub fn with_port(mut self, port: u16) -> Self {
87 self.port = port;
88 self
89 }
90
91 #[must_use]
97 pub fn with_upstreams(mut self, upstreams: Vec<SocketAddr>) -> Self {
98 self.upstreams = Some(upstreams);
99 self
100 }
101}
102
103fn is_unusable_upstream(ip: IpAddr) -> bool {
116 match ip {
117 IpAddr::V4(v4) => v4.is_loopback() || v4.is_unspecified(),
118 IpAddr::V6(v6) => v6.is_loopback() || v6.is_unspecified(),
119 }
120}
121
122fn parse_resolv_conf(contents: &str) -> Vec<SocketAddr> {
133 let mut out: Vec<SocketAddr> = Vec::new();
134 for line in contents.lines() {
135 let line = line.trim();
136 if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
137 continue;
138 }
139 let mut parts = line.split_whitespace();
140 if parts.next() != Some("nameserver") {
141 continue;
142 }
143 let Some(addr_str) = parts.next() else {
144 continue;
145 };
146 let addr_str = addr_str.split('%').next().unwrap_or(addr_str);
149 let Ok(ip) = IpAddr::from_str(addr_str) else {
150 continue;
151 };
152 if is_unusable_upstream(ip) {
153 continue;
154 }
155 let sock = SocketAddr::new(ip, STANDARD_DNS_PORT);
156 if !out.contains(&sock) {
157 out.push(sock);
158 }
159 }
160 out
161}
162
163pub(crate) fn resolve_upstreams(config: &DnsConfig, resolv_conf_path: &str) -> Vec<SocketAddr> {
185 if let Some(explicit) = &config.upstreams {
186 if !explicit.is_empty() {
187 tracing::debug!(
188 count = explicit.len(),
189 "using explicit overlay DNS upstreams from config (host detection skipped)",
190 );
191 return explicit.clone();
192 }
193 }
194
195 let detected = match std::fs::read_to_string(resolv_conf_path) {
196 Ok(contents) => parse_resolv_conf(&contents),
197 Err(e) => {
198 tracing::warn!(
199 path = resolv_conf_path,
200 error = %e,
201 "could not read host resolv.conf for overlay DNS upstream detection",
202 );
203 Vec::new()
204 }
205 };
206
207 if detected.is_empty() {
208 let fallback: Vec<SocketAddr> = PUBLIC_FALLBACK_UPSTREAMS
209 .iter()
210 .map(|ip| SocketAddr::new(*ip, STANDARD_DNS_PORT))
211 .collect();
212 tracing::warn!(
213 fallback = ?fallback,
214 "no usable host DNS upstreams found (resolv.conf empty, missing, or stub-only); \
215 falling back to public resolvers for overlay forwarding",
216 );
217 fallback
218 } else {
219 tracing::info!(
220 upstreams = ?detected,
221 "overlay DNS forwarding to host upstreams (loopback/stub filtered out)",
222 );
223 detected
224 }
225}
226
227pub(crate) fn build_forward_resolver(
249 upstreams: &[SocketAddr],
250) -> Result<hickory_server::resolver::TokioAsyncResolver, DnsError> {
251 use hickory_server::resolver::config::{ResolverConfig, ResolverOpts};
252
253 if upstreams.is_empty() {
254 return Err(DnsError::Server("no upstreams for forward resolver".into()));
255 }
256
257 let mut group = NameServerConfigGroup::new();
258 let mut by_port: std::collections::BTreeMap<u16, Vec<IpAddr>> =
259 std::collections::BTreeMap::new();
260 for addr in upstreams {
261 by_port.entry(addr.port()).or_default().push(addr.ip());
262 }
263 for (port, ips) in by_port {
264 group.merge(NameServerConfigGroup::from_ips_clear(&ips, port, true));
267 }
268
269 let mut options = ResolverOpts::default();
273 options.timeout = Duration::from_secs(2);
274 options.attempts = 2;
275 options.preserve_intermediates = true;
277
278 let config = ResolverConfig::from_parts(None, vec![], group);
279 Ok(hickory_server::resolver::TokioAsyncResolver::tokio(
280 config, options,
281 ))
282}
283
284struct ForwardingCatalog {
306 catalog: Catalog,
307 zone_origin: LowerName,
308 resolver: Option<Arc<hickory_server::resolver::TokioAsyncResolver>>,
309}
310
311impl ForwardingCatalog {
312 fn forward_answer_response<'a>(
314 request: &'a hickory_server::server::Request,
315 answers: &'a [Record],
316 ) -> hickory_server::authority::MessageResponse<
317 'a,
318 'a,
319 std::slice::Iter<'a, Record>,
320 std::iter::Empty<&'a Record>,
321 std::iter::Empty<&'a Record>,
322 std::iter::Empty<&'a Record>,
323 > {
324 use hickory_server::authority::MessageResponseBuilder;
325 use hickory_server::proto::op::ResponseCode;
326
327 let mut header = hickory_server::proto::op::Header::response_from_request(request.header());
328 header.set_recursion_available(true);
329 header.set_response_code(ResponseCode::NoError);
330 header.set_authoritative(false);
332
333 MessageResponseBuilder::from_message_request(request).build(
334 header,
335 answers.iter(),
336 std::iter::empty(),
337 std::iter::empty(),
338 std::iter::empty(),
339 )
340 }
341
342 fn forward_code_response(
345 request: &hickory_server::server::Request,
346 code: hickory_server::proto::op::ResponseCode,
347 ) -> hickory_server::authority::MessageResponse<
348 '_,
349 '_,
350 impl Iterator<Item = &Record> + Send,
351 impl Iterator<Item = &Record> + Send,
352 impl Iterator<Item = &Record> + Send,
353 impl Iterator<Item = &Record> + Send,
354 > {
355 use hickory_server::authority::MessageResponseBuilder;
356 MessageResponseBuilder::from_message_request(request).error_msg(request.header(), code)
357 }
358
359 async fn forward<R: hickory_server::server::ResponseHandler>(
362 &self,
363 resolver: &hickory_server::resolver::TokioAsyncResolver,
364 request: &hickory_server::server::Request,
365 mut response_handle: R,
366 ) -> hickory_server::server::ResponseInfo {
367 use hickory_server::proto::op::ResponseCode;
368 use hickory_server::resolver::error::ResolveErrorKind;
369
370 let query = request.request_info().query;
371 let name = Name::from(query.name());
372 let rtype = query.query_type();
373
374 match resolver.lookup(name, rtype).await {
375 Ok(lookup) => {
376 let records: Vec<Record> = lookup.records().to_vec();
377 let response = Self::forward_answer_response(request, &records);
378 Self::send_or_servfail(&mut response_handle, response).await
379 }
380 Err(e) => {
381 let code = match e.kind() {
382 ResolveErrorKind::NoRecordsFound { response_code, .. }
384 if *response_code == ResponseCode::NXDomain =>
385 {
386 ResponseCode::NXDomain
387 }
388 ResolveErrorKind::NoRecordsFound { response_code, .. }
390 if *response_code == ResponseCode::NoError =>
391 {
392 ResponseCode::NoError
393 }
394 _ => {
398 tracing::debug!(error = %e, "overlay DNS upstream forward failed; SERVFAIL");
399 ResponseCode::ServFail
400 }
401 };
402 let response = Self::forward_code_response(request, code);
403 Self::send_or_servfail(&mut response_handle, response).await
404 }
405 }
406 }
407
408 async fn send_or_servfail<'a, R, A, N, S, D>(
411 response_handle: &mut R,
412 response: hickory_server::authority::MessageResponse<'_, 'a, A, N, S, D>,
413 ) -> hickory_server::server::ResponseInfo
414 where
415 R: hickory_server::server::ResponseHandler,
416 A: Iterator<Item = &'a Record> + Send + 'a,
417 N: Iterator<Item = &'a Record> + Send + 'a,
418 S: Iterator<Item = &'a Record> + Send + 'a,
419 D: Iterator<Item = &'a Record> + Send + 'a,
420 {
421 match response_handle.send_response(response).await {
422 Ok(info) => info,
423 Err(e) => {
424 tracing::error!(error = %e, "failed to send overlay DNS forward response");
425 let mut header = hickory_server::proto::op::Header::new();
426 header.set_response_code(hickory_server::proto::op::ResponseCode::ServFail);
427 header.into()
428 }
429 }
430 }
431}
432
433#[async_trait::async_trait]
434impl hickory_server::server::RequestHandler for ForwardingCatalog {
435 async fn handle_request<R: hickory_server::server::ResponseHandler>(
436 &self,
437 request: &hickory_server::server::Request,
438 response_handle: R,
439 ) -> hickory_server::server::ResponseInfo {
440 let query_name = request.request_info().query.name().clone();
444 let is_overlay = self.zone_origin.zone_of(&query_name);
445
446 match (&self.resolver, is_overlay) {
447 (Some(resolver), false) => self.forward(resolver, request, response_handle).await,
448 _ => self.catalog.handle_request(request, response_handle).await,
449 }
450 }
451}
452
453#[must_use]
458pub fn peer_hostname(ip: IpAddr) -> String {
459 match ip {
460 IpAddr::V4(v4) => {
461 let octets = v4.octets();
462 format!("node-{}-{}", octets[2], octets[3])
463 }
464 IpAddr::V6(v6) => {
465 let segments = v6.segments();
466 let last_segment = segments[7];
467 format!("node-{last_segment:04x}")
468 }
469 }
470}
471
472#[derive(Debug, thiserror::Error)]
474pub enum DnsError {
475 #[error("Invalid domain name: {0}")]
476 InvalidName(String),
477
478 #[error("DNS server error: {0}")]
479 Server(String),
480
481 #[error("DNS client error: {0}")]
482 Client(String),
483
484 #[error("IO error: {0}")]
485 Io(#[from] std::io::Error),
486
487 #[error("Record not found: {0}")]
488 NotFound(String),
489}
490
491#[derive(Clone)]
495pub struct DnsHandle {
496 authority: Arc<InMemoryAuthority>,
497 zone_origin: Name,
498 serial: Arc<RwLock<u32>>,
499}
500
501impl DnsHandle {
502 pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
510 let fqdn = if hostname.ends_with('.') {
512 Name::from_str(hostname)
513 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
514 } else {
515 let name = Name::from_str(hostname)
517 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
518 name.append_domain(&self.zone_origin)
519 .map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
520 };
521
522 let rdata = match ip {
524 IpAddr::V4(v4) => RData::A(A::from(v4)),
525 IpAddr::V6(v6) => RData::AAAA(AAAA::from(v6)),
526 };
527 let record = Record::from_rdata(fqdn, 300, rdata); let serial = {
531 let mut s = self.serial.write().await;
532 let current = *s;
533 *s = s.wrapping_add(1);
534 current
535 };
536
537 self.authority.upsert(record, serial).await;
539
540 Ok(())
541 }
542
543 pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
551 let fqdn = if hostname.ends_with('.') {
552 Name::from_str(hostname)
553 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?
554 } else {
555 let name = Name::from_str(hostname)
556 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
557 name.append_domain(&self.zone_origin)
558 .map_err(|e| DnsError::InvalidName(format!("Failed to append zone: {e}")))?
559 };
560
561 let serial = {
562 let mut s = self.serial.write().await;
563 let current = *s;
564 *s = s.wrapping_add(1);
565 current
566 };
567
568 let a_record = Record::with(fqdn.clone(), RecordType::A, 0);
572 self.authority.upsert(a_record, serial).await;
573
574 let aaaa_record = Record::with(fqdn.clone(), RecordType::AAAA, 0);
575 self.authority.upsert(aaaa_record, serial).await;
576
577 Ok(true)
578 }
579
580 #[must_use]
582 pub fn zone_origin(&self) -> &Name {
583 &self.zone_origin
584 }
585
586 pub async fn lookup_a(&self, fqdn: &str) -> Option<IpAddr> {
598 use hickory_server::authority::{Authority, LookupOptions};
599
600 let name = Name::from_str(fqdn).ok()?;
601 let lower = LowerName::from(name);
602 let lookup = self
603 .authority
604 .lookup(&lower, RecordType::A, LookupOptions::default())
605 .await
606 .ok()?;
607 lookup.iter().find_map(|record| match record.data() {
608 Some(RData::A(a)) => Some(IpAddr::V4((*a).into())),
609 _ => None,
610 })
611 }
612}
613
614pub struct DnsServer {
616 listen_addr: SocketAddr,
617 authority: Arc<InMemoryAuthority>,
618 zone_origin: Name,
619 serial: Arc<RwLock<u32>>,
620 upstreams: Vec<SocketAddr>,
632}
633
634impl DnsServer {
635 pub fn new(listen_addr: SocketAddr, zone: &str) -> Result<Self, DnsError> {
645 let upstreams =
646 resolve_upstreams(&DnsConfig::new(zone, listen_addr.ip()), RESOLV_CONF_PATH);
647 Self::new_with_upstreams(listen_addr, zone, upstreams)
648 }
649
650 pub fn new_with_upstreams(
660 listen_addr: SocketAddr,
661 zone: &str,
662 upstreams: Vec<SocketAddr>,
663 ) -> Result<Self, DnsError> {
664 let zone_origin =
665 Name::from_str(zone).map_err(|e| DnsError::InvalidName(format!("{zone}: {e}")))?;
666
667 let authority = Arc::new(InMemoryAuthority::empty(
670 zone_origin.clone(),
671 ZoneType::Primary,
672 false,
673 ));
674
675 Ok(Self {
676 listen_addr,
677 authority,
678 zone_origin,
679 serial: Arc::new(RwLock::new(1)),
680 upstreams,
681 })
682 }
683
684 pub fn from_config(config: &DnsConfig) -> Result<Self, DnsError> {
693 let listen_addr = SocketAddr::new(config.bind_addr, config.port);
694 let upstreams = resolve_upstreams(config, RESOLV_CONF_PATH);
695 Self::new_with_upstreams(listen_addr, &config.zone, upstreams)
696 }
697
698 #[must_use]
700 pub fn upstreams(&self) -> &[SocketAddr] {
701 &self.upstreams
702 }
703
704 fn build_catalog(
715 zone_origin: Name,
716 authority: Arc<InMemoryAuthority>,
717 upstreams: &[SocketAddr],
718 ) -> ForwardingCatalog {
719 let lower_origin = LowerName::from(zone_origin.clone());
720
721 let mut catalog = Catalog::new();
722 catalog.upsert(zone_origin.into(), Box::new(authority));
724
725 let resolver = if upstreams.is_empty() {
726 None
727 } else {
728 match build_forward_resolver(upstreams) {
729 Ok(r) => {
730 tracing::debug!(
731 upstreams = ?upstreams,
732 "overlay DNS forwarder ready for non-overlay queries",
733 );
734 Some(Arc::new(r))
735 }
736 Err(e) => {
737 tracing::error!(
738 error = %e,
739 "failed to build overlay DNS forwarder; non-overlay queries \
740 will be refused (overlay zone still served)",
741 );
742 None
743 }
744 }
745 };
746
747 ForwardingCatalog {
748 catalog,
749 zone_origin: lower_origin,
750 resolver,
751 }
752 }
753
754 #[must_use]
759 pub fn handle(&self) -> DnsHandle {
760 DnsHandle {
761 authority: Arc::clone(&self.authority),
762 zone_origin: self.zone_origin.clone(),
763 serial: Arc::clone(&self.serial),
764 }
765 }
766
767 pub async fn add_record(&self, hostname: &str, ip: IpAddr) -> Result<(), DnsError> {
775 self.handle().add_record(hostname, ip).await
776 }
777
778 pub async fn remove_record(&self, hostname: &str) -> Result<bool, DnsError> {
784 self.handle().remove_record(hostname).await
785 }
786
787 #[allow(clippy::unused_async)]
796 pub async fn start(self) -> Result<DnsHandle, DnsError> {
797 let handle = self.handle();
798 let listen_addr = self.listen_addr;
799 let zone_origin = self.zone_origin.clone();
800 let authority = Arc::clone(&self.authority);
801 let upstreams = self.upstreams.clone();
802
803 tokio::spawn(async move {
805 if let Err(e) = Self::run_server(listen_addr, zone_origin, authority, upstreams).await {
806 tracing::error!("DNS server error: {}", e);
807 }
808 });
809
810 Ok(handle)
811 }
812
813 #[allow(clippy::unused_async)]
823 pub async fn start_background(&self) -> Result<DnsHandle, DnsError> {
824 let handle = self.handle();
825 let listen_addr = self.listen_addr;
826 let zone_origin = self.zone_origin.clone();
827 let authority = Arc::clone(&self.authority);
828 let upstreams = self.upstreams.clone();
829
830 tokio::spawn(async move {
831 if let Err(e) = Self::run_server(listen_addr, zone_origin, authority, upstreams).await {
832 tracing::error!("DNS server error: {}", e);
833 }
834 });
835
836 Ok(handle)
837 }
838
839 #[allow(clippy::unused_async)]
868 pub async fn bind_windows_fallback(&self, bind_ip: IpAddr) -> Result<DnsHandle, DnsError> {
869 self.bind_secondary(SocketAddr::new(bind_ip, 53)).await
870 }
871
872 #[allow(clippy::unused_async)]
887 pub async fn bind_secondary(&self, listen_addr: SocketAddr) -> Result<DnsHandle, DnsError> {
888 let handle = self.handle();
889 let zone_origin = self.zone_origin.clone();
890 let authority = Arc::clone(&self.authority);
891 let upstreams = self.upstreams.clone();
892
893 let udp_socket = UdpSocket::bind(listen_addr).await?;
897 let tcp_listener = TcpListener::bind(listen_addr).await?;
898
899 tokio::spawn(async move {
900 let catalog = Self::build_catalog(zone_origin, authority, &upstreams);
901 let mut server = ServerFuture::new(catalog);
902 server.register_socket(udp_socket);
903 server.register_listener(tcp_listener, Duration::from_secs(30));
904 tracing::info!(
905 addr = %listen_addr,
906 "secondary DNS listener started",
907 );
908 if let Err(e) = server.block_until_done().await {
909 tracing::error!("secondary DNS listener error: {}", e);
910 }
911 });
912
913 Ok(handle)
914 }
915
916 async fn run_server(
918 listen_addr: SocketAddr,
919 zone_origin: Name,
920 authority: Arc<InMemoryAuthority>,
921 upstreams: Vec<SocketAddr>,
922 ) -> Result<(), DnsError> {
923 let catalog = Self::build_catalog(zone_origin, authority, &upstreams);
926
927 let mut server = ServerFuture::new(catalog);
929
930 let udp_socket = UdpSocket::bind(listen_addr).await?;
932 server.register_socket(udp_socket);
933
934 let tcp_listener = TcpListener::bind(listen_addr).await?;
936 server.register_listener(tcp_listener, Duration::from_secs(30));
937
938 tracing::info!(addr = %listen_addr, "DNS server listening");
939
940 server
942 .block_until_done()
943 .await
944 .map_err(|e| DnsError::Server(e.to_string()))?;
945
946 Ok(())
947 }
948
949 #[must_use]
951 pub fn listen_addr(&self) -> SocketAddr {
952 self.listen_addr
953 }
954
955 #[must_use]
957 pub fn zone_origin(&self) -> &Name {
958 &self.zone_origin
959 }
960}
961
962pub struct DnsClient {
964 server_addr: SocketAddr,
965}
966
967impl DnsClient {
968 #[must_use]
970 pub fn new(server_addr: SocketAddr) -> Self {
971 Self { server_addr }
972 }
973
974 pub fn query_a(&self, hostname: &str) -> Result<Option<Ipv4Addr>, DnsError> {
980 let name = Name::from_str(hostname)
981 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
982
983 let conn = UdpClientConnection::new(self.server_addr)
984 .map_err(|e| DnsError::Client(e.to_string()))?;
985
986 let client = SyncClient::new(conn);
987
988 let response = client
989 .query(&name, DNSClass::IN, RecordType::A)
990 .map_err(|e| DnsError::Client(e.to_string()))?;
991
992 for answer in response.answers() {
994 if let Some(RData::A(a_record)) = answer.data() {
995 return Ok(Some((*a_record).into()));
996 }
997 }
998
999 Ok(None)
1000 }
1001
1002 pub fn query_aaaa(&self, hostname: &str) -> Result<Option<Ipv6Addr>, DnsError> {
1008 let name = Name::from_str(hostname)
1009 .map_err(|e| DnsError::InvalidName(format!("{hostname}: {e}")))?;
1010
1011 let conn = UdpClientConnection::new(self.server_addr)
1012 .map_err(|e| DnsError::Client(e.to_string()))?;
1013
1014 let client = SyncClient::new(conn);
1015
1016 let response = client
1017 .query(&name, DNSClass::IN, RecordType::AAAA)
1018 .map_err(|e| DnsError::Client(e.to_string()))?;
1019
1020 for answer in response.answers() {
1022 if let Some(RData::AAAA(aaaa_record)) = answer.data() {
1023 return Ok(Some((*aaaa_record).into()));
1024 }
1025 }
1026
1027 Ok(None)
1028 }
1029
1030 pub fn query_addr(&self, hostname: &str) -> Result<Option<IpAddr>, DnsError> {
1038 if let Ok(Some(v4)) = self.query_a(hostname) {
1040 return Ok(Some(IpAddr::V4(v4)));
1041 }
1042
1043 if let Ok(Some(v6)) = self.query_aaaa(hostname) {
1045 return Ok(Some(IpAddr::V6(v6)));
1046 }
1047
1048 Ok(None)
1049 }
1050}
1051
1052pub struct ServiceDiscovery {
1054 dns_server: SocketAddr,
1055 records: RwLock<HashMap<String, IpAddr>>,
1056}
1057
1058impl ServiceDiscovery {
1059 #[must_use]
1061 pub fn new(dns_server_addr: SocketAddr) -> Self {
1062 Self {
1063 dns_server: dns_server_addr,
1064 records: RwLock::new(HashMap::new()),
1065 }
1066 }
1067
1068 pub async fn register(&self, name: &str, ip: IpAddr) {
1070 let mut records = self.records.write().await;
1071 records.insert(name.to_string(), ip);
1072 }
1073
1074 pub async fn resolve(&self, name: &str) -> Option<IpAddr> {
1079 {
1081 let records = self.records.read().await;
1082 if let Some(ip) = records.get(name) {
1083 return Some(*ip);
1084 }
1085 }
1086
1087 let client = DnsClient::new(self.dns_server);
1089 if let Ok(Some(addr)) = client.query_addr(name) {
1090 return Some(addr);
1091 }
1092
1093 None
1094 }
1095
1096 pub async fn unregister(&self, name: &str) {
1098 let mut records = self.records.write().await;
1099 records.remove(name);
1100 }
1101
1102 pub async fn list_services(&self) -> Vec<String> {
1104 let records = self.records.read().await;
1105 records.keys().cloned().collect()
1106 }
1107
1108 pub fn dns_server(&self) -> SocketAddr {
1110 self.dns_server
1111 }
1112}
1113
1114#[cfg(test)]
1115mod tests {
1116 use super::*;
1117
1118 #[test]
1119 fn test_peer_hostname_v4() {
1120 assert_eq!(
1122 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1))),
1123 "node-0-1"
1124 );
1125 assert_eq!(
1126 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 0, 5))),
1127 "node-0-5"
1128 );
1129 assert_eq!(
1130 peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 200, 1, 100))),
1131 "node-1-100"
1132 );
1133 assert_eq!(
1134 peer_hostname(IpAddr::V4(Ipv4Addr::new(192, 168, 255, 254))),
1135 "node-255-254"
1136 );
1137 }
1138
1139 #[test]
1140 fn test_peer_hostname_v6() {
1141 assert_eq!(
1143 peer_hostname(IpAddr::V6("fd00::1".parse().unwrap())),
1144 "node-0001"
1145 );
1146 assert_eq!(
1147 peer_hostname(IpAddr::V6("fd00::abcd".parse().unwrap())),
1148 "node-abcd"
1149 );
1150 assert_eq!(
1151 peer_hostname(IpAddr::V6("fd00:200::ffff".parse().unwrap())),
1152 "node-ffff"
1153 );
1154 assert_eq!(
1156 peer_hostname(IpAddr::V6("fd00::1:0".parse().unwrap())),
1157 "node-0000"
1158 );
1159 }
1160
1161 #[test]
1162 fn test_dns_config() {
1163 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
1164 assert_eq!(config.zone, "overlay.local.");
1165 assert_eq!(config.port, DEFAULT_DNS_PORT);
1166 assert_eq!(config.bind_addr, IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)));
1167
1168 let config = config.with_port(5353);
1170 assert_eq!(config.port, 5353);
1171 }
1172
1173 #[test]
1174 fn test_dns_config_serialization() {
1175 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 1)))
1176 .with_port(15353);
1177
1178 let json = serde_json::to_string(&config).unwrap();
1179 let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
1180
1181 assert_eq!(deserialized.zone, config.zone);
1182 assert_eq!(deserialized.port, config.port);
1183 assert_eq!(deserialized.bind_addr, config.bind_addr);
1184 }
1185
1186 #[tokio::test]
1187 async fn test_service_discovery_local_cache() {
1188 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1190 let discovery = ServiceDiscovery::new(addr);
1191
1192 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
1193 discovery.register("test-service", ip).await;
1194
1195 let resolved = discovery.resolve("test-service").await;
1196 assert_eq!(resolved, Some(ip));
1197
1198 discovery.unregister("test-service").await;
1200 let services = discovery.list_services().await;
1201 assert!(services.is_empty());
1202 }
1203
1204 #[test]
1205 fn test_dns_server_creation() {
1206 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1207 let server = DnsServer::new(addr, "overlay.local.");
1208
1209 assert!(server.is_ok());
1210 let server = server.unwrap();
1211 assert_eq!(server.listen_addr(), addr);
1212 assert_eq!(server.zone_origin().to_string(), "overlay.local.");
1213 }
1214
1215 #[test]
1216 fn test_dns_server_from_config() {
1217 let config =
1218 DnsConfig::new("test.local.", IpAddr::V4(Ipv4Addr::LOCALHOST)).with_port(15353);
1219 let server = DnsServer::from_config(&config);
1220
1221 assert!(server.is_ok());
1222 let server = server.unwrap();
1223 assert_eq!(server.listen_addr().port(), 15353);
1224 assert_eq!(server.zone_origin().to_string(), "test.local.");
1225 }
1226
1227 #[test]
1228 fn test_dns_server_invalid_zone() {
1229 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1230 let server = DnsServer::new(addr, "overlay.local.");
1232 assert!(server.is_ok());
1233 }
1234
1235 #[tokio::test]
1236 async fn test_dns_server_add_record() {
1237 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1238 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1239
1240 let result = server
1241 .add_record("myservice", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)))
1242 .await;
1243 assert!(result.is_ok());
1244 }
1245
1246 #[tokio::test]
1247 async fn test_dns_handle_add_record() {
1248 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1249 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1250
1251 let handle = server.handle();
1253
1254 let result = handle
1255 .add_record("service1", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))
1256 .await;
1257 assert!(result.is_ok());
1258
1259 let result = handle
1260 .add_record("service2", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)))
1261 .await;
1262 assert!(result.is_ok());
1263
1264 assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
1266 }
1267
1268 #[test]
1269 fn test_dns_client_creation() {
1270 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53);
1271 let client = DnsClient::new(addr);
1272 assert_eq!(client.server_addr, addr);
1273 }
1274
1275 #[tokio::test]
1276 async fn test_dns_handle_add_aaaa_record() {
1277 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1278 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1279 let handle = server.handle();
1280
1281 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1283 let result = handle.add_record("service-v6", ipv6).await;
1284 assert!(result.is_ok());
1285
1286 let ipv6_2: IpAddr = "fd00::abcd".parse().unwrap();
1288 let result = handle.add_record("service-v6-2", ipv6_2).await;
1289 assert!(result.is_ok());
1290 }
1291
1292 #[tokio::test]
1293 async fn test_dns_server_add_aaaa_record() {
1294 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1295 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1296
1297 let ipv6: IpAddr = "fd00::42".parse().unwrap();
1299 let result = server.add_record("myservice-v6", ipv6).await;
1300 assert!(result.is_ok());
1301 }
1302
1303 #[tokio::test]
1304 async fn test_dns_handle_remove_record_covers_both_types() {
1305 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1306 let server = DnsServer::new(addr, "overlay.local.").unwrap();
1307 let handle = server.handle();
1308
1309 let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
1311 handle.add_record("dual-service", ipv4).await.unwrap();
1312
1313 let removed = handle.remove_record("dual-service").await.unwrap();
1315 assert!(removed);
1316
1317 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1319 handle.add_record("v6-service", ipv6).await.unwrap();
1320
1321 let removed = handle.remove_record("v6-service").await.unwrap();
1323 assert!(removed);
1324 }
1325
1326 #[tokio::test]
1327 async fn test_service_discovery_local_cache_ipv6() {
1328 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1329 let discovery = ServiceDiscovery::new(addr);
1330
1331 let ipv6: IpAddr = "fd00::beef".parse().unwrap();
1333 discovery.register("v6-service", ipv6).await;
1334
1335 let resolved = discovery.resolve("v6-service").await;
1337 assert_eq!(resolved, Some(ipv6));
1338
1339 discovery.unregister("v6-service").await;
1341 let services = discovery.list_services().await;
1342 assert!(services.is_empty());
1343 }
1344
1345 #[tokio::test]
1346 async fn test_service_discovery_mixed_v4_v6_cache() {
1347 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 15353);
1348 let discovery = ServiceDiscovery::new(addr);
1349
1350 let ipv4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
1351 let ipv6: IpAddr = "fd00::1".parse().unwrap();
1352
1353 discovery.register("svc-v4", ipv4).await;
1354 discovery.register("svc-v6", ipv6).await;
1355
1356 assert_eq!(discovery.resolve("svc-v4").await, Some(ipv4));
1357 assert_eq!(discovery.resolve("svc-v6").await, Some(ipv6));
1358
1359 let mut services = discovery.list_services().await;
1360 services.sort();
1361 assert_eq!(services, vec!["svc-v4", "svc-v6"]);
1362 }
1363
1364 #[test]
1365 fn test_dns_config_with_ipv6_bind_addr() {
1366 let ipv6_bind: IpAddr = "fd00::1".parse().unwrap();
1367 let config = DnsConfig::new("overlay.local.", ipv6_bind);
1368 assert_eq!(config.bind_addr, ipv6_bind);
1369 assert_eq!(config.port, DEFAULT_DNS_PORT);
1370
1371 let json = serde_json::to_string(&config).unwrap();
1373 let deserialized: DnsConfig = serde_json::from_str(&json).unwrap();
1374 assert_eq!(deserialized.bind_addr, ipv6_bind);
1375 }
1376
1377 #[test]
1378 fn test_dns_server_creation_ipv6_bind() {
1379 let ipv6_addr: IpAddr = "::1".parse().unwrap();
1380 let addr = SocketAddr::new(ipv6_addr, 15353);
1381 let server = DnsServer::new(addr, "overlay.local.");
1382
1383 assert!(server.is_ok());
1384 let server = server.unwrap();
1385 assert_eq!(server.listen_addr(), addr);
1386 }
1387
1388 #[tokio::test]
1395 async fn test_bind_windows_fallback_errors_or_shares_authority() {
1396 let primary = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
1397 let server = DnsServer::new(primary, "overlay.local.").unwrap();
1398 let bind_ip: IpAddr = "127.0.0.2".parse().unwrap();
1399
1400 match server.bind_windows_fallback(bind_ip).await {
1401 Ok(handle) => {
1402 assert_eq!(handle.zone_origin().to_string(), "overlay.local.");
1406 handle
1407 .add_record("dual", IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)))
1408 .await
1409 .expect("add_record via fallback handle");
1410 }
1411 Err(DnsError::Io(_)) => {
1412 }
1416 Err(other) => panic!("unexpected error from bind_windows_fallback: {other}"),
1417 }
1418 }
1419
1420 #[test]
1421 fn test_peer_hostname_uniqueness() {
1422 let v4_a = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
1424 let v4_b = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)));
1425 assert_ne!(v4_a, v4_b);
1426
1427 let v6_a = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
1428 let v6_b = peer_hostname(IpAddr::V6("fd00::2".parse().unwrap()));
1429 assert_ne!(v6_a, v6_b);
1430
1431 let v4 = peer_hostname(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
1433 let v6 = peer_hostname(IpAddr::V6("fd00::1".parse().unwrap()));
1434 assert_ne!(v4, v6);
1435 }
1436
1437 #[test]
1440 fn test_parse_resolv_conf_filters_stub_and_loopback() {
1441 let contents = "\
1444 # generated by netbird\n\
1445 nameserver 127.0.0.53\n\
1446 nameserver 127.0.0.1\n\
1447 nameserver 192.168.1.1\n\
1448 search example.com\n\
1449 options edns0\n";
1450 let parsed = parse_resolv_conf(contents);
1451 assert_eq!(
1452 parsed,
1453 vec![SocketAddr::new(
1454 IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
1455 53
1456 )],
1457 "127.0.0.53 stub and 127.0.0.1 loopback must be filtered out",
1458 );
1459 }
1460
1461 #[test]
1462 fn test_parse_resolv_conf_dedup_and_comments() {
1463 let contents = "\
1464 ; a comment\n\
1465 nameserver 8.8.8.8\n\
1466 nameserver 8.8.8.8\n\
1467 nameserver fe80::1%eth0\n\
1468 nameserver 0.0.0.0\n";
1469 let parsed = parse_resolv_conf(contents);
1470 assert_eq!(parsed.len(), 2);
1473 assert_eq!(
1474 parsed[0],
1475 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)
1476 );
1477 assert_eq!(parsed[1].ip(), "fe80::1".parse::<IpAddr>().unwrap());
1478 }
1479
1480 #[test]
1481 fn test_resolve_upstreams_config_override_wins() {
1482 let explicit = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 9, 9, 9)), 5300);
1485 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::LOCALHOST))
1486 .with_upstreams(vec![explicit]);
1487 let resolved = resolve_upstreams(&config, "/nonexistent/resolv.conf");
1488 assert_eq!(resolved, vec![explicit]);
1489 }
1490
1491 #[test]
1492 fn test_resolve_upstreams_falls_back_to_public_when_missing() {
1493 let config = DnsConfig::new("overlay.local.", IpAddr::V4(Ipv4Addr::LOCALHOST));
1495 let resolved = resolve_upstreams(&config, "/definitely/not/a/real/resolv.conf");
1496 assert_eq!(
1497 resolved,
1498 vec![
1499 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53),
1500 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
1501 ],
1502 );
1503 }
1504
1505 async fn spawn_stub_upstream(answer_ip: Ipv4Addr) -> SocketAddr {
1513 use hickory_server::proto::op::{Message, MessageType, ResponseCode};
1514
1515 let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1516 .await
1517 .expect("bind stub upstream");
1518 let addr = sock.local_addr().expect("stub local_addr");
1519
1520 tokio::spawn(async move {
1521 let mut buf = vec![0u8; 1500];
1522 loop {
1523 let Ok((len, from)) = sock.recv_from(&mut buf).await else {
1524 break;
1525 };
1526 let Ok(request) = Message::from_vec(&buf[..len]) else {
1527 continue;
1528 };
1529 let mut resp = Message::new();
1530 resp.set_id(request.id());
1531 resp.set_message_type(MessageType::Response);
1532 resp.set_recursion_available(true);
1533 resp.set_response_code(ResponseCode::NoError);
1534 for q in request.queries() {
1535 resp.add_query(q.clone());
1536 if q.query_type() == RecordType::A {
1537 let rec =
1538 Record::from_rdata(q.name().clone(), 60, RData::A(A::from(answer_ip)));
1539 resp.add_answer(rec);
1540 }
1541 }
1542 if let Ok(bytes) = resp.to_vec() {
1543 let _ = sock.send_to(&bytes, from).await;
1544 }
1545 }
1546 });
1547
1548 addr
1549 }
1550
1551 async fn raw_query_a(
1555 server: SocketAddr,
1556 name: &str,
1557 ) -> Result<Option<Ipv4Addr>, hickory_server::proto::op::ResponseCode> {
1558 use hickory_server::proto::op::{Message, MessageType, Query, ResponseCode};
1559
1560 let client = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1561 .await
1562 .expect("bind client");
1563
1564 let qname = Name::from_str(name).expect("query name");
1565 let mut msg = Message::new();
1566 msg.set_id(0x1234);
1567 msg.set_message_type(MessageType::Query);
1568 msg.set_recursion_desired(true);
1569 msg.add_query(Query::query(qname, RecordType::A));
1570 let bytes = msg.to_vec().expect("encode query");
1571
1572 client.send_to(&bytes, server).await.expect("send query");
1573
1574 let mut buf = vec![0u8; 1500];
1575 let len = tokio::time::timeout(Duration::from_secs(12), client.recv(&mut buf))
1580 .await
1581 .expect("query timed out")
1582 .expect("recv response");
1583 let resp = Message::from_vec(&buf[..len]).expect("decode response");
1584
1585 if resp.response_code() != ResponseCode::NoError {
1586 return Err(resp.response_code());
1587 }
1588 for ans in resp.answers() {
1589 if let Some(RData::A(a)) = ans.data() {
1590 return Ok(Some((*a).into()));
1591 }
1592 }
1593 Ok(None)
1594 }
1595
1596 #[tokio::test]
1597 async fn test_forwarding_overlay_answered_and_nonoverlay_forwarded() {
1598 let upstream_answer = Ipv4Addr::new(203, 0, 113, 7);
1600 let upstream = spawn_stub_upstream(upstream_answer).await;
1601
1602 let bound = {
1606 let probe = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1607 .await
1608 .unwrap();
1609 let a = probe.local_addr().unwrap();
1610 drop(probe);
1611 a
1612 };
1613
1614 let overlay_ip = Ipv4Addr::new(10, 200, 0, 5);
1617 let server =
1618 DnsServer::new_with_upstreams(bound, "overlay.local.", vec![upstream]).unwrap();
1619 let handle = server.handle();
1620 handle
1621 .add_record("svc", IpAddr::V4(overlay_ip))
1622 .await
1623 .unwrap();
1624 let _running = server.start().await.unwrap();
1625
1626 tokio::time::sleep(Duration::from_millis(150)).await;
1628
1629 let overlay = raw_query_a(bound, "svc.overlay.local.")
1631 .await
1632 .expect("overlay query should not SERVFAIL");
1633 assert_eq!(
1634 overlay,
1635 Some(overlay_ip),
1636 "overlay name must be answered from InMemoryAuthority",
1637 );
1638
1639 let forwarded = raw_query_a(bound, "example.com.")
1641 .await
1642 .expect("forwarded query should not SERVFAIL");
1643 assert_eq!(
1644 forwarded,
1645 Some(upstream_answer),
1646 "non-overlay name must be forwarded to the upstream stub",
1647 );
1648 }
1649
1650 #[tokio::test]
1651 async fn test_forwarding_total_upstream_failure_is_servfail_not_panic() {
1652 use hickory_server::proto::op::ResponseCode;
1653
1654 let dead_upstream = {
1658 let s = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1660 .await
1661 .unwrap();
1662 let a = s.local_addr().unwrap();
1663 drop(s);
1664 a
1665 };
1666
1667 let bound = {
1668 let s = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1669 .await
1670 .unwrap();
1671 let a = s.local_addr().unwrap();
1672 drop(s);
1673 a
1674 };
1675
1676 let server =
1677 DnsServer::new_with_upstreams(bound, "overlay.local.", vec![dead_upstream]).unwrap();
1678 let handle = server.handle();
1679 handle
1680 .add_record("svc", IpAddr::V4(Ipv4Addr::new(10, 200, 0, 9)))
1681 .await
1682 .unwrap();
1683 let _running = server.start().await.unwrap();
1684 tokio::time::sleep(Duration::from_millis(150)).await;
1685
1686 let overlay = raw_query_a(bound, "svc.overlay.local.")
1688 .await
1689 .expect("overlay query should still succeed");
1690 assert_eq!(overlay, Some(Ipv4Addr::new(10, 200, 0, 9)));
1691
1692 match raw_query_a(bound, "example.com.").await {
1695 Err(ResponseCode::ServFail) => {} Err(other) => panic!("expected SERVFAIL, got {other:?}"),
1697 Ok(answer) => panic!("expected SERVFAIL, got answer {answer:?}"),
1698 }
1699 }
1700}