stratum_apps/network_helpers/
resolve_hostname.rs1use std::{
2 net::{IpAddr, SocketAddr},
3 time::Duration,
4};
5
6use tracing::{debug, info};
7
8const DNS_TIMEOUT: Duration = Duration::from_secs(5);
13
14#[derive(Debug)]
16pub enum ResolveError {
17 NoResults(String),
19 LookupFailed(std::io::Error),
21 Timeout(String),
23}
24
25impl std::fmt::Display for ResolveError {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 ResolveError::NoResults(host) => {
29 write!(f, "DNS resolution returned no results for '{host}'")
30 }
31 ResolveError::LookupFailed(e) => write!(f, "DNS resolution failed: {e}"),
32 ResolveError::Timeout(host) => {
33 write!(
34 f,
35 "DNS resolution for '{host}' timed out after {}s",
36 DNS_TIMEOUT.as_secs()
37 )
38 }
39 }
40 }
41}
42
43impl std::error::Error for ResolveError {}
44
45pub async fn resolve_host(host: &str, port: u16) -> Result<SocketAddr, ResolveError> {
63 if let Ok(ip) = host.parse::<IpAddr>() {
65 return Ok(SocketAddr::new(ip, port));
66 }
67
68 info!("Resolving hostname '{host}' via DNS...");
70 let lookup = format!("{host}:{port}");
71 let addr = tokio::time::timeout(DNS_TIMEOUT, tokio::net::lookup_host(&lookup))
72 .await
73 .map_err(|_| ResolveError::Timeout(host.to_string()))?
74 .map_err(ResolveError::LookupFailed)?
75 .next()
77 .ok_or_else(|| ResolveError::NoResults(host.to_string()))?;
78
79 debug!("Resolved '{host}' -> {addr}");
80 Ok(addr)
81}
82
83pub async fn resolve_host_port(addr: &str) -> Result<SocketAddr, ResolveError> {
98 if let Ok(socket) = addr.parse::<SocketAddr>() {
100 return Ok(socket);
101 }
102
103 info!("Resolving address '{addr}' via DNS...");
105 let resolved = tokio::time::timeout(DNS_TIMEOUT, tokio::net::lookup_host(addr))
106 .await
107 .map_err(|_| ResolveError::Timeout(addr.to_string()))?
108 .map_err(ResolveError::LookupFailed)?
109 .next()
111 .ok_or_else(|| ResolveError::NoResults(addr.to_string()))?;
112
113 debug!("Resolved '{addr}' -> {resolved}");
114 Ok(resolved)
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[tokio::test]
122 async fn resolve_ipv4_address() {
123 let addr = resolve_host("127.0.0.1", 3333).await.unwrap();
124 assert_eq!(addr, SocketAddr::new("127.0.0.1".parse().unwrap(), 3333));
125 }
126
127 #[tokio::test]
128 async fn resolve_ipv6_address() {
129 let addr = resolve_host("::1", 3333).await.unwrap();
130 assert_eq!(addr, SocketAddr::new("::1".parse().unwrap(), 3333));
131 }
132
133 #[tokio::test]
134 async fn resolve_localhost_hostname() {
135 let addr = resolve_host("localhost", 3333).await.unwrap();
136 assert_eq!(addr.port(), 3333);
138 assert!(addr.ip().is_loopback());
139 }
140
141 #[tokio::test]
142 async fn resolve_invalid_hostname_fails() {
143 let result = resolve_host("this.hostname.definitely.does.not.exist.invalid", 3333).await;
144 assert!(result.is_err());
145 }
146
147 #[tokio::test]
148 async fn resolve_host_port_ipv4() {
149 let addr = resolve_host_port("127.0.0.1:3333").await.unwrap();
150 assert_eq!(addr, SocketAddr::new("127.0.0.1".parse().unwrap(), 3333));
151 }
152
153 #[tokio::test]
154 async fn resolve_host_port_localhost() {
155 let addr = resolve_host_port("localhost:3333").await.unwrap();
156 assert_eq!(addr.port(), 3333);
157 assert!(addr.ip().is_loopback());
158 }
159
160 #[tokio::test]
161 async fn resolve_host_port_invalid_fails() {
162 let result =
163 resolve_host_port("this.hostname.definitely.does.not.exist.invalid:3333").await;
164 assert!(result.is_err());
165 }
166}