uni_stream/
addr.rs

1//! Provide domain name resolution service
2
3use std::future::{self, Future};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::pin::Pin;
6use std::sync::RwLock;
7use std::task::{ready, Context, Poll};
8
9use once_cell::sync::Lazy;
10use tokio::task::JoinHandle;
11use trust_dns_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts};
12use trust_dns_resolver::Resolver;
13
14type Result<T, E = std::io::Error> = std::result::Result<T, E>;
15type ReadyFuture<T> = future::Ready<Result<T>>;
16
17macro_rules! invalid_input {
18    ($msg:expr) => {
19        std::io::Error::new(std::io::ErrorKind::InvalidInput, $msg)
20    };
21}
22
23macro_rules! try_opt {
24    ($call:expr, $msg:expr) => {
25        match $call {
26            Some(v) => v,
27            None => Err(invalid_input!($msg))?,
28        }
29    };
30}
31
32macro_rules! try_ret {
33    ($call:expr, $msg:expr) => {
34        match $call {
35            Ok(v) => v,
36            Err(e) => Err(invalid_input!(format!("{} ,detail:{e}", $msg)))?,
37        }
38    };
39}
40
41/// Converts or resolves without blocking to one or more `SocketAddr` values.
42///
43/// # DNS
44///
45/// Implemented custom DNS resolution for string type `ToSocketAddrs`,
46/// user can change default dns resolution server via [`set_custom_dns_server`].
47pub trait ToSocketAddrs {
48    /// An iterator over SocketAddr
49    type Iter: Iterator<Item = SocketAddr> + Send + 'static;
50    /// Future representing an iterator
51    type Future: Future<Output = Result<Self::Iter>> + Send + 'static;
52
53    /// Returns an asynchronous iterator for getting `SocketAddr`
54    fn to_socket_addrs(&self) -> Self::Future;
55}
56
57impl ToSocketAddrs for SocketAddr {
58    type Future = ReadyFuture<Self::Iter>;
59    type Iter = std::option::IntoIter<SocketAddr>;
60
61    fn to_socket_addrs(&self) -> Self::Future {
62        let iter = Some(*self).into_iter();
63        future::ready(Ok(iter))
64    }
65}
66
67impl ToSocketAddrs for SocketAddrV4 {
68    type Future = ReadyFuture<Self::Iter>;
69    type Iter = std::option::IntoIter<SocketAddr>;
70
71    fn to_socket_addrs(&self) -> Self::Future {
72        SocketAddr::V4(*self).to_socket_addrs()
73    }
74}
75
76impl ToSocketAddrs for SocketAddrV6 {
77    type Future = ReadyFuture<Self::Iter>;
78    type Iter = std::option::IntoIter<SocketAddr>;
79
80    fn to_socket_addrs(&self) -> Self::Future {
81        SocketAddr::V6(*self).to_socket_addrs()
82    }
83}
84
85impl ToSocketAddrs for (IpAddr, u16) {
86    type Future = ReadyFuture<Self::Iter>;
87    type Iter = std::option::IntoIter<SocketAddr>;
88
89    fn to_socket_addrs(&self) -> Self::Future {
90        let iter = Some(SocketAddr::from(*self)).into_iter();
91        future::ready(Ok(iter))
92    }
93}
94
95impl ToSocketAddrs for (Ipv4Addr, u16) {
96    type Future = ReadyFuture<Self::Iter>;
97    type Iter = std::option::IntoIter<SocketAddr>;
98
99    fn to_socket_addrs(&self) -> Self::Future {
100        let (ip, port) = *self;
101        SocketAddrV4::new(ip, port).to_socket_addrs()
102    }
103}
104
105impl ToSocketAddrs for (Ipv6Addr, u16) {
106    type Future = ReadyFuture<Self::Iter>;
107    type Iter = std::option::IntoIter<SocketAddr>;
108
109    fn to_socket_addrs(&self) -> Self::Future {
110        let (ip, port) = *self;
111        SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs()
112    }
113}
114
115impl ToSocketAddrs for &[SocketAddr] {
116    type Future = ReadyFuture<Self::Iter>;
117    type Iter = std::vec::IntoIter<SocketAddr>;
118
119    fn to_socket_addrs(&self) -> Self::Future {
120        #[inline]
121        fn slice_to_vec(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
122            addrs.to_vec()
123        }
124
125        // This uses a helper method because clippy doesn't like the `to_vec()`
126        // call here (it will allocate, whereas `self.iter().copied()` would
127        // not), but it's actually necessary in order to ensure that the
128        // returned iterator is valid for the `'static` lifetime, which the
129        // borrowed `slice::Iter` iterator would not be.
130        let iter = slice_to_vec(self).into_iter();
131        future::ready(Ok(iter))
132    }
133}
134
135/// Represents one or more SockeAddr, since a String type may be a domain name or a direct address
136#[derive(Debug)]
137pub enum OneOrMore {
138    /// Direct address
139    One(std::option::IntoIter<SocketAddr>),
140    /// Addresses resolved by dns
141    More(std::vec::IntoIter<SocketAddr>),
142}
143
144#[derive(Debug)]
145enum State {
146    Ready(Option<SocketAddr>),
147    Blocking(JoinHandle<Result<std::vec::IntoIter<SocketAddr>>>),
148}
149
150/// Implement Future to return asynchronous results
151#[derive(Debug)]
152pub struct MaybeReady(State);
153
154impl Future for MaybeReady {
155    type Output = Result<OneOrMore>;
156
157    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158        match self.0 {
159            State::Ready(ref mut i) => {
160                let iter = OneOrMore::One(i.take().into_iter());
161                Poll::Ready(Ok(iter))
162            }
163            State::Blocking(ref mut rx) => {
164                let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More);
165
166                Poll::Ready(res)
167            }
168        }
169    }
170}
171
172impl Iterator for OneOrMore {
173    type Item = SocketAddr;
174
175    fn next(&mut self) -> Option<Self::Item> {
176        match self {
177            OneOrMore::One(i) => i.next(),
178            OneOrMore::More(i) => i.next(),
179        }
180    }
181
182    fn size_hint(&self) -> (usize, Option<usize>) {
183        match self {
184            OneOrMore::One(i) => i.size_hint(),
185            OneOrMore::More(i) => i.size_hint(),
186        }
187    }
188}
189
190// ===== impl &str =====
191
192impl ToSocketAddrs for str {
193    type Future = MaybeReady;
194    type Iter = OneOrMore;
195
196    fn to_socket_addrs(&self) -> Self::Future {
197        // First check if the input parses as a socket address
198        let res: Result<SocketAddr, _> = self.parse();
199        if let Ok(addr) = res {
200            return MaybeReady(State::Ready(Some(addr)));
201        }
202
203        // Run DNS lookup on the blocking pool
204        let s = self.to_owned();
205
206        MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
207            // Customized dns resolvers are preferred, if a custom resolver does not exist then the
208            // standard library's
209            get_socket_addrs_inner(&s).map(|v| v.into_iter())
210        })))
211    }
212}
213
214/// Implement this trait for &T of type !Sized(such as str), since &T of type Sized all implement it
215/// by default.
216impl<T> ToSocketAddrs for &T
217where
218    T: ToSocketAddrs + ?Sized,
219{
220    type Future = T::Future;
221    type Iter = T::Iter;
222
223    fn to_socket_addrs(&self) -> Self::Future {
224        (**self).to_socket_addrs()
225    }
226}
227
228// ===== impl (&str,u16) =====
229
230impl ToSocketAddrs for (&str, u16) {
231    type Future = MaybeReady;
232    type Iter = OneOrMore;
233
234    fn to_socket_addrs(&self) -> Self::Future {
235        let (host, port) = *self;
236
237        // try to parse the host as a regular IP address first
238        if let Ok(addr) = host.parse::<Ipv4Addr>() {
239            let addr = SocketAddrV4::new(addr, port);
240            let addr = SocketAddr::V4(addr);
241
242            return MaybeReady(State::Ready(Some(addr)));
243        }
244
245        if let Ok(addr) = host.parse::<Ipv6Addr>() {
246            let addr = SocketAddrV6::new(addr, port, 0, 0);
247            let addr = SocketAddr::V6(addr);
248
249            return MaybeReady(State::Ready(Some(addr)));
250        }
251
252        let host = host.to_owned();
253
254        MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
255            get_socket_addrs_from_host_port_inner(&host, port).map(|v| v.into_iter())
256        })))
257    }
258}
259
260// ===== impl (String,u16) =====
261
262impl ToSocketAddrs for (String, u16) {
263    type Future = MaybeReady;
264    type Iter = OneOrMore;
265
266    fn to_socket_addrs(&self) -> Self::Future {
267        (self.0.as_str(), self.1).to_socket_addrs()
268    }
269}
270
271// ===== impl String =====
272
273impl ToSocketAddrs for String {
274    type Future = <str as ToSocketAddrs>::Future;
275    type Iter = <str as ToSocketAddrs>::Iter;
276
277    fn to_socket_addrs(&self) -> Self::Future {
278        self[..].to_socket_addrs()
279    }
280}
281
282/// Default dns resolution server
283const DEFAULT_DNS_SERVER_GROUP: &[IpAddr] = &[
284    IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), // alibaba
285    IpAddr::V4(Ipv4Addr::new(223, 6, 6, 6)),
286    IpAddr::V4(Ipv4Addr::new(119, 29, 29, 29)), // tencent
287    IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),      // google
288    IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)), // google
289];
290
291/// Customized dns resolution server
292static DNS_SERVER_GROUP: Lazy<RwLock<Vec<IpAddr>>> =
293    Lazy::new(|| RwLock::new(DEFAULT_DNS_SERVER_GROUP.to_vec()));
294
295const DNS_QUERY_PORT: u16 = 53;
296
297#[inline]
298fn get_custom_resolver() -> Result<Resolver> {
299    let dns_group = try_ret!(DNS_SERVER_GROUP.read(), "read dns server");
300    Resolver::new(
301        ResolverConfig::from_parts(
302            None,
303            vec![],
304            NameServerConfigGroup::from_ips_clear(&dns_group, DNS_QUERY_PORT, true),
305        ),
306        ResolverOpts::default(),
307    )
308    .map_err(|e| invalid_input!(format!("create custom resolver error:{e}")))
309}
310
311/// Set up DNS servers, use `DEFAULT_DNS_SERVER_GROUP` by default
312/// Note: must be called before the first network connection to be effective
313#[inline]
314pub fn set_custom_dns_server(dns_addrs: &[IpAddr]) -> Result<()> {
315    let mut writer = DNS_SERVER_GROUP
316        .write()
317        .map_err(|e| invalid_input!(format!("get dns server writer, detail:{e}")))?;
318    let servers: &mut Vec<IpAddr> = writer.as_mut();
319    servers.clear();
320    dns_addrs.iter().for_each(|&a| servers.push(a));
321    Ok(())
322}
323
324/// Resolving domain to get `IpAddr`
325/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
326pub async fn get_ip_addrs(s: &str) -> Result<Vec<IpAddr>> {
327    let s = s.to_owned();
328    tokio::task::spawn_blocking(move || get_ip_addrs_inner(&s))
329        .await
330        .map_err(|_| invalid_input!("get ip addrs"))?
331}
332
333/// Resolving domain to get `IpAddr`
334/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
335fn get_ip_addrs_inner(s: &str) -> Result<Vec<IpAddr>> {
336    thread_local! {
337        static RESOLVER:Option<Resolver> = {
338            match get_custom_resolver(){
339                Ok(v) => Some(v),
340                Err(e) => {
341                    tracing::error!("create resolver error:{e}");
342                    None
343                },
344            }
345        };
346    }
347    let result = RESOLVER.with(|r| r.as_ref().map(|r| r.lookup_ip(s)));
348    try_opt!(result, "custom resolver not exist")
349        .map(|v| v.into_iter().collect())
350        .map_err(|e| invalid_input!(e))
351}
352
353/// Resolving domain and port to get `SocketAddr`
354#[inline]
355pub async fn get_socket_addrs_from_host_port(s: &str, port: u16) -> Result<Vec<SocketAddr>> {
356    let s = s.to_owned();
357    tokio::task::spawn_blocking(move || get_socket_addrs_from_host_port_inner(&s, port))
358        .await
359        .map_err(|_| invalid_input!("get socket addrs from host port"))?
360}
361
362/// Resolving domain and port to get `SocketAddr`
363/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
364#[inline]
365fn get_socket_addrs_from_host_port_inner(host: &str, port: u16) -> Result<Vec<SocketAddr>> {
366    match get_ip_addrs_inner(host) {
367        Ok(r) => Ok(r.into_iter().map(|ip| SocketAddr::new(ip, port)).collect()),
368        // Resolve dns properly with the standard library
369        Err(_) => std::net::ToSocketAddrs::to_socket_addrs(&(host, port)).map(|v| v.collect()),
370    }
371}
372
373/// Resolving `domain:port` forms,such as bilibili.com:1080
374#[inline]
375pub async fn get_socket_addrs(s: &str) -> Result<Vec<SocketAddr>> {
376    let s = s.to_owned();
377    tokio::task::spawn_blocking(move || get_socket_addrs_inner(&s))
378        .await
379        .map_err(|_| invalid_input!("get socket addrs"))?
380}
381
382/// Resolving `domain:port` forms,such as bilibili.com:1080
383/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
384#[inline]
385fn get_socket_addrs_inner(s: &str) -> Result<Vec<SocketAddr>> {
386    let (host, port_str) = try_opt!(s.rsplit_once(':'), "invalid socket address");
387    let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
388    get_socket_addrs_from_host_port_inner(host, port)
389}
390
391/// Look up all the socket addr's and pass in the method to get the result
392pub async fn each_addr<A: ToSocketAddrs, F, T, R>(addr: A, f: F) -> Result<T>
393where
394    F: Fn(SocketAddr) -> R,
395    R: std::future::Future<Output = Result<T>>,
396{
397    let addrs = match addr.to_socket_addrs().await {
398        Ok(addrs) => addrs,
399        Err(e) => return Err(e),
400    };
401    let mut last_err = None;
402    for addr in addrs {
403        match f(addr).await {
404            Ok(l) => return Ok(l),
405            Err(e) => last_err = Some(e),
406        }
407    }
408    Err(last_err.unwrap_or_else(|| {
409        std::io::Error::new(
410            std::io::ErrorKind::InvalidInput,
411            "could not resolve to any addresses",
412        )
413    }))
414}