1use std::future::{self, Future};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::pin::Pin;
6use std::sync::RwLock;
7use std::task::{ready, Context, Poll};
8
9use once_cell::sync::Lazy;
10use tokio::task::JoinHandle;
11use trust_dns_resolver::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts};
12use trust_dns_resolver::Resolver;
13
14type Result<T, E = std::io::Error> = std::result::Result<T, E>;
15type ReadyFuture<T> = future::Ready<Result<T>>;
16
17macro_rules! invalid_input {
18 ($msg:expr) => {
19 std::io::Error::new(std::io::ErrorKind::InvalidInput, $msg)
20 };
21}
22
23macro_rules! try_opt {
24 ($call:expr, $msg:expr) => {
25 match $call {
26 Some(v) => v,
27 None => Err(invalid_input!($msg))?,
28 }
29 };
30}
31
32macro_rules! try_ret {
33 ($call:expr, $msg:expr) => {
34 match $call {
35 Ok(v) => v,
36 Err(e) => Err(invalid_input!(format!("{} ,detail:{e}", $msg)))?,
37 }
38 };
39}
40
41pub trait ToSocketAddrs {
48 type Iter: Iterator<Item = SocketAddr> + Send + 'static;
50 type Future: Future<Output = Result<Self::Iter>> + Send + 'static;
52
53 fn to_socket_addrs(&self) -> Self::Future;
55}
56
57impl ToSocketAddrs for SocketAddr {
58 type Future = ReadyFuture<Self::Iter>;
59 type Iter = std::option::IntoIter<SocketAddr>;
60
61 fn to_socket_addrs(&self) -> Self::Future {
62 let iter = Some(*self).into_iter();
63 future::ready(Ok(iter))
64 }
65}
66
67impl ToSocketAddrs for SocketAddrV4 {
68 type Future = ReadyFuture<Self::Iter>;
69 type Iter = std::option::IntoIter<SocketAddr>;
70
71 fn to_socket_addrs(&self) -> Self::Future {
72 SocketAddr::V4(*self).to_socket_addrs()
73 }
74}
75
76impl ToSocketAddrs for SocketAddrV6 {
77 type Future = ReadyFuture<Self::Iter>;
78 type Iter = std::option::IntoIter<SocketAddr>;
79
80 fn to_socket_addrs(&self) -> Self::Future {
81 SocketAddr::V6(*self).to_socket_addrs()
82 }
83}
84
85impl ToSocketAddrs for (IpAddr, u16) {
86 type Future = ReadyFuture<Self::Iter>;
87 type Iter = std::option::IntoIter<SocketAddr>;
88
89 fn to_socket_addrs(&self) -> Self::Future {
90 let iter = Some(SocketAddr::from(*self)).into_iter();
91 future::ready(Ok(iter))
92 }
93}
94
95impl ToSocketAddrs for (Ipv4Addr, u16) {
96 type Future = ReadyFuture<Self::Iter>;
97 type Iter = std::option::IntoIter<SocketAddr>;
98
99 fn to_socket_addrs(&self) -> Self::Future {
100 let (ip, port) = *self;
101 SocketAddrV4::new(ip, port).to_socket_addrs()
102 }
103}
104
105impl ToSocketAddrs for (Ipv6Addr, u16) {
106 type Future = ReadyFuture<Self::Iter>;
107 type Iter = std::option::IntoIter<SocketAddr>;
108
109 fn to_socket_addrs(&self) -> Self::Future {
110 let (ip, port) = *self;
111 SocketAddrV6::new(ip, port, 0, 0).to_socket_addrs()
112 }
113}
114
115impl ToSocketAddrs for &[SocketAddr] {
116 type Future = ReadyFuture<Self::Iter>;
117 type Iter = std::vec::IntoIter<SocketAddr>;
118
119 fn to_socket_addrs(&self) -> Self::Future {
120 #[inline]
121 fn slice_to_vec(addrs: &[SocketAddr]) -> Vec<SocketAddr> {
122 addrs.to_vec()
123 }
124
125 let iter = slice_to_vec(self).into_iter();
131 future::ready(Ok(iter))
132 }
133}
134
135#[derive(Debug)]
137pub enum OneOrMore {
138 One(std::option::IntoIter<SocketAddr>),
140 More(std::vec::IntoIter<SocketAddr>),
142}
143
144#[derive(Debug)]
145enum State {
146 Ready(Option<SocketAddr>),
147 Blocking(JoinHandle<Result<std::vec::IntoIter<SocketAddr>>>),
148}
149
150#[derive(Debug)]
152pub struct MaybeReady(State);
153
154impl Future for MaybeReady {
155 type Output = Result<OneOrMore>;
156
157 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158 match self.0 {
159 State::Ready(ref mut i) => {
160 let iter = OneOrMore::One(i.take().into_iter());
161 Poll::Ready(Ok(iter))
162 }
163 State::Blocking(ref mut rx) => {
164 let res = ready!(Pin::new(rx).poll(cx))?.map(OneOrMore::More);
165
166 Poll::Ready(res)
167 }
168 }
169 }
170}
171
172impl Iterator for OneOrMore {
173 type Item = SocketAddr;
174
175 fn next(&mut self) -> Option<Self::Item> {
176 match self {
177 OneOrMore::One(i) => i.next(),
178 OneOrMore::More(i) => i.next(),
179 }
180 }
181
182 fn size_hint(&self) -> (usize, Option<usize>) {
183 match self {
184 OneOrMore::One(i) => i.size_hint(),
185 OneOrMore::More(i) => i.size_hint(),
186 }
187 }
188}
189
190impl ToSocketAddrs for str {
193 type Future = MaybeReady;
194 type Iter = OneOrMore;
195
196 fn to_socket_addrs(&self) -> Self::Future {
197 let res: Result<SocketAddr, _> = self.parse();
199 if let Ok(addr) = res {
200 return MaybeReady(State::Ready(Some(addr)));
201 }
202
203 let s = self.to_owned();
205
206 MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
207 get_socket_addrs_inner(&s).map(|v| v.into_iter())
210 })))
211 }
212}
213
214impl<T> ToSocketAddrs for &T
217where
218 T: ToSocketAddrs + ?Sized,
219{
220 type Future = T::Future;
221 type Iter = T::Iter;
222
223 fn to_socket_addrs(&self) -> Self::Future {
224 (**self).to_socket_addrs()
225 }
226}
227
228impl ToSocketAddrs for (&str, u16) {
231 type Future = MaybeReady;
232 type Iter = OneOrMore;
233
234 fn to_socket_addrs(&self) -> Self::Future {
235 let (host, port) = *self;
236
237 if let Ok(addr) = host.parse::<Ipv4Addr>() {
239 let addr = SocketAddrV4::new(addr, port);
240 let addr = SocketAddr::V4(addr);
241
242 return MaybeReady(State::Ready(Some(addr)));
243 }
244
245 if let Ok(addr) = host.parse::<Ipv6Addr>() {
246 let addr = SocketAddrV6::new(addr, port, 0, 0);
247 let addr = SocketAddr::V6(addr);
248
249 return MaybeReady(State::Ready(Some(addr)));
250 }
251
252 let host = host.to_owned();
253
254 MaybeReady(State::Blocking(tokio::task::spawn_blocking(move || {
255 get_socket_addrs_from_host_port_inner(&host, port).map(|v| v.into_iter())
256 })))
257 }
258}
259
260impl ToSocketAddrs for (String, u16) {
263 type Future = MaybeReady;
264 type Iter = OneOrMore;
265
266 fn to_socket_addrs(&self) -> Self::Future {
267 (self.0.as_str(), self.1).to_socket_addrs()
268 }
269}
270
271impl ToSocketAddrs for String {
274 type Future = <str as ToSocketAddrs>::Future;
275 type Iter = <str as ToSocketAddrs>::Iter;
276
277 fn to_socket_addrs(&self) -> Self::Future {
278 self[..].to_socket_addrs()
279 }
280}
281
282const DEFAULT_DNS_SERVER_GROUP: &[IpAddr] = &[
284 IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), IpAddr::V4(Ipv4Addr::new(223, 6, 6, 6)),
286 IpAddr::V4(Ipv4Addr::new(119, 29, 29, 29)), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), IpAddr::V6(Ipv6Addr::new(0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888)), ];
290
291static DNS_SERVER_GROUP: Lazy<RwLock<Vec<IpAddr>>> =
293 Lazy::new(|| RwLock::new(DEFAULT_DNS_SERVER_GROUP.to_vec()));
294
295const DNS_QUERY_PORT: u16 = 53;
296
297#[inline]
298fn get_custom_resolver() -> Result<Resolver> {
299 let dns_group = try_ret!(DNS_SERVER_GROUP.read(), "read dns server");
300 Resolver::new(
301 ResolverConfig::from_parts(
302 None,
303 vec![],
304 NameServerConfigGroup::from_ips_clear(&dns_group, DNS_QUERY_PORT, true),
305 ),
306 ResolverOpts::default(),
307 )
308 .map_err(|e| invalid_input!(format!("create custom resolver error:{e}")))
309}
310
311#[inline]
314pub fn set_custom_dns_server(dns_addrs: &[IpAddr]) -> Result<()> {
315 let mut writer = DNS_SERVER_GROUP
316 .write()
317 .map_err(|e| invalid_input!(format!("get dns server writer, detail:{e}")))?;
318 let servers: &mut Vec<IpAddr> = writer.as_mut();
319 servers.clear();
320 dns_addrs.iter().for_each(|&a| servers.push(a));
321 Ok(())
322}
323
324pub async fn get_ip_addrs(s: &str) -> Result<Vec<IpAddr>> {
327 let s = s.to_owned();
328 tokio::task::spawn_blocking(move || get_ip_addrs_inner(&s))
329 .await
330 .map_err(|_| invalid_input!("get ip addrs"))?
331}
332
333fn get_ip_addrs_inner(s: &str) -> Result<Vec<IpAddr>> {
336 thread_local! {
337 static RESOLVER:Option<Resolver> = {
338 match get_custom_resolver(){
339 Ok(v) => Some(v),
340 Err(e) => {
341 tracing::error!("create resolver error:{e}");
342 None
343 },
344 }
345 };
346 }
347 let result = RESOLVER.with(|r| r.as_ref().map(|r| r.lookup_ip(s)));
348 try_opt!(result, "custom resolver not exist")
349 .map(|v| v.into_iter().collect())
350 .map_err(|e| invalid_input!(e))
351}
352
353#[inline]
355pub async fn get_socket_addrs_from_host_port(s: &str, port: u16) -> Result<Vec<SocketAddr>> {
356 let s = s.to_owned();
357 tokio::task::spawn_blocking(move || get_socket_addrs_from_host_port_inner(&s, port))
358 .await
359 .map_err(|_| invalid_input!("get socket addrs from host port"))?
360}
361
362#[inline]
365fn get_socket_addrs_from_host_port_inner(host: &str, port: u16) -> Result<Vec<SocketAddr>> {
366 match get_ip_addrs_inner(host) {
367 Ok(r) => Ok(r.into_iter().map(|ip| SocketAddr::new(ip, port)).collect()),
368 Err(_) => std::net::ToSocketAddrs::to_socket_addrs(&(host, port)).map(|v| v.collect()),
370 }
371}
372
373#[inline]
375pub async fn get_socket_addrs(s: &str) -> Result<Vec<SocketAddr>> {
376 let s = s.to_owned();
377 tokio::task::spawn_blocking(move || get_socket_addrs_inner(&s))
378 .await
379 .map_err(|_| invalid_input!("get socket addrs"))?
380}
381
382#[inline]
385fn get_socket_addrs_inner(s: &str) -> Result<Vec<SocketAddr>> {
386 let (host, port_str) = try_opt!(s.rsplit_once(':'), "invalid socket address");
387 let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
388 get_socket_addrs_from_host_port_inner(host, port)
389}
390
391pub async fn each_addr<A: ToSocketAddrs, F, T, R>(addr: A, f: F) -> Result<T>
393where
394 F: Fn(SocketAddr) -> R,
395 R: std::future::Future<Output = Result<T>>,
396{
397 let addrs = match addr.to_socket_addrs().await {
398 Ok(addrs) => addrs,
399 Err(e) => return Err(e),
400 };
401 let mut last_err = None;
402 for addr in addrs {
403 match f(addr).await {
404 Ok(l) => return Ok(l),
405 Err(e) => last_err = Some(e),
406 }
407 }
408 Err(last_err.unwrap_or_else(|| {
409 std::io::Error::new(
410 std::io::ErrorKind::InvalidInput,
411 "could not resolve to any addresses",
412 )
413 }))
414}