Skip to main content

uni_stream/
addr.rs

1//! Provide domain name resolution service
2
3use std::future::{self, Future};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::pin::Pin;
6use std::sync::RwLock;
7use std::task::{ready, Context, Poll};
8
9use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts};
10use hickory_resolver::name_server::TokioConnectionProvider;
11use hickory_resolver::TokioResolver;
12use once_cell::sync::Lazy;
13use tokio::task::JoinHandle;
14
15type Result<T, E = std::io::Error> = std::result::Result<T, E>;
16type ReadyFuture<T> = future::Ready<Result<T>>;
17
18macro_rules! invalid_input {
19    ($msg:expr) => {
20        std::io::Error::new(std::io::ErrorKind::InvalidInput, $msg)
21    };
22}
23
24macro_rules! try_opt {
25    ($call:expr, $msg:expr) => {
26        match $call {
27            Some(v) => v,
28            None => Err(invalid_input!($msg))?,
29        }
30    };
31}
32
33macro_rules! try_ret {
34    ($call:expr, $msg:expr) => {
35        match $call {
36            Ok(v) => v,
37            Err(e) => Err(invalid_input!(format!("{} ,detail:{e}", $msg)))?,
38        }
39    };
40}
41
42/// Converts or resolves without blocking to one or more `SocketAddr` values.
43///
44/// # DNS
45///
46/// Implemented custom DNS resolution for string type `ToSocketAddrs`,
47/// user can change default dns resolution server via [`set_custom_dns_server`].
48pub trait ToSocketAddrs {
49    /// An iterator over SocketAddr
50    type Iter: Iterator<Item = SocketAddr> + Send + 'static;
51    /// Future representing an iterator
52    type Future: Future<Output = Result<Self::Iter>> + Send + 'static;
53
54    /// Returns an asynchronous iterator for getting `SocketAddr`
55    fn to_socket_addrs(&self) -> Self::Future;
56}
57
58impl ToSocketAddrs for SocketAddr {
59    type Future = ReadyFuture<Self::Iter>;
60    type Iter = std::option::IntoIter<SocketAddr>;
61
62    fn to_socket_addrs(&self) -> Self::Future {
63        let iter = Some(*self).into_iter();
64        future::ready(Ok(iter))
65    }
66}
67
68impl ToSocketAddrs for SocketAddrV4 {
69    type Future = ReadyFuture<Self::Iter>;
70    type Iter = std::option::IntoIter<SocketAddr>;
71
72    fn to_socket_addrs(&self) -> Self::Future {
73        SocketAddr::V4(*self).to_socket_addrs()
74    }
75}
76
77impl ToSocketAddrs for SocketAddrV6 {
78    type Future = ReadyFuture<Self::Iter>;
79    type Iter = std::option::IntoIter<SocketAddr>;
80
81    fn to_socket_addrs(&self) -> Self::Future {
82        SocketAddr::V6(*self).to_socket_addrs()
83    }
84}
85
86impl ToSocketAddrs for (IpAddr, u16) {
87    type Future = ReadyFuture<Self::Iter>;
88    type Iter = std::option::IntoIter<SocketAddr>;
89
90    fn to_socket_addrs(&self) -> Self::Future {
91        let iter = Some(SocketAddr::from(*self)).into_iter();
92        future::ready(Ok(iter))
93    }
94}
95
96impl ToSocketAddrs for (Ipv4Addr, u16) {
97    type Future = ReadyFuture<Self::Iter>;
98    type Iter = std::option::IntoIter<SocketAddr>;
99
100    fn to_socket_addrs(&self) -> Self::Future {
101        let (ip, port) = *self;
102        SocketAddrV4::new(ip, port).to_socket_addrs()
103    }
104}
105
106impl ToSocketAddrs for (Ipv6Addr, u16) {
107    type Future = ReadyFuture<Self::Iter>;
108    type Iter = std::option::IntoIter<SocketAddr>;
109
110    fn to_socket_addrs(&self) -> Self::Future {
111        let (ip, port) = *self;
112        SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs()
113    }
114}
115
116impl ToSocketAddrs for &[SocketAddr] {
117    type Future = ReadyFuture<Self::Iter>;
118    type Iter = std::vec::IntoIter<SocketAddr>;
119
120    fn to_socket_addrs(&self) -> Self::Future {
121        #[inline]
122        fn slice_to_vec(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
123            addrs.to_vec()
124        }
125
126        // This uses a helper method because clippy doesn't like the `to_vec()`
127        // call here (it will allocate, whereas `self.iter().copied()` would
128        // not), but it's actually necessary in order to ensure that the
129        // returned iterator is valid for the `'static` lifetime, which the
130        // borrowed `slice::Iter` iterator would not be.
131        let iter = slice_to_vec(self).into_iter();
132        future::ready(Ok(iter))
133    }
134}
135
136/// Represents one or more SockeAddr, since a String type may be a domain name or a direct address
137#[derive(Debug)]
138pub enum OneOrMore {
139    /// Direct address
140    One(std::option::IntoIter<SocketAddr>),
141    /// Addresses resolved by dns
142    More(std::vec::IntoIter<SocketAddr>),
143}
144
145#[derive(Debug)]
146enum State {
147    Ready(Option<SocketAddr>),
148    Blocking(JoinHandle<Result<std::vec::IntoIter<SocketAddr>>>),
149}
150
151/// Implement Future to return asynchronous results
152#[derive(Debug)]
153pub struct MaybeReady(State);
154
155impl Future for MaybeReady {
156    type Output = Result<OneOrMore>;
157
158    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159        match self.0 {
160            State::Ready(ref mut i) => {
161                let iter = OneOrMore::One(i.take().into_iter());
162                Poll::Ready(Ok(iter))
163            }
164            State::Blocking(ref mut rx) => {
165                let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More);
166
167                Poll::Ready(res)
168            }
169        }
170    }
171}
172
173impl Iterator for OneOrMore {
174    type Item = SocketAddr;
175
176    fn next(&mut self) -> Option<Self::Item> {
177        match self {
178            OneOrMore::One(i) => i.next(),
179            OneOrMore::More(i) => i.next(),
180        }
181    }
182
183    fn size_hint(&self) -> (usize, Option<usize>) {
184        match self {
185            OneOrMore::One(i) => i.size_hint(),
186            OneOrMore::More(i) => i.size_hint(),
187        }
188    }
189}
190
191// ===== impl &str =====
192
193impl ToSocketAddrs for str {
194    type Future = MaybeReady;
195    type Iter = OneOrMore;
196
197    fn to_socket_addrs(&self) -> Self::Future {
198        // First check if the input parses as a socket address
199        let res: Result<SocketAddr, _> = self.parse();
200        if let Ok(addr) = res {
201            return MaybeReady(State::Ready(Some(addr)));
202        }
203
204        // Run DNS lookup on the blocking pool
205        let s = self.to_owned();
206
207        MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
208            // Customized dns resolvers are preferred, if a custom resolver does not exist then the
209            // standard library's
210            get_socket_addrs_inner(&s).map(|v| v.into_iter())
211        })))
212    }
213}
214
215/// Implement this trait for &T of type !Sized(such as str), since &T of type Sized all implement it
216/// by default.
217impl<T> ToSocketAddrs for &T
218where
219    T: ToSocketAddrs + ?Sized,
220{
221    type Future = T::Future;
222    type Iter = T::Iter;
223
224    fn to_socket_addrs(&self) -> Self::Future {
225        (**self).to_socket_addrs()
226    }
227}
228
229// ===== impl (&str,u16) =====
230
231impl ToSocketAddrs for (&str, u16) {
232    type Future = MaybeReady;
233    type Iter = OneOrMore;
234
235    fn to_socket_addrs(&self) -> Self::Future {
236        let (host, port) = *self;
237
238        // try to parse the host as a regular IP address first
239        if let Ok(addr) = host.parse::<Ipv4Addr>() {
240            let addr = SocketAddrV4::new(addr, port);
241            let addr = SocketAddr::V4(addr);
242
243            return MaybeReady(State::Ready(Some(addr)));
244        }
245
246        if let Ok(addr) = host.parse::<Ipv6Addr>() {
247            let addr = SocketAddrV6::new(addr, port, 0, 0);
248            let addr = SocketAddr::V6(addr);
249
250            return MaybeReady(State::Ready(Some(addr)));
251        }
252
253        let host = host.to_owned();
254
255        MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
256            get_socket_addrs_from_host_port_inner(&host, port).map(|v| v.into_iter())
257        })))
258    }
259}
260
261// ===== impl (String,u16) =====
262
263impl ToSocketAddrs for (String, u16) {
264    type Future = MaybeReady;
265    type Iter = OneOrMore;
266
267    fn to_socket_addrs(&self) -> Self::Future {
268        (self.0.as_str(), self.1).to_socket_addrs()
269    }
270}
271
272// ===== impl String =====
273
274impl ToSocketAddrs for String {
275    type Future = <str as ToSocketAddrs>::Future;
276    type Iter = <str as ToSocketAddrs>::Iter;
277
278    fn to_socket_addrs(&self) -> Self::Future {
279        self[..].to_socket_addrs()
280    }
281}
282
283/// Default dns resolution server
284const DEFAULT_DNS_SERVER_GROUP: &[IpAddr] = &[
285    IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), // alibaba
286    IpAddr::V4(Ipv4Addr::new(223, 6, 6, 6)),
287    IpAddr::V4(Ipv4Addr::new(119, 29, 29, 29)), // tencent
288    IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),      // google
289    IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)), // google
290];
291
292/// Customized dns resolution server
293static DNS_SERVER_GROUP: Lazy<RwLock<Vec<IpAddr>>> =
294    Lazy::new(|| RwLock::new(DEFAULT_DNS_SERVER_GROUP.to_vec()));
295
296const DNS_QUERY_PORT: u16 = 53;
297
298#[inline]
299fn get_custom_resolver() -> Result<TokioResolver> {
300    let dns_group = try_ret!(DNS_SERVER_GROUP.read(), "read dns server");
301    let config = ResolverConfig::from_parts(
302        None,
303        vec![],
304        NameServerConfigGroup::from_ips_clear(&dns_group, DNS_QUERY_PORT, true),
305    );
306    let mut builder =
307        TokioResolver::builder_with_config(config, TokioConnectionProvider::default());
308    *builder.options_mut() = ResolverOpts::default();
309    Ok(builder.build())
310}
311
312/// Set up DNS servers, use `DEFAULT_DNS_SERVER_GROUP` by default
313/// Note: must be called before the first network connection to be effective
314#[inline]
315pub fn set_custom_dns_server(dns_addrs: &[IpAddr]) -> Result<()> {
316    let mut writer = DNS_SERVER_GROUP
317        .write()
318        .map_err(|e| invalid_input!(format!("get dns server writer, detail:{e}")))?;
319    let servers: &mut Vec<IpAddr> = writer.as_mut();
320    servers.clear();
321    dns_addrs.iter().for_each(|&a| servers.push(a));
322    Ok(())
323}
324
325/// Resolving domain to get `IpAddr`
326/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
327pub async fn get_ip_addrs(s: &str) -> Result<Vec<IpAddr>> {
328    let s = s.to_owned();
329    tokio::task::spawn_blocking(move || get_ip_addrs_inner(&s))
330        .await
331        .map_err(|_| invalid_input!("get ip addrs"))?
332}
333
334/// Resolving domain to get `IpAddr`
335/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
336fn get_ip_addrs_inner(s: &str) -> Result<Vec<IpAddr>> {
337    thread_local! {
338        static RESOLVER:Option<TokioResolver> = {
339            match get_custom_resolver(){
340                Ok(v) => Some(v),
341                Err(e) => {
342                    tracing::error!("create resolver error:{e}");
343                    None
344                },
345            }
346        };
347    }
348    let resolver = RESOLVER.with(|r| r.clone());
349    let resolver = try_opt!(resolver, "custom resolver not exist");
350    let handle = tokio::runtime::Handle::try_current()
351        .map_err(|_| invalid_input!("tokio runtime not found"))?;
352    let lookup = handle
353        .block_on(resolver.lookup_ip(s))
354        .map_err(|e| invalid_input!(e))?;
355    Ok(lookup.into_iter().collect())
356}
357
358/// Resolving domain and port to get `SocketAddr`
359#[inline]
360pub async fn get_socket_addrs_from_host_port(s: &str, port: u16) -> Result<Vec<SocketAddr>> {
361    let s = s.to_owned();
362    tokio::task::spawn_blocking(move || get_socket_addrs_from_host_port_inner(&s, port))
363        .await
364        .map_err(|_| invalid_input!("get socket addrs from host port"))?
365}
366
367/// Resolving domain and port to get `SocketAddr`
368/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
369#[inline]
370fn get_socket_addrs_from_host_port_inner(host: &str, port: u16) -> Result<Vec<SocketAddr>> {
371    match get_ip_addrs_inner(host) {
372        Ok(r) => Ok(r.into_iter().map(|ip| SocketAddr::new(ip, port)).collect()),
373        // Resolve dns properly with the standard library
374        Err(_) => std::net::ToSocketAddrs::to_socket_addrs(&(host, port)).map(|v| v.collect()),
375    }
376}
377
378/// Resolving `domain:port` forms,such as bilibili.com:1080
379#[inline]
380pub async fn get_socket_addrs(s: &str) -> Result<Vec<SocketAddr>> {
381    let s = s.to_owned();
382    tokio::task::spawn_blocking(move || get_socket_addrs_inner(&s))
383        .await
384        .map_err(|_| invalid_input!("get socket addrs"))?
385}
386
387/// Resolving `domain:port` forms,such as bilibili.com:1080
388/// Note: must run as async runtime,such as [`tokio::task::spawn_blocking`]
389#[inline]
390fn get_socket_addrs_inner(s: &str) -> Result<Vec<SocketAddr>> {
391    let (host, port_str) = try_opt!(s.rsplit_once(':'), "invalid socket address");
392    let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
393    get_socket_addrs_from_host_port_inner(host, port)
394}
395
396/// Look up all the socket addr's and pass in the method to get the result
397pub async fn each_addr<A: ToSocketAddrs, F, T, R>(addr: A, f: F) -> Result<T>
398where
399    F: Fn(SocketAddr) -> R,
400    R: std::future::Future<Output = Result<T>>,
401{
402    let addrs = match addr.to_socket_addrs().await {
403        Ok(addrs) => addrs,
404        Err(e) => return Err(e),
405    };
406    let mut last_err = None;
407    for addr in addrs {
408        match f(addr).await {
409            Ok(l) => return Ok(l),
410            Err(e) => last_err = Some(e),
411        }
412    }
413    Err(last_err.unwrap_or_else(|| {
414        std::io::Error::new(
415            std::io::ErrorKind::InvalidInput,
416            "could not resolve to any addresses",
417        )
418    }))
419}