rust_rcs_core/dns/mod.rs
1// Copyright 2023 宋昊文
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15extern crate hickory_client;
16extern crate tokio;
17extern crate tokio_stream;
18
19use std::collections::HashMap;
20use std::net::{IpAddr, SocketAddr};
21use std::str::FromStr;
22use std::sync::{Arc, Mutex};
23
24use tokio::net::UdpSocket;
25use tokio::runtime::Runtime;
26use tokio::sync::mpsc;
27use tokio::time;
28use tokio::time::{Duration, Instant};
29
30use tokio_stream::wrappers::ReceiverStream;
31
32use hickory_client::client::{AsyncClient, ClientHandle as _};
33use hickory_client::rr::{DNSClass, Name, RData, RecordType};
34use hickory_client::udp::UdpClientStream;
35
36use crate::ffi::log::platform_log;
37use crate::util::raw_string::StrEq;
38
39const LOG_TAG: &str = "dns";
40
41const REQUEST_BUFFER_SIZE: usize = 16;
42
43pub struct DnsConfig {
44 pub server_addrs: Vec<SocketAddr>,
45}
46
47impl Clone for DnsConfig {
48 fn clone(&self) -> Self {
49 DnsConfig {
50 server_addrs: self.server_addrs.clone(),
51 }
52 }
53}
54
55pub enum DnsRequest {
56 Default(DnsConfig, String, mpsc::Sender<IpAddr>),
57 SipNaptr(DnsConfig, String, String, mpsc::Sender<(String, u16)>),
58}
59
60pub struct DnsClient {
61 tx: mpsc::Sender<DnsRequest>,
62 cache_a_aaaa: Arc<Mutex<HashMap<String, Vec<(Instant, IpAddr)>>>>,
63 cache_naptr_srv: Arc<Mutex<HashMap<String, Vec<(Instant, String, String, u16)>>>>,
64}
65
66impl DnsClient {
67 pub fn new(rt: Arc<Runtime>) -> DnsClient {
68 let (tx, mut rx) = mpsc::channel::<DnsRequest>(REQUEST_BUFFER_SIZE);
69
70 let cache_a_aaaa: Arc<Mutex<HashMap<String, Vec<(Instant, IpAddr)>>>> =
71 Arc::new(Mutex::new(HashMap::new()));
72 let cache_a_aaaa_ = Arc::clone(&cache_a_aaaa);
73
74 let cache_naptr_srv: Arc<Mutex<HashMap<String, Vec<(Instant, String, String, u16)>>>> =
75 Arc::new(Mutex::new(HashMap::new()));
76 let cache_naptr_srv_ = Arc::clone(&cache_naptr_srv);
77
78 rt.spawn(async move {
79 let cache_a_aaaa = cache_a_aaaa_;
80 let cache_naptr_srv = cache_naptr_srv_;
81
82 'next: loop {
83 match rx.recv().await {
84 Some(dr) => {
85 match dr {
86 DnsRequest::Default(config, host, tx) => {
87
88 platform_log(LOG_TAG, "getting dns request");
89
90 let now = Instant::now();
91 let mut cached = Vec::new();
92
93 {
94 let guard = cache_a_aaaa.lock().unwrap();
95
96 if let Some(cached_addresses) = guard.get(&host) {
97 for (expire, ip) in cached_addresses {
98 if expire > &now {
99 cached.push(*ip);
100 }
101 }
102 }
103 }
104
105 if cached.is_empty() {
106 let cache = Arc::clone(&cache_a_aaaa);
107
108 platform_log(LOG_TAG, "start dns");
109
110 tokio::spawn(async move {
111 for server_addr in config.server_addrs {
112 let stream = UdpClientStream::<UdpSocket>::new(server_addr);
113
114 let mut successful = false;
115
116 if let Ok((mut client, bg)) = AsyncClient::connect(stream).await
117 {
118 platform_log(LOG_TAG, "dns server connected");
119
120 tokio::spawn(async move {
121 bg.await.unwrap();
122 });
123
124 if let Ok(name) = Name::from_str(&host) {
125 platform_log(LOG_TAG, "start AAAA query");
126
127 match time::timeout_at(
128 Instant::now() + Duration::from_secs(15),
129 client.query(name, DNSClass::IN, RecordType::AAAA),
130 )
131 .await
132 {
133 Ok(r) => {
134 if let Ok(resp) = r {
135 for r in resp.answers() {
136 if let Some(&RData::AAAA(addr)) =
137 r.data()
138 {
139 successful = true;
140
141 let addr = IpAddr::V6(addr.0);
142 let ttl = r.ttl();
143
144 if ttl > 0 {
145 update_a_aaaa_cache(
146 &cache, addr, ttl, &host,
147 );
148 }
149
150 match tx.send(addr).await {
151 Ok(()) => {}
152
153 Err(_) => {
154 return;
155 }
156 }
157 }
158 }
159 }
160 }
161
162 Err(_) => {
163 platform_log(LOG_TAG, "dns timeout");
164 }
165 }
166 }
167
168 if let Ok(name) = Name::from_str(&host) {
169 platform_log(LOG_TAG, "start A query");
170
171 match time::timeout_at(
172 Instant::now() + Duration::from_secs(15),
173 client.query(name, DNSClass::IN, RecordType::A),
174 )
175 .await
176 {
177 Ok(r) => {
178 if let Ok(resp) = r {
179 for r in resp.answers() {
180 if let Some(&RData::A(addr)) = r.data()
181 {
182 successful = true;
183
184 let addr = IpAddr::V4(addr.0);
185 let ttl = r.ttl();
186
187 if ttl > 0 {
188 update_a_aaaa_cache(
189 &cache, addr, ttl, &host,
190 );
191 }
192
193 match tx.send(addr).await {
194 Ok(()) => {}
195
196 Err(_) => {
197 return;
198 }
199 }
200 }
201 }
202 }
203 }
204
205 Err(_) => {
206 platform_log(LOG_TAG, "dns timeout");
207 }
208 }
209 }
210 }
211
212 if successful {
213 return;
214 }
215 }
216 });
217 } else {
218 for addr in cached {
219 match tx.send(addr).await {
220 Ok(()) => {}
221
222 Err(_) => {
223 continue 'next;
224 }
225 }
226 }
227 }
228 }
229
230 DnsRequest::SipNaptr(config, q_name, q_service_type, tx) => {
231
232 platform_log(LOG_TAG, "getting dns request");
233
234 let now = Instant::now();
235 let mut cached = Vec::new();
236
237 {
238 let guard = cache_naptr_srv.lock().unwrap();
239
240 if let Some(cached_addresses) = guard.get(&q_name) {
241 for (expire, service_type, target, port) in cached_addresses {
242 if expire > &now && q_service_type.eq(service_type) {
243 cached.push((String::from(target), *port));
244 }
245 }
246 }
247 }
248
249 if cached.is_empty() {
250
251 let cache = Arc::clone(&cache_naptr_srv);
252
253 tokio::spawn(async move {
254 for server_addr in config.server_addrs {
255 let stream = UdpClientStream::<UdpSocket>::new(server_addr);
256
257 let mut successful = false;
258
259 if let Ok((mut client, bg)) = AsyncClient::connect(stream).await
260 {
261 platform_log(LOG_TAG, "dns server connected");
262
263 tokio::spawn(async move {
264 bg.await.unwrap();
265 });
266
267 if let Ok(name) = Name::from_str(&q_name) {
268 platform_log(LOG_TAG, "start NAPTR query");
269
270 match time::timeout_at(
271 Instant::now() + Duration::from_secs(15),
272 client.query(name, DNSClass::IN, RecordType::NAPTR),
273 )
274 .await
275 {
276 Ok(r) => {
277 if let Ok(resp) = r {
278 for r in resp.answers() {
279 if let Some(rd) =
280 r.data()
281 {
282 if let RData::NAPTR(ptr) = rd {
283
284 if ptr.services().equals_string(&q_service_type, false) {
285
286 let replacement = ptr.replacement().clone();
287
288 platform_log(LOG_TAG, format!("naptr replacement: {:?}", &replacement));
289
290 platform_log(LOG_TAG, "start SRV query");
291
292 match time::timeout_at(
293 Instant::now() + Duration::from_secs(15),
294 client.query(replacement, DNSClass::IN, RecordType::SRV),
295 )
296 .await {
297 Ok(r) => {
298
299 if let Ok(resp) = r {
300 for r in resp.answers() {
301 if let Some(rd) =
302 r.data()
303 {
304 if let RData::SRV(srv) = rd {
305 successful = true;
306
307 let target = srv.target();
308 platform_log(LOG_TAG, format!("srv target: {:?}", target));
309 let target = target.to_string();
310 let target = if target.ends_with('.') {
311 String::from(&target[0..target.len() - 1])
312 } else {
313 target
314 };
315 let port = srv.port();
316
317 let ttl = r.ttl();
318
319 if ttl > 0 {
320 update_naptr_srv_cache(
321 &cache, &target, port, ttl, &q_name, &q_service_type,
322 );
323 }
324
325 match tx.send((target, port)).await {
326 Ok(()) => {}
327
328 Err(_) => {
329 return;
330 }
331 }
332 }
333 }
334 }
335 }
336 }
337
338 Err(_) => {
339 platform_log(LOG_TAG, "dns timeout");
340 }
341 }
342 }
343 }
344 }
345 }
346 }
347 }
348
349 Err(_) => {
350 platform_log(LOG_TAG, "dns timeout");
351 }
352 }
353 }
354 }
355
356 if successful {
357 return;
358 }
359 }
360 });
361
362 } else {
363 for res in cached {
364 match tx.send(res).await {
365 Ok(()) => {}
366
367 Err(_) => {
368 continue 'next;
369 }
370 }
371 }
372 }
373 }
374
375 }
376 },
377 None => break,
378 }
379 }
380 });
381
382 DnsClient {
383 tx,
384 cache_a_aaaa,
385 cache_naptr_srv,
386 }
387 }
388
389 pub async fn resolve(
390 &self,
391 dns_config: DnsConfig,
392 host: String,
393 ) -> Result<ReceiverStream<IpAddr>> {
394 let (tx, rx) = mpsc::channel::<IpAddr>(1);
395
396 match self
397 .tx
398 .send(DnsRequest::Default(dns_config, host, tx))
399 .await
400 {
401 Ok(()) => Ok(ReceiverStream::new(rx)),
402
403 Err(_) => Err(ErrorKind::BrokenPipe),
404 }
405 }
406
407 pub async fn resolve_service(
408 &self,
409 dns_config: DnsConfig,
410 domain: String,
411 service_name: String,
412 ) -> Result<ReceiverStream<(String, u16)>> {
413 let (tx, rx) = mpsc::channel::<(String, u16)>(1);
414
415 match self
416 .tx
417 .send(DnsRequest::SipNaptr(dns_config, domain, service_name, tx))
418 .await
419 {
420 Ok(()) => Ok(ReceiverStream::new(rx)),
421
422 Err(_) => Err(ErrorKind::BrokenPipe),
423 }
424 }
425
426 pub fn clear_cache(&self, name: String, rtype: RecordType) {
427 match rtype {
428 RecordType::A | RecordType::AAAA => {
429 let mut guard = self.cache_a_aaaa.lock().unwrap();
430
431 if let Some(cached_addresses) = guard.get_mut(&name) {
432 let mut i = 0;
433 while i < cached_addresses.len() {
434 let (_, ip) = cached_addresses[i];
435 match (ip, rtype) {
436 (IpAddr::V4(_), RecordType::A) | (IpAddr::V6(_), RecordType::AAAA) => {
437 cached_addresses.swap_remove(i);
438 }
439
440 _ => {
441 i = i + 1;
442 }
443 }
444 }
445 }
446 }
447
448 _ => {}
449 }
450 }
451
452 pub fn clear_naptr_srv_cache(&self, q_name: String, q_service_type: String) {
453 let mut guard = self.cache_naptr_srv.lock().unwrap();
454
455 if let Some(cached_addresses) = guard.get_mut(&q_name) {
456 let mut i = 0;
457 while i < cached_addresses.len() {
458 let (_, service_type, _, _) = &cached_addresses[i];
459 if q_service_type.eq(service_type) {
460 cached_addresses.swap_remove(i);
461 } else {
462 i = i + 1;
463 }
464 }
465 }
466 }
467}
468
469fn update_a_aaaa_cache(
470 cache: &Arc<Mutex<HashMap<String, Vec<(Instant, IpAddr)>>>>,
471 addr: IpAddr,
472 ttl: u32,
473 host: &String,
474) {
475 let mut guard = cache.lock().unwrap();
476
477 let now = Instant::now();
478 let expire = now + Duration::from_secs(ttl.into());
479 if let Some(ref mut cached_addresses) = guard.get_mut(host) {
480 for (e, ip) in cached_addresses.iter_mut() {
481 match (*ip, addr) {
482 (IpAddr::V4(_), IpAddr::V4(_)) | (IpAddr::V6(_), IpAddr::V6(_)) => {
483 *e = expire;
484 *ip = addr;
485 return;
486 }
487
488 _ => {}
489 }
490 }
491
492 cached_addresses.push((expire, addr));
493 } else {
494 let mut cached_addresses = Vec::new();
495 cached_addresses.push((expire, addr));
496 guard.insert(String::from(host), cached_addresses);
497 }
498}
499
500fn update_naptr_srv_cache(
501 cache: &Arc<Mutex<HashMap<String, Vec<(Instant, String, String, u16)>>>>,
502 r_target: &String,
503 r_port: u16,
504 ttl: u32,
505 q_name: &String,
506 q_service_type: &String,
507) {
508 let mut guard = cache.lock().unwrap();
509
510 let now = Instant::now();
511 let expire = now + Duration::from_secs(ttl.into());
512 if let Some(ref mut cached_addresses) = guard.get_mut(q_name) {
513 for (e, service_type, target, port) in cached_addresses.iter_mut() {
514 if q_service_type.eq(service_type) {
515 *e = expire;
516 *target = r_target.clone();
517 *port = r_port;
518 return;
519 }
520 }
521
522 cached_addresses.push((
523 expire,
524 String::from(q_service_type),
525 String::from(r_target),
526 r_port,
527 ));
528 } else {
529 let mut cached_addresses = Vec::new();
530 cached_addresses.push((
531 expire,
532 String::from(q_service_type),
533 String::from(r_target),
534 r_port,
535 ));
536 guard.insert(String::from(q_name), cached_addresses);
537 }
538}
539
540pub enum ErrorKind {
541 BrokenPipe,
542}
543
544impl Copy for ErrorKind {}
545
546impl Clone for ErrorKind {
547 fn clone(&self) -> ErrorKind {
548 *self
549 }
550}
551
552pub type Result<T> = std::result::Result<T, ErrorKind>;