1use std::fmt::{self, Display};
9use std::future::Future;
10use std::io;
11use std::net::SocketAddr;
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::str::FromStr;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18use bytes::{Buf, Bytes, BytesMut};
19use futures_util::future::{FutureExt, TryFutureExt};
20use futures_util::ready;
21use futures_util::stream::Stream;
22use h2::client::{Connection, SendRequest};
23use http::header::{self, CONTENT_LENGTH};
24use rustls::ClientConfig;
25use tokio_rustls::{
26 client::TlsStream as TokioTlsClientStream, Connect as TokioTlsConnect, TlsConnector,
27};
28use tracing::{debug, warn};
29
30use crate::error::ProtoError;
31use crate::iocompat::AsyncIoStdAsTokio;
32use crate::op::Message;
33use crate::tcp::{Connect, DnsTcpStream};
34use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream};
35
36const ALPN_H2: &[u8] = b"h2";
37
38#[derive(Clone)]
40#[must_use = "futures do nothing unless polled"]
41pub struct HttpsClientStream {
42 name_server_name: Arc<str>,
44 name_server: SocketAddr,
45 h2: SendRequest<Bytes>,
46 is_shutdown: bool,
47}
48
49impl Display for HttpsClientStream {
50 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
51 write!(
52 formatter,
53 "HTTPS({},{})",
54 self.name_server, self.name_server_name
55 )
56 }
57}
58
59impl HttpsClientStream {
60 async fn inner_send(
61 h2: SendRequest<Bytes>,
62 message: Bytes,
63 name_server_name: Arc<str>,
64 ) -> Result<DnsResponse, ProtoError> {
65 let mut h2 = match h2.ready().await {
66 Ok(h2) => h2,
67 Err(err) => {
68 return Err(ProtoError::from(format!("h2 send_request error: {err}")));
70 }
71 };
72
73 let request = crate::https::request::new(&name_server_name, message.remaining());
75
76 let request =
77 request.map_err(|err| ProtoError::from(format!("bad http request: {err}")))?;
78
79 debug!("request: {:#?}", request);
80
81 let (response_future, mut send_stream) = h2
83 .send_request(request, false)
84 .map_err(|err| ProtoError::from(format!("h2 send_request error: {err}")))?;
85
86 send_stream
87 .send_data(message, true)
88 .map_err(|e| ProtoError::from(format!("h2 send_data error: {e}")))?;
89
90 let mut response_stream = response_future
91 .await
92 .map_err(|err| ProtoError::from(format!("received a stream error: {err}")))?;
93
94 debug!("got response: {:#?}", response_stream);
95
96 let content_length = response_stream
98 .headers()
99 .get(CONTENT_LENGTH)
100 .map(|v| v.to_str())
101 .transpose()
102 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?
103 .map(usize::from_str)
104 .transpose()
105 .map_err(|e| ProtoError::from(format!("bad headers received: {e}")))?;
106
107 let mut response_bytes =
111 BytesMut::with_capacity(content_length.unwrap_or(512).clamp(512, 4096));
112
113 while let Some(partial_bytes) = response_stream.body_mut().data().await {
114 let partial_bytes =
115 partial_bytes.map_err(|e| ProtoError::from(format!("bad http request: {e}")))?;
116
117 debug!("got bytes: {}", partial_bytes.len());
118 response_bytes.extend(partial_bytes);
119
120 if let Some(content_length) = content_length {
122 if response_bytes.len() >= content_length {
123 break;
124 }
125 }
126 }
127
128 if let Some(content_length) = content_length {
130 if response_bytes.len() != content_length {
131 return Err(ProtoError::from(format!(
133 "expected byte length: {}, got: {}",
134 content_length,
135 response_bytes.len()
136 )));
137 }
138 }
139
140 if !response_stream.status().is_success() {
142 let error_string = String::from_utf8_lossy(response_bytes.as_ref());
143
144 return Err(ProtoError::from(format!(
146 "http unsuccessful code: {}, message: {}",
147 response_stream.status(),
148 error_string
149 )));
150 } else {
151 {
153 let content_type = response_stream
155 .headers()
156 .get(header::CONTENT_TYPE)
157 .map(|h| {
158 h.to_str().map_err(|err| {
159 ProtoError::from(format!("ContentType header not a string: {err}"))
161 })
162 })
163 .unwrap_or(Ok(crate::https::MIME_APPLICATION_DNS))?;
164
165 if content_type != crate::https::MIME_APPLICATION_DNS {
166 return Err(ProtoError::from(format!(
167 "ContentType unsupported (must be '{}'): '{}'",
168 crate::https::MIME_APPLICATION_DNS,
169 content_type
170 )));
171 }
172 }
173 };
174
175 let message = Message::from_vec(&response_bytes)?;
177 Ok(DnsResponse::new(message, response_bytes.to_vec()))
178 }
179}
180
181impl DnsRequestSender for HttpsClientStream {
182 fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
230 if self.is_shutdown {
231 panic!("can not send messages after stream is shutdown")
232 }
233
234 message.set_id(0);
236
237 let bytes = match message.to_vec() {
238 Ok(bytes) => bytes,
239 Err(err) => return err.into(),
240 };
241
242 Box::pin(Self::inner_send(
243 self.h2.clone(),
244 Bytes::from(bytes),
245 Arc::clone(&self.name_server_name),
246 ))
247 .into()
248 }
249
250 fn shutdown(&mut self) {
251 self.is_shutdown = true;
252 }
253
254 fn is_shutdown(&self) -> bool {
255 self.is_shutdown
256 }
257}
258
259impl Stream for HttpsClientStream {
260 type Item = Result<(), ProtoError>;
261
262 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
263 if self.is_shutdown {
264 return Poll::Ready(None);
265 }
266
267 match self.h2.poll_ready(cx) {
269 Poll::Ready(Ok(())) => Poll::Ready(Some(Ok(()))),
270 Poll::Pending => Poll::Pending,
271 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!(
272 "h2 stream errored: {e}",
273 ))))),
274 }
275 }
276}
277
278#[derive(Clone)]
280pub struct HttpsClientStreamBuilder {
281 client_config: Arc<ClientConfig>,
282 bind_addr: Option<SocketAddr>,
283}
284
285impl HttpsClientStreamBuilder {
286 pub fn with_client_config(client_config: Arc<ClientConfig>) -> Self {
288 Self {
289 client_config,
290 bind_addr: None,
291 }
292 }
293
294 pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
296 self.bind_addr = Some(bind_addr);
297 }
298
299 pub fn build<S: Connect>(
306 mut self,
307 name_server: SocketAddr,
308 dns_name: String,
309 ) -> HttpsClientConnect<S> {
310 if self.client_config.alpn_protocols.is_empty() {
312 let mut client_config = (*self.client_config).clone();
313 client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
314
315 self.client_config = Arc::new(client_config);
316 }
317
318 let tls = TlsConfig {
319 client_config: self.client_config,
320 dns_name: Arc::from(dns_name),
321 };
322
323 let connect = S::connect_with_bind(name_server, self.bind_addr);
324
325 HttpsClientConnect::<S>(HttpsClientConnectState::TcpConnecting {
326 connect,
327 name_server,
328 tls: Some(tls),
329 })
330 }
331
332 pub fn build_with_future<S, F>(
334 future: F,
335 mut client_config: Arc<ClientConfig>,
336 name_server: SocketAddr,
337 dns_name: String,
338 ) -> HttpsClientConnect<S>
339 where
340 S: DnsTcpStream,
341 F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
342 {
343 if client_config.alpn_protocols.is_empty() {
345 let mut client_cfg = (*client_config).clone();
346 client_cfg.alpn_protocols = vec![ALPN_H2.to_vec()];
347
348 client_config = Arc::new(client_cfg);
349 }
350
351 let tls = TlsConfig {
352 client_config,
353 dns_name: Arc::from(dns_name),
354 };
355
356 HttpsClientConnect::<S>(HttpsClientConnectState::TcpConnecting {
357 connect: Box::pin(future),
358 name_server,
359 tls: Some(tls),
360 })
361 }
362}
363
364pub struct HttpsClientConnect<S>(HttpsClientConnectState<S>)
366where
367 S: DnsTcpStream;
368
369impl<S> Future for HttpsClientConnect<S>
370where
371 S: DnsTcpStream,
372{
373 type Output = Result<HttpsClientStream, ProtoError>;
374
375 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
376 self.0.poll_unpin(cx)
377 }
378}
379
380struct TlsConfig {
381 client_config: Arc<ClientConfig>,
382 dns_name: Arc<str>,
383}
384
385#[allow(clippy::large_enum_variant)]
386#[allow(clippy::type_complexity)]
387enum HttpsClientConnectState<S>
388where
389 S: DnsTcpStream,
390{
391 TcpConnecting {
392 connect: Pin<Box<dyn Future<Output = io::Result<S>> + Send>>,
393 name_server: SocketAddr,
394 tls: Option<TlsConfig>,
395 },
396 TlsConnecting {
397 tls: TokioTlsConnect<AsyncIoStdAsTokio<S>>,
399 name_server_name: Arc<str>,
400 name_server: SocketAddr,
401 },
402 H2Handshake {
403 handshake: Pin<
404 Box<
405 dyn Future<
406 Output = Result<
407 (
408 SendRequest<Bytes>,
409 Connection<TokioTlsClientStream<AsyncIoStdAsTokio<S>>, Bytes>,
410 ),
411 h2::Error,
412 >,
413 > + Send,
414 >,
415 >,
416 name_server_name: Arc<str>,
417 name_server: SocketAddr,
418 },
419 Connected(Option<HttpsClientStream>),
420 Errored(Option<ProtoError>),
421}
422
423impl<S> Future for HttpsClientConnectState<S>
424where
425 S: DnsTcpStream,
426{
427 type Output = Result<HttpsClientStream, ProtoError>;
428
429 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
430 loop {
431 let next = match *self {
432 Self::TcpConnecting {
433 ref mut connect,
434 name_server,
435 ref mut tls,
436 } => {
437 let tcp = ready!(connect.poll_unpin(cx))?;
438
439 debug!("tcp connection established to: {}", name_server);
440 let tls = tls
441 .take()
442 .expect("programming error, tls should not be None here");
443 let name_server_name = Arc::clone(&tls.dns_name);
444
445 match tls.dns_name.as_ref().try_into() {
446 Ok(dns_name) => {
447 let tls = TlsConnector::from(tls.client_config);
448 let tls = tls.connect(dns_name, AsyncIoStdAsTokio(tcp));
449 Self::TlsConnecting {
450 name_server_name,
451 name_server,
452 tls,
453 }
454 }
455 Err(_) => Self::Errored(Some(ProtoError::from(format!(
456 "bad dns_name: {}",
457 &tls.dns_name
458 )))),
459 }
460 }
461 Self::TlsConnecting {
462 ref name_server_name,
463 name_server,
464 ref mut tls,
465 } => {
466 let tls = ready!(tls.poll_unpin(cx))?;
467 debug!("tls connection established to: {}", name_server);
468 let mut handshake = h2::client::Builder::new();
469 handshake.enable_push(false);
470
471 let handshake = handshake.handshake(tls);
472 Self::H2Handshake {
473 name_server_name: Arc::clone(name_server_name),
474 name_server,
475 handshake: Box::pin(handshake),
476 }
477 }
478 Self::H2Handshake {
479 ref name_server_name,
480 name_server,
481 ref mut handshake,
482 } => {
483 let (send_request, connection) = ready!(handshake
484 .poll_unpin(cx)
485 .map_err(|e| ProtoError::from(format!("h2 handshake error: {e}"))))?;
486
487 debug!("h2 connection established to: {}", name_server);
489 tokio::spawn(
490 connection
491 .map_err(|e| warn!("h2 connection failed: {e}"))
492 .map(|_: Result<(), ()>| ()),
493 );
494
495 Self::Connected(Some(HttpsClientStream {
496 name_server_name: Arc::clone(name_server_name),
497 name_server,
498 h2: send_request,
499 is_shutdown: false,
500 }))
501 }
502 Self::Connected(ref mut conn) => {
503 return Poll::Ready(Ok(conn.take().expect("cannot poll after complete")))
504 }
505 Self::Errored(ref mut err) => {
506 return Poll::Ready(Err(err.take().expect("cannot poll after complete")))
507 }
508 };
509
510 *self.as_mut().deref_mut() = next;
511 }
512 }
513}
514
515pub struct HttpsClientResponse(
517 Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>,
518);
519
520impl Future for HttpsClientResponse {
521 type Output = Result<DnsResponse, ProtoError>;
522
523 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
524 self.0.as_mut().poll(cx).map_err(ProtoError::from)
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use std::net::SocketAddr;
531 use std::str::FromStr;
532
533 use rustls::KeyLogFile;
534 use tokio::net::TcpStream as TokioTcpStream;
535 use tokio::runtime::Runtime;
536
537 use crate::iocompat::AsyncIoTokioAsStd;
538 use crate::op::{Message, Query, ResponseCode};
539 use crate::rr::rdata::{A, AAAA};
540 use crate::rr::{Name, RData, RecordType};
541 use crate::xfer::{DnsRequestOptions, FirstAnswer};
542
543 use super::*;
544
545 #[test]
546 fn test_https_google() {
547 let google = SocketAddr::from(([8, 8, 8, 8], 443));
550 let mut request = Message::new();
551 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
552 request.add_query(query);
553
554 let request = DnsRequest::new(request, DnsRequestOptions::default());
555
556 let mut client_config = client_config_tls12_webpki_roots();
557 client_config.key_log = Arc::new(KeyLogFile::new());
558
559 let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
560 let connect = https_builder
561 .build::<AsyncIoTokioAsStd<TokioTcpStream>>(google, "dns.google".to_string());
562
563 let runtime = Runtime::new().expect("could not start runtime");
565 let mut https = runtime.block_on(connect).expect("https connect failed");
566
567 let response = runtime
568 .block_on(https.send_message(request).first_answer())
569 .expect("send_message failed");
570
571 let record = &response.answers()[0];
572 let addr = record
573 .data()
574 .and_then(RData::as_a)
575 .expect("Expected A record");
576
577 assert_eq!(addr, &A::new(93, 184, 216, 34));
578
579 let mut request = Message::new();
582 let query = Query::query(
583 Name::from_str("www.example.com.").unwrap(),
584 RecordType::AAAA,
585 );
586 request.add_query(query);
587 let request = DnsRequest::new(request, DnsRequestOptions::default());
588
589 for _ in 0..3 {
590 let response = runtime
591 .block_on(https.send_message(request.clone()).first_answer())
592 .expect("send_message failed");
593 if response.response_code() == ResponseCode::ServFail {
594 continue;
595 }
596
597 let record = &response.answers()[0];
598 let addr = record
599 .data()
600 .and_then(RData::as_aaaa)
601 .expect("invalid response, expected A record");
602
603 assert_eq!(
604 addr,
605 &AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
606 );
607 }
608 }
609
610 #[test]
611 fn test_https_google_with_pure_ip_address_server() {
612 let google = SocketAddr::from(([8, 8, 8, 8], 443));
615 let mut request = Message::new();
616 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
617 request.add_query(query);
618
619 let request = DnsRequest::new(request, DnsRequestOptions::default());
620
621 let mut client_config = client_config_tls12_webpki_roots();
622 client_config.key_log = Arc::new(KeyLogFile::new());
623
624 let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
625 let connect = https_builder
626 .build::<AsyncIoTokioAsStd<TokioTcpStream>>(google, google.ip().to_string());
627
628 let runtime = Runtime::new().expect("could not start runtime");
630 let mut https = runtime.block_on(connect).expect("https connect failed");
631
632 let response = runtime
633 .block_on(https.send_message(request).first_answer())
634 .expect("send_message failed");
635
636 let record = &response.answers()[0];
637 let addr = record
638 .data()
639 .and_then(RData::as_a)
640 .expect("Expected A record");
641
642 assert_eq!(addr, &A::new(93, 184, 216, 34));
643
644 let mut request = Message::new();
647 let query = Query::query(
648 Name::from_str("www.example.com.").unwrap(),
649 RecordType::AAAA,
650 );
651 request.add_query(query);
652 let request = DnsRequest::new(request, DnsRequestOptions::default());
653
654 for _ in 0..3 {
655 let response = runtime
656 .block_on(https.send_message(request.clone()).first_answer())
657 .expect("send_message failed");
658 if response.response_code() == ResponseCode::ServFail {
659 continue;
660 }
661
662 let record = &response.answers()[0];
663 let addr = record
664 .data()
665 .and_then(RData::as_aaaa)
666 .expect("invalid response, expected A record");
667
668 assert_eq!(
669 addr,
670 &AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
671 );
672 }
673 }
674
675 #[test]
676 #[ignore] fn test_https_cloudflare() {
678 let cloudflare = SocketAddr::from(([1, 1, 1, 1], 443));
681 let mut request = Message::new();
682 let query = Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::A);
683 request.add_query(query);
684
685 let request = DnsRequest::new(request, DnsRequestOptions::default());
686
687 let client_config = client_config_tls12_webpki_roots();
688 let https_builder = HttpsClientStreamBuilder::with_client_config(Arc::new(client_config));
689 let connect = https_builder.build::<AsyncIoTokioAsStd<TokioTcpStream>>(
690 cloudflare,
691 "cloudflare-dns.com".to_string(),
692 );
693
694 let runtime = Runtime::new().expect("could not start runtime");
696 let mut https = runtime.block_on(connect).expect("https connect failed");
697
698 let response = runtime
699 .block_on(https.send_message(request).first_answer())
700 .expect("send_message failed");
701
702 let record = &response.answers()[0];
703 let addr = record
704 .data()
705 .and_then(RData::as_a)
706 .expect("invalid response, expected A record");
707
708 assert_eq!(addr, &A::new(93, 184, 216, 34));
709
710 let mut request = Message::new();
713 let query = Query::query(
714 Name::from_str("www.example.com.").unwrap(),
715 RecordType::AAAA,
716 );
717 request.add_query(query);
718 let request = DnsRequest::new(request, DnsRequestOptions::default());
719
720 let response = runtime
721 .block_on(https.send_message(request).first_answer())
722 .expect("send_message failed");
723
724 let record = &response.answers()[0];
725 let addr = record
726 .data()
727 .and_then(RData::as_aaaa)
728 .expect("invalid response, expected A record");
729
730 assert_eq!(
731 addr,
732 &AAAA::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
733 );
734 }
735
736 fn client_config_tls12_webpki_roots() -> ClientConfig {
737 use rustls::{OwnedTrustAnchor, RootCertStore};
738 let mut root_store = RootCertStore::empty();
739 root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
740 OwnedTrustAnchor::from_subject_spki_name_constraints(
741 ta.subject,
742 ta.spki,
743 ta.name_constraints,
744 )
745 }));
746
747 let mut client_config = ClientConfig::builder()
748 .with_safe_default_cipher_suites()
749 .with_safe_default_kx_groups()
750 .with_safe_default_protocol_versions()
751 .unwrap()
752 .with_root_certificates(root_store)
753 .with_no_client_auth();
754
755 client_config.alpn_protocols = vec![ALPN_H2.to_vec()];
756 client_config
757 }
758}