rust_rcs_core/dns/
mod.rs

1// Copyright 2023 宋昊文
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15extern crate hickory_client;
16extern crate tokio;
17extern crate tokio_stream;
18
19use std::collections::HashMap;
20use std::net::{IpAddr, SocketAddr};
21use std::str::FromStr;
22use std::sync::{Arc, Mutex};
23
24use tokio::net::UdpSocket;
25use tokio::runtime::Runtime;
26use tokio::sync::mpsc;
27use tokio::time;
28use tokio::time::{Duration, Instant};
29
30use tokio_stream::wrappers::ReceiverStream;
31
32use hickory_client::client::{AsyncClient, ClientHandle as _};
33use hickory_client::rr::{DNSClass, Name, RData, RecordType};
34use hickory_client::udp::UdpClientStream;
35
36use crate::ffi::log::platform_log;
37use crate::util::raw_string::StrEq;
38
39const LOG_TAG: &str = "dns";
40
41const REQUEST_BUFFER_SIZE: usize = 16;
42
43pub struct DnsConfig {
44    pub server_addrs: Vec<SocketAddr>,
45}
46
47impl Clone for DnsConfig {
48    fn clone(&self) -> Self {
49        DnsConfig {
50            server_addrs: self.server_addrs.clone(),
51        }
52    }
53}
54
55pub enum DnsRequest {
56    Default(DnsConfig, String, mpsc::Sender<IpAddr>),
57    SipNaptr(DnsConfig, String, String, mpsc::Sender<(String, u16)>),
58}
59
60pub struct DnsClient {
61    tx: mpsc::Sender<DnsRequest>,
62    cache_a_aaaa: Arc<Mutex<HashMap<String, Vec<(Instant, IpAddr)>>>>,
63    cache_naptr_srv: Arc<Mutex<HashMap<String, Vec<(Instant, String, String, u16)>>>>,
64}
65
66impl DnsClient {
67    pub fn new(rt: Arc<Runtime>) -> DnsClient {
68        let (tx, mut rx) = mpsc::channel::<DnsRequest>(REQUEST_BUFFER_SIZE);
69
70        let cache_a_aaaa: Arc<Mutex<HashMap<String, Vec<(Instant, IpAddr)>>>> =
71            Arc::new(Mutex::new(HashMap::new()));
72        let cache_a_aaaa_ = Arc::clone(&cache_a_aaaa);
73
74        let cache_naptr_srv: Arc<Mutex<HashMap<String, Vec<(Instant, String, String, u16)>>>> =
75            Arc::new(Mutex::new(HashMap::new()));
76        let cache_naptr_srv_ = Arc::clone(&cache_naptr_srv);
77
78        rt.spawn(async move {
79            let cache_a_aaaa = cache_a_aaaa_;
80            let cache_naptr_srv = cache_naptr_srv_;
81
82            'next: loop {
83                match rx.recv().await {
84                    Some(dr) => {
85                        match dr {
86                            DnsRequest::Default(config, host, tx) => {
87
88                                platform_log(LOG_TAG, "getting dns request");
89
90                                let now = Instant::now();
91                                let mut cached = Vec::new();
92
93                                {
94                                    let guard = cache_a_aaaa.lock().unwrap();
95
96                                    if let Some(cached_addresses) = guard.get(&host) {
97                                        for (expire, ip) in cached_addresses {
98                                            if expire > &now {
99                                                cached.push(*ip);
100                                            }
101                                        }
102                                    }
103                                }
104
105                                if cached.is_empty() {
106                                    let cache = Arc::clone(&cache_a_aaaa);
107
108                                    platform_log(LOG_TAG, "start dns");
109
110                                    tokio::spawn(async move {
111                                        for server_addr in config.server_addrs {
112                                            let stream = UdpClientStream::<UdpSocket>::new(server_addr);
113
114                                            let mut successful = false;
115
116                                            if let Ok((mut client, bg)) = AsyncClient::connect(stream).await
117                                            {
118                                                platform_log(LOG_TAG, "dns server connected");
119
120                                                tokio::spawn(async move {
121                                                    bg.await.unwrap();
122                                                });
123
124                                                if let Ok(name) = Name::from_str(&host) {
125                                                    platform_log(LOG_TAG, "start AAAA query");
126
127                                                    match time::timeout_at(
128                                                        Instant::now() + Duration::from_secs(15),
129                                                        client.query(name, DNSClass::IN, RecordType::AAAA),
130                                                    )
131                                                    .await
132                                                    {
133                                                        Ok(r) => {
134                                                            if let Ok(resp) = r {
135                                                                for r in resp.answers() {
136                                                                    if let Some(&RData::AAAA(addr)) =
137                                                                        r.data()
138                                                                    {
139                                                                        successful = true;
140
141                                                                        let addr = IpAddr::V6(addr.0);
142                                                                        let ttl = r.ttl();
143
144                                                                        if ttl > 0 {
145                                                                            update_a_aaaa_cache(
146                                                                                &cache, addr, ttl, &host,
147                                                                            );
148                                                                        }
149
150                                                                        match tx.send(addr).await {
151                                                                            Ok(()) => {}
152
153                                                                            Err(_) => {
154                                                                                return;
155                                                                            }
156                                                                        }
157                                                                    }
158                                                                }
159                                                            }
160                                                        }
161
162                                                        Err(_) => {
163                                                            platform_log(LOG_TAG, "dns timeout");
164                                                        }
165                                                    }
166                                                }
167
168                                                if let Ok(name) = Name::from_str(&host) {
169                                                    platform_log(LOG_TAG, "start A query");
170
171                                                    match time::timeout_at(
172                                                        Instant::now() + Duration::from_secs(15),
173                                                        client.query(name, DNSClass::IN, RecordType::A),
174                                                    )
175                                                    .await
176                                                    {
177                                                        Ok(r) => {
178                                                            if let Ok(resp) = r {
179                                                                for r in resp.answers() {
180                                                                    if let Some(&RData::A(addr)) = r.data()
181                                                                    {
182                                                                        successful = true;
183
184                                                                        let addr = IpAddr::V4(addr.0);
185                                                                        let ttl = r.ttl();
186
187                                                                        if ttl > 0 {
188                                                                            update_a_aaaa_cache(
189                                                                                &cache, addr, ttl, &host,
190                                                                            );
191                                                                        }
192
193                                                                        match tx.send(addr).await {
194                                                                            Ok(()) => {}
195
196                                                                            Err(_) => {
197                                                                                return;
198                                                                            }
199                                                                        }
200                                                                    }
201                                                                }
202                                                            }
203                                                        }
204
205                                                        Err(_) => {
206                                                            platform_log(LOG_TAG, "dns timeout");
207                                                        }
208                                                    }
209                                                }
210                                            }
211
212                                            if successful {
213                                                return;
214                                            }
215                                        }
216                                    });
217                                } else {
218                                    for addr in cached {
219                                        match tx.send(addr).await {
220                                            Ok(()) => {}
221
222                                            Err(_) => {
223                                                continue 'next;
224                                            }
225                                        }
226                                    }
227                                }
228                            }
229
230                            DnsRequest::SipNaptr(config, q_name, q_service_type, tx) => {
231
232                                platform_log(LOG_TAG, "getting dns request");
233
234                                let now = Instant::now();
235                                let mut cached = Vec::new();
236
237                                {
238                                    let guard = cache_naptr_srv.lock().unwrap();
239
240                                    if let Some(cached_addresses) = guard.get(&q_name) {
241                                        for (expire, service_type, target, port) in cached_addresses {
242                                            if expire > &now && q_service_type.eq(service_type) {
243                                                cached.push((String::from(target), *port));
244                                            }
245                                        }
246                                    }
247                                }
248
249                                if cached.is_empty() {
250
251                                    let cache = Arc::clone(&cache_naptr_srv);
252
253                                    tokio::spawn(async move {
254                                        for server_addr in config.server_addrs {
255                                            let stream = UdpClientStream::<UdpSocket>::new(server_addr);
256
257                                            let mut successful = false;
258
259                                            if let Ok((mut client, bg)) = AsyncClient::connect(stream).await
260                                            {
261                                                platform_log(LOG_TAG, "dns server connected");
262
263                                                tokio::spawn(async move {
264                                                    bg.await.unwrap();
265                                                });
266
267                                                if let Ok(name) = Name::from_str(&q_name) {
268                                                    platform_log(LOG_TAG, "start NAPTR query");
269
270                                                    match time::timeout_at(
271                                                        Instant::now() + Duration::from_secs(15),
272                                                        client.query(name, DNSClass::IN, RecordType::NAPTR),
273                                                    )
274                                                    .await
275                                                    {
276                                                        Ok(r) => {
277                                                            if let Ok(resp) = r {
278                                                                for r in resp.answers() {
279                                                                    if let Some(rd) =
280                                                                        r.data()
281                                                                    {
282                                                                        if let RData::NAPTR(ptr) = rd {
283
284                                                                            if ptr.services().equals_string(&q_service_type, false) {
285
286                                                                                let replacement = ptr.replacement().clone();
287
288                                                                                platform_log(LOG_TAG, format!("naptr replacement: {:?}", &replacement));
289
290                                                                                platform_log(LOG_TAG, "start SRV query");
291
292                                                                                match time::timeout_at(
293                                                                                    Instant::now() + Duration::from_secs(15),
294                                                                                    client.query(replacement, DNSClass::IN, RecordType::SRV),
295                                                                                )
296                                                                                .await {
297                                                                                    Ok(r) => {
298
299                                                                                        if let Ok(resp) = r {
300                                                                                            for r in resp.answers() {
301                                                                                                if let Some(rd) =
302                                                                                                    r.data()
303                                                                                                {
304                                                                                                    if let RData::SRV(srv) = rd {
305                                                                                                        successful = true;
306
307                                                                                                        let target = srv.target();
308                                                                                                        platform_log(LOG_TAG, format!("srv target: {:?}", target));
309                                                                                                        let target = target.to_string();
310                                                                                                        let target = if target.ends_with('.') {
311                                                                                                            String::from(&target[0..target.len() - 1])
312                                                                                                        } else {
313                                                                                                            target
314                                                                                                        };
315                                                                                                        let port = srv.port();
316
317                                                                                                        let ttl = r.ttl();
318
319                                                                                                        if ttl > 0 {
320                                                                                                            update_naptr_srv_cache(
321                                                                                                                &cache, &target, port, ttl, &q_name, &q_service_type,
322                                                                                                            );
323                                                                                                        }
324
325                                                                                                        match tx.send((target, port)).await {
326                                                                                                            Ok(()) => {}
327
328                                                                                                            Err(_) => {
329                                                                                                                return;
330                                                                                                            }
331                                                                                                        }
332                                                                                                    }
333                                                                                                }
334                                                                                            }
335                                                                                        }
336                                                                                    }
337
338                                                                                    Err(_) => {
339                                                                                        platform_log(LOG_TAG, "dns timeout");
340                                                                                    }
341                                                                                }
342                                                                            }
343                                                                        }
344                                                                    }
345                                                                }
346                                                            }
347                                                        }
348
349                                                        Err(_) => {
350                                                            platform_log(LOG_TAG, "dns timeout");
351                                                        }
352                                                    }
353                                                }
354                                            }
355
356                                            if successful {
357                                                return;
358                                            }
359                                        }
360                                    });
361
362                                } else {
363                                    for res in cached {
364                                        match tx.send(res).await {
365                                            Ok(()) => {}
366
367                                            Err(_) => {
368                                                continue 'next;
369                                            }
370                                        }
371                                    }
372                                }
373                            }
374
375                        }
376                    },
377                    None => break,
378                }
379            }
380        });
381
382        DnsClient {
383            tx,
384            cache_a_aaaa,
385            cache_naptr_srv,
386        }
387    }
388
389    pub async fn resolve(
390        &self,
391        dns_config: DnsConfig,
392        host: String,
393    ) -> Result<ReceiverStream<IpAddr>> {
394        let (tx, rx) = mpsc::channel::<IpAddr>(1);
395
396        match self
397            .tx
398            .send(DnsRequest::Default(dns_config, host, tx))
399            .await
400        {
401            Ok(()) => Ok(ReceiverStream::new(rx)),
402
403            Err(_) => Err(ErrorKind::BrokenPipe),
404        }
405    }
406
407    pub async fn resolve_service(
408        &self,
409        dns_config: DnsConfig,
410        domain: String,
411        service_name: String,
412    ) -> Result<ReceiverStream<(String, u16)>> {
413        let (tx, rx) = mpsc::channel::<(String, u16)>(1);
414
415        match self
416            .tx
417            .send(DnsRequest::SipNaptr(dns_config, domain, service_name, tx))
418            .await
419        {
420            Ok(()) => Ok(ReceiverStream::new(rx)),
421
422            Err(_) => Err(ErrorKind::BrokenPipe),
423        }
424    }
425
426    pub fn clear_cache(&self, name: String, rtype: RecordType) {
427        match rtype {
428            RecordType::A | RecordType::AAAA => {
429                let mut guard = self.cache_a_aaaa.lock().unwrap();
430
431                if let Some(cached_addresses) = guard.get_mut(&name) {
432                    let mut i = 0;
433                    while i < cached_addresses.len() {
434                        let (_, ip) = cached_addresses[i];
435                        match (ip, rtype) {
436                            (IpAddr::V4(_), RecordType::A) | (IpAddr::V6(_), RecordType::AAAA) => {
437                                cached_addresses.swap_remove(i);
438                            }
439
440                            _ => {
441                                i = i + 1;
442                            }
443                        }
444                    }
445                }
446            }
447
448            _ => {}
449        }
450    }
451
452    pub fn clear_naptr_srv_cache(&self, q_name: String, q_service_type: String) {
453        let mut guard = self.cache_naptr_srv.lock().unwrap();
454
455        if let Some(cached_addresses) = guard.get_mut(&q_name) {
456            let mut i = 0;
457            while i < cached_addresses.len() {
458                let (_, service_type, _, _) = &cached_addresses[i];
459                if q_service_type.eq(service_type) {
460                    cached_addresses.swap_remove(i);
461                } else {
462                    i = i + 1;
463                }
464            }
465        }
466    }
467}
468
469fn update_a_aaaa_cache(
470    cache: &Arc<Mutex<HashMap<String, Vec<(Instant, IpAddr)>>>>,
471    addr: IpAddr,
472    ttl: u32,
473    host: &String,
474) {
475    let mut guard = cache.lock().unwrap();
476
477    let now = Instant::now();
478    let expire = now + Duration::from_secs(ttl.into());
479    if let Some(ref mut cached_addresses) = guard.get_mut(host) {
480        for (e, ip) in cached_addresses.iter_mut() {
481            match (*ip, addr) {
482                (IpAddr::V4(_), IpAddr::V4(_)) | (IpAddr::V6(_), IpAddr::V6(_)) => {
483                    *e = expire;
484                    *ip = addr;
485                    return;
486                }
487
488                _ => {}
489            }
490        }
491
492        cached_addresses.push((expire, addr));
493    } else {
494        let mut cached_addresses = Vec::new();
495        cached_addresses.push((expire, addr));
496        guard.insert(String::from(host), cached_addresses);
497    }
498}
499
500fn update_naptr_srv_cache(
501    cache: &Arc<Mutex<HashMap<String, Vec<(Instant, String, String, u16)>>>>,
502    r_target: &String,
503    r_port: u16,
504    ttl: u32,
505    q_name: &String,
506    q_service_type: &String,
507) {
508    let mut guard = cache.lock().unwrap();
509
510    let now = Instant::now();
511    let expire = now + Duration::from_secs(ttl.into());
512    if let Some(ref mut cached_addresses) = guard.get_mut(q_name) {
513        for (e, service_type, target, port) in cached_addresses.iter_mut() {
514            if q_service_type.eq(service_type) {
515                *e = expire;
516                *target = r_target.clone();
517                *port = r_port;
518                return;
519            }
520        }
521
522        cached_addresses.push((
523            expire,
524            String::from(q_service_type),
525            String::from(r_target),
526            r_port,
527        ));
528    } else {
529        let mut cached_addresses = Vec::new();
530        cached_addresses.push((
531            expire,
532            String::from(q_service_type),
533            String::from(r_target),
534            r_port,
535        ));
536        guard.insert(String::from(q_name), cached_addresses);
537    }
538}
539
540pub enum ErrorKind {
541    BrokenPipe,
542}
543
544impl Copy for ErrorKind {}
545
546impl Clone for ErrorKind {
547    fn clone(&self) -> ErrorKind {
548        *self
549    }
550}
551
552pub type Result<T> = std::result::Result<T, ErrorKind>;