1use std::future::Future;
8use std::{
9 io,
10 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
11 sync::Arc,
12 time::Duration,
13};
14
15use drain::{Signal, Watch};
16use futures_util::{FutureExt, StreamExt};
17#[cfg(feature = "dns-over-rustls")]
18use rustls::{Certificate, PrivateKey};
19use tokio::{net, task::JoinSet};
20use tracing::{debug, info, warn};
21use trust_dns_proto::{op::MessageType, rr::Record};
22
23#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
24use crate::proto::openssl::tls_server::*;
25use crate::{
26 authority::{MessageRequest, MessageResponseBuilder},
27 proto::{
28 error::ProtoError,
29 iocompat::AsyncIoTokioAsStd,
30 op::{Edns, Header, LowerQuery, Query, ResponseCode},
31 serialize::binary::{BinDecodable, BinDecoder},
32 tcp::TcpStream,
33 udp::UdpStream,
34 xfer::SerialMessage,
35 BufDnsStreamHandle,
36 },
37 server::{Protocol, Request, RequestHandler, ResponseHandle, ResponseHandler, TimeoutStream},
38};
39
40pub struct ServerFuture<T: RequestHandler> {
43 handler: Arc<T>,
44 join_set: JoinSet<Result<(), ProtoError>>,
45 shutdown_signal: ShutdownSignal,
46 shutdown_watch: Watch,
47}
48
49impl<T: RequestHandler> ServerFuture<T> {
50 pub fn new(handler: T) -> Self {
52 let (signal, watch) = drain::channel();
53 Self {
54 handler: Arc::new(handler),
55 join_set: JoinSet::new(),
56 shutdown_signal: ShutdownSignal::new(signal),
57 shutdown_watch: watch,
58 }
59 }
60
61 pub fn register_socket(&mut self, socket: net::UdpSocket) {
63 debug!("registering udp: {:?}", socket);
64
65 let (stream, stream_handle) =
68 UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
69 let shutdown = self.shutdown_watch.clone();
70 let mut stream = stream.take_until(Box::pin(shutdown.signaled()));
71 let handler = self.handler.clone();
72
73 self.join_set.spawn({
75 async move {
76 let mut inner_join_set = JoinSet::new();
77 while let Some(message) = stream.next().await {
78 let message = match message {
79 Err(e) => {
80 warn!("error receiving message on udp_socket: {}", e);
81 continue;
82 }
83 Ok(message) => message,
84 };
85
86 let src_addr = message.addr();
87 debug!("received udp request from: {}", src_addr);
88
89 if let Err(e) = sanitize_src_address(src_addr) {
91 warn!(
92 "address can not be responded to {src_addr}: {e}",
93 src_addr = src_addr,
94 e = e
95 );
96 continue;
97 }
98
99 let handler = handler.clone();
100 let stream_handle = stream_handle.with_remote_addr(src_addr);
101
102 inner_join_set.spawn(async move {
103 handle_raw_request(message, Protocol::Udp, handler, stream_handle).await;
104 });
105
106 reap_tasks(&mut inner_join_set);
107 }
108
109 if stream.is_stopped() {
110 Ok(())
111 } else {
112 Err(ProtoError::from("unexpected close of UDP socket"))
114 }
115 }
116 });
117 }
118
119 pub fn register_socket_std(&mut self, socket: std::net::UdpSocket) -> io::Result<()> {
121 self.register_socket(net::UdpSocket::from_std(socket)?);
122 Ok(())
123 }
124
125 pub fn register_listener(&mut self, listener: net::TcpListener, timeout: Duration) {
138 debug!("register tcp: {:?}", listener);
139
140 let handler = self.handler.clone();
141
142 let shutdown = self.shutdown_watch.clone();
144 self.join_set.spawn(async move {
145 let mut inner_join_set = JoinSet::new();
146 loop {
147 let (tcp_stream, src_addr) = tokio::select! {
148 tcp_stream = listener.accept() => match tcp_stream {
149 Ok((t, s)) => (t, s),
150 Err(e) => {
151 debug!("error receiving TCP tcp_stream error: {}", e);
152 continue;
153 },
154 },
155 _ = shutdown.clone().signaled() => {
156 break;
158 },
159 };
160
161 if let Err(e) = sanitize_src_address(src_addr) {
163 warn!(
164 "address can not be responded to {src_addr}: {e}",
165 src_addr = src_addr,
166 e = e
167 );
168 continue;
169 }
170
171 let handler = handler.clone();
172
173 inner_join_set.spawn(async move {
175 debug!("accepted request from: {}", src_addr);
176 let (buf_stream, stream_handle) =
178 TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr);
179 let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
180
181 while let Some(message) = timeout_stream.next().await {
182 let message = match message {
183 Ok(message) => message,
184 Err(e) => {
185 debug!(
186 "error in TCP request_stream src: {} error: {}",
187 src_addr, e
188 );
189 return;
191 }
192 };
193
194 handle_raw_request(
196 message,
197 Protocol::Tcp,
198 handler.clone(),
199 stream_handle.clone(),
200 )
201 .await;
202 }
203 });
204
205 reap_tasks(&mut inner_join_set);
206 }
207
208 Ok(())
209 });
210 }
211
212 pub fn register_listener_std(
225 &mut self,
226 listener: std::net::TcpListener,
227 timeout: Duration,
228 ) -> io::Result<()> {
229 self.register_listener(net::TcpListener::from_std(listener)?, timeout);
230 Ok(())
231 }
232
233 #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
247 #[cfg_attr(
248 docsrs,
249 doc(cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls"))))
250 )]
251 pub fn register_tls_listener(
252 &mut self,
253 listener: net::TcpListener,
254 timeout: Duration,
255 certificate_and_key: ((X509, Option<Stack<X509>>), PKey<Private>),
256 ) -> io::Result<()> {
257 use crate::proto::openssl::{tls_server, TlsStream};
258 use openssl::ssl::Ssl;
259 use std::pin::Pin;
260 use tokio_openssl::SslStream as TokioSslStream;
261
262 let ((cert, chain), key) = certificate_and_key;
263
264 let handler = self.handler.clone();
265 debug!("registered tcp: {:?}", listener);
266
267 let tls_acceptor = Box::pin(tls_server::new_acceptor(cert, chain, key)?);
268
269 let shutdown = self.shutdown_watch.clone();
271 self.join_set.spawn(async move {
272 let mut inner_join_set = JoinSet::new();
273 loop {
274 let (tcp_stream, src_addr) = tokio::select! {
275 tcp_stream = listener.accept() => match tcp_stream {
276 Ok((t, s)) => (t, s),
277 Err(e) => {
278 debug!("error receiving TLS tcp_stream error: {}", e);
279 continue;
280 },
281 },
282 _ = shutdown.clone().signaled() => {
283 break;
285 },
286 };
287
288 if let Err(e) = sanitize_src_address(src_addr) {
290 warn!(
291 "address can not be responded to {src_addr}: {e}",
292 src_addr = src_addr,
293 e = e
294 );
295 continue;
296 }
297
298 let handler = handler.clone();
299 let tls_acceptor = tls_acceptor.clone();
300
301 inner_join_set.spawn(async move {
303 debug!("starting TLS request from: {}", src_addr);
304
305 let mut tls_stream = match Ssl::new(tls_acceptor.context())
307 .and_then(|ssl| TokioSslStream::new(ssl, tcp_stream))
308 {
309 Ok(tls_stream) => tls_stream,
310 Err(e) => {
311 debug!("tls handshake src: {} error: {}", src_addr, e);
312 return ();
313 }
314 };
315 match Pin::new(&mut tls_stream).accept().await {
316 Ok(()) => {}
317 Err(e) => {
318 debug!("tls handshake src: {} error: {}", src_addr, e);
319 return ();
320 }
321 };
322 debug!("accepted TLS request from: {}", src_addr);
323 let (buf_stream, stream_handle) =
324 TlsStream::from_stream(AsyncIoTokioAsStd(tls_stream), src_addr);
325 let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
326 while let Some(message) = timeout_stream.next().await {
327 let message = match message {
328 Ok(message) => message,
329 Err(e) => {
330 debug!(
331 "error in TLS request_stream src: {:?} error: {}",
332 src_addr, e
333 );
334
335 return ();
337 }
338 };
339
340 self::handle_raw_request(
341 message,
342 Protocol::Tls,
343 handler.clone(),
344 stream_handle.clone(),
345 )
346 .await;
347 }
348 });
349
350 reap_tasks(&mut inner_join_set);
351 }
352
353 Ok(())
354 });
355
356 Ok(())
357 }
358
359 #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
373 #[cfg_attr(
374 docsrs,
375 doc(cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls"))))
376 )]
377 pub fn register_tls_listener_std(
378 &mut self,
379 listener: std::net::TcpListener,
380 timeout: Duration,
381 certificate_and_key: ((X509, Option<Stack<X509>>), PKey<Private>),
382 ) -> io::Result<()> {
383 self.register_tls_listener(
384 net::TcpListener::from_std(listener)?,
385 timeout,
386 certificate_and_key,
387 )
388 }
389
390 #[cfg(feature = "dns-over-rustls")]
404 #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-rustls")))]
405 pub fn register_tls_listener(
406 &mut self,
407 listener: net::TcpListener,
408 timeout: Duration,
409 certificate_and_key: (Vec<Certificate>, PrivateKey),
410 ) -> io::Result<()> {
411 use crate::proto::rustls::{tls_from_stream, tls_server};
412 use tokio_rustls::TlsAcceptor;
413
414 let handler = self.handler.clone();
415
416 debug!("registered tcp: {:?}", listener);
417
418 let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
419 .map_err(|e| {
420 io::Error::new(
421 io::ErrorKind::Other,
422 format!("error creating TLS acceptor: {e}"),
423 )
424 })?;
425 let tls_acceptor = TlsAcceptor::from(Arc::new(tls_acceptor));
426
427 let shutdown = self.shutdown_watch.clone();
429 self.join_set.spawn(async move {
430 let mut inner_join_set = JoinSet::new();
431 loop {
432 let (tcp_stream, src_addr) = tokio::select! {
433 tcp_stream = listener.accept() => match tcp_stream {
434 Ok((t, s)) => (t, s),
435 Err(e) => {
436 debug!("error receiving TLS tcp_stream error: {}", e);
437 continue;
438 },
439 },
440 _ = shutdown.clone().signaled() => {
441 break;
443 },
444 };
445
446 if let Err(e) = sanitize_src_address(src_addr) {
448 warn!(
449 "address can not be responded to {src_addr}: {e}",
450 src_addr = src_addr,
451 e = e
452 );
453 continue;
454 }
455
456 let handler = handler.clone();
457 let tls_acceptor = tls_acceptor.clone();
458
459 inner_join_set.spawn(async move {
461 debug!("starting TLS request from: {}", src_addr);
462
463 let tls_stream = tls_acceptor.accept(tcp_stream).await;
465
466 let tls_stream = match tls_stream {
467 Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
468 Err(e) => {
469 debug!("tls handshake src: {} error: {}", src_addr, e);
470 return;
471 }
472 };
473 debug!("accepted TLS request from: {}", src_addr);
474 let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
475 let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
476 while let Some(message) = timeout_stream.next().await {
477 let message = match message {
478 Ok(message) => message,
479 Err(e) => {
480 debug!(
481 "error in TLS request_stream src: {:?} error: {}",
482 src_addr, e
483 );
484
485 return;
487 }
488 };
489
490 handle_raw_request(
491 message,
492 Protocol::Tls,
493 handler.clone(),
494 stream_handle.clone(),
495 )
496 .await;
497 }
498 });
499
500 reap_tasks(&mut inner_join_set);
501 }
502
503 Ok(())
504 });
505
506 Ok(())
507 }
508
509 #[cfg(all(
523 feature = "dns-over-https-openssl",
524 not(feature = "dns-over-https-rustls")
525 ))]
526 #[cfg_attr(
527 docsrs,
528 doc(cfg(all(
529 feature = "dns-over-https-openssl",
530 not(feature = "dns-over-https-rustls")
531 )))
532 )]
533 pub fn register_https_listener(
534 &self,
535 listener: tcp::TcpListener,
536 timeout: Duration,
537 pkcs12: ParsedPkcs12,
538 ) -> io::Result<()> {
539 unimplemented!("openssl based `dns-over-https` not yet supported. see the `dns-over-https-rustls` feature")
540 }
541
542 #[cfg(feature = "dns-over-https-rustls")]
556 #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-https-rustls")))]
557 pub fn register_https_listener(
558 &mut self,
559 listener: net::TcpListener,
560 _timeout: Duration,
562 certificate_and_key: (Vec<Certificate>, PrivateKey),
563 dns_hostname: Option<String>,
564 ) -> io::Result<()> {
565 use tokio_rustls::TlsAcceptor;
566
567 use crate::proto::rustls::tls_server;
568 use crate::server::https_handler::h2_handler;
569
570 let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
571
572 let handler = self.handler.clone();
573 debug!("registered https: {listener:?}");
574
575 let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
576 .map_err(|e| {
577 io::Error::new(
578 io::ErrorKind::Other,
579 format!("error creating TLS acceptor: {e}"),
580 )
581 })?;
582 let tls_acceptor = TlsAcceptor::from(Arc::new(tls_acceptor));
583
584 let dns_hostname = dns_hostname;
586 let shutdown = self.shutdown_watch.clone();
587 self.join_set.spawn(async move {
588 let mut inner_join_set = JoinSet::new();
589 let dns_hostname = dns_hostname;
590 loop {
591 let shutdown = shutdown.clone();
592 let (tcp_stream, src_addr) = tokio::select! {
593 tcp_stream = listener.accept() => match tcp_stream {
594 Ok((t, s)) => (t, s),
595 Err(e) => {
596 debug!("error receiving HTTPS tcp_stream error: {}", e);
597 continue;
598 },
599 },
600 _ = shutdown.clone().signaled() => {
601 break;
603 },
604 };
605
606 if let Err(e) = sanitize_src_address(src_addr) {
608 warn!("address can not be responded to {src_addr}: {e}");
609 continue;
610 }
611
612 let handler = handler.clone();
613 let tls_acceptor = tls_acceptor.clone();
614 let dns_hostname = dns_hostname.clone();
615
616 inner_join_set.spawn(async move {
617 debug!("starting HTTPS request from: {src_addr}");
618
619 let tls_stream = tls_acceptor.accept(tcp_stream).await;
622
623 let tls_stream = match tls_stream {
624 Ok(tls_stream) => tls_stream,
625 Err(e) => {
626 debug!("https handshake src: {src_addr} error: {e}");
627 return;
628 }
629 };
630 debug!("accepted HTTPS request from: {src_addr}");
631
632 h2_handler(
633 handler,
634 tls_stream,
635 src_addr,
636 dns_hostname,
637 shutdown.clone(),
638 )
639 .await;
640 });
641
642 reap_tasks(&mut inner_join_set);
643 }
644
645 Ok(())
646 });
647
648 Ok(())
649 }
650
651 #[cfg(feature = "dns-over-quic")]
665 #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-quic")))]
666 pub fn register_quic_listener(
667 &mut self,
668 socket: net::UdpSocket,
669 _timeout: Duration,
671 certificate_and_key: (Vec<Certificate>, PrivateKey),
672 dns_hostname: Option<String>,
673 ) -> io::Result<()> {
674 use crate::proto::quic::QuicServer;
675 use crate::server::quic_handler::quic_handler;
676
677 let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
678
679 let handler = self.handler.clone();
680
681 debug!("registered quic: {:?}", socket);
682 let mut server =
683 QuicServer::with_socket(socket, certificate_and_key.0, certificate_and_key.1)?;
684
685 let dns_hostname = dns_hostname;
687 let shutdown = self.shutdown_watch.clone();
688 self.join_set.spawn(async move {
689 let mut inner_join_set = JoinSet::new();
690 let dns_hostname = dns_hostname;
691 loop {
692 let shutdown = shutdown.clone();
693 let (streams, src_addr) = tokio::select! {
694 result = server.next() => match result {
695 Ok(Some(c)) => c,
696 Ok(None) => continue,
697 Err(e) => {
698 debug!("error receiving quic connection: {e}");
699 continue;
700 }
701 },
702 _ = shutdown.clone().signaled() => {
703 break;
705 },
706 };
707
708 if let Err(e) = sanitize_src_address(src_addr) {
711 warn!(
712 "address can not be responded to {src_addr}: {e}",
713 src_addr = src_addr,
714 e = e
715 );
716 continue;
717 }
718
719 let handler = handler.clone();
720 let dns_hostname = dns_hostname.clone();
721
722 inner_join_set.spawn(async move {
723 debug!("starting quic stream request from: {src_addr}");
724
725 let result =
727 quic_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
728 .await;
729
730 if let Err(e) = result {
731 warn!("quic stream processing failed from {src_addr}: {e}")
732 }
733 });
734
735 reap_tasks(&mut inner_join_set);
736 }
737
738 Ok(())
739 });
740
741 Ok(())
742 }
743
744 pub fn graceful(self) -> (ShutdownSignal, impl Future<Output = Result<(), ProtoError>>) {
750 let signal = self.shutdown_signal;
751 let join_set = self.join_set;
752 (signal, block_until_done(join_set))
753 }
754
755 pub async fn shutdown_gracefully(self) -> Result<(), ProtoError> {
761 let (signal, fut) = self.graceful();
762
763 signal.shutdown().await;
765
766 fut.await
768 }
769
770 pub async fn block_until_done(self) -> Result<(), ProtoError> {
773 block_until_done(self.join_set).await
774 }
775}
776
777#[derive(Debug)]
779pub struct ShutdownSignal {
780 signal: Signal,
781}
782
783impl ShutdownSignal {
784 fn new(signal: Signal) -> Self {
785 Self { signal }
786 }
787
788 pub async fn shutdown(self) {
791 self.signal.drain().await
792 }
793}
794
795async fn block_until_done(mut join_set: JoinSet<Result<(), ProtoError>>) -> Result<(), ProtoError> {
796 if join_set.is_empty() {
797 warn!("block_until_done called with no pending tasks");
798 return Ok(());
799 }
800
801 let mut out = Ok(());
803 while let Some(join_result) = join_set.join_next().await {
804 match join_result {
805 Ok(result) => {
806 match result {
807 Ok(_) => (),
808 Err(e) => {
809 out = Err(e);
811 }
812 }
813 }
814 Err(e) => return Err(ProtoError::from(format!("Internal error in spawn: {e}"))),
815 }
816 }
817 out
818}
819
820fn reap_tasks(join_set: &mut JoinSet<()>) {
822 while FutureExt::now_or_never(join_set.join_next()).is_some() {}
823}
824
825pub(crate) async fn handle_raw_request<T: RequestHandler>(
826 message: SerialMessage,
827 protocol: Protocol,
828 request_handler: Arc<T>,
829 response_handler: BufDnsStreamHandle,
830) {
831 let src_addr = message.addr();
832 let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);
833
834 handle_request(
835 message.bytes(),
836 src_addr,
837 protocol,
838 request_handler,
839 response_handler,
840 )
841 .await;
842}
843
844#[derive(Clone)]
845struct ReportingResponseHandler<R: ResponseHandler> {
846 request_header: Header,
847 query: LowerQuery,
848 protocol: Protocol,
849 src_addr: SocketAddr,
850 handler: R,
851}
852
853#[async_trait::async_trait]
854#[allow(clippy::uninlined_format_args)]
855impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
856 async fn send_response<'a>(
857 &mut self,
858 response: crate::authority::MessageResponse<
859 '_,
860 'a,
861 impl Iterator<Item = &'a Record> + Send + 'a,
862 impl Iterator<Item = &'a Record> + Send + 'a,
863 impl Iterator<Item = &'a Record> + Send + 'a,
864 impl Iterator<Item = &'a Record> + Send + 'a,
865 >,
866 ) -> io::Result<super::ResponseInfo> {
867 let response_info = self.handler.send_response(response).await?;
868
869 let id = self.request_header.id();
870 let rid = response_info.id();
871 if id != rid {
872 warn!("request id:{id} does not match response id:{rid}");
873 debug_assert_eq!(id, rid, "request id and response id should match");
874 }
875
876 let rflags = response_info.flags();
877 let answer_count = response_info.answer_count();
878 let authority_count = response_info.name_server_count();
879 let additional_count = response_info.additional_count();
880 let response_code = response_info.response_code();
881
882 info!("request:{id} src:{proto}://{addr}#{port} {op}:{query}:{qtype}:{class} qflags:{qflags} response:{code:?} rr:{answers}/{authorities}/{additionals} rflags:{rflags}",
883 id = rid,
884 proto = self.protocol,
885 addr = self.src_addr.ip(),
886 port = self.src_addr.port(),
887 op = self.request_header.op_code(),
888 query = self.query.name(),
889 qtype = self.query.query_type(),
890 class = self.query.query_class(),
891 qflags = self.request_header.flags(),
892 code = response_code,
893 answers = answer_count,
894 authorities = authority_count,
895 additionals = additional_count,
896 rflags = rflags
897 );
898
899 Ok(response_info)
900 }
901}
902
903pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
904 message_bytes: &[u8],
906 src_addr: SocketAddr,
907 protocol: Protocol,
908 request_handler: Arc<T>,
909 response_handler: R,
910) {
911 let mut decoder = BinDecoder::new(message_bytes);
912
913 let inner_handle_request = |message: MessageRequest, response_handler: R| async move {
915 if message.message_type() == MessageType::Response {
916 return;
918 }
919
920 let id = message.id();
921 let qflags = message.header().flags();
922 let qop_code = message.op_code();
923 let message_type = message.message_type();
924 let is_dnssec = message.edns().map_or(false, Edns::dnssec_ok);
925
926 let request = Request::new(message, src_addr, protocol);
927
928 let info = request.request_info();
929 let query = info.query.clone();
930 let query_name = info.query.name();
931 let query_type = info.query.query_type();
932 let query_class = info.query.query_class();
933
934 debug!(
935 "request:{id} src:{proto}://{addr}#{port} type:{message_type} dnssec:{is_dnssec} {op}:{query}:{qtype}:{class} qflags:{qflags}",
936 id = id,
937 proto = protocol,
938 addr = src_addr.ip(),
939 port = src_addr.port(),
940 message_type= message_type,
941 is_dnssec = is_dnssec,
942 op = qop_code,
943 query = query_name,
944 qtype = query_type,
945 class = query_class,
946 qflags = qflags,
947 );
948
949 let reporter = ReportingResponseHandler {
951 request_header: *request.header(),
952 query,
953 protocol,
954 src_addr,
955 handler: response_handler,
956 };
957
958 request_handler.handle_request(&request, reporter).await;
959 };
960
961 match MessageRequest::read(&mut decoder) {
963 Ok(message) => {
964 inner_handle_request(message, response_handler).await;
965 }
966 Err(ProtoError { kind, .. }) if kind.as_form_error().is_some() => {
967 let (header, error) = kind
969 .into_form_error()
970 .expect("as form_error already confirmed this is a FormError");
971 let query = LowerQuery::query(Query::default());
972
973 debug!(
975 "request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:FormError:{error}",
976 id = header.id(),
977 proto = protocol,
978 addr = src_addr.ip(),
979 port = src_addr.port(),
980 message_type= header.message_type(),
981 op = header.op_code(),
982 error = error,
983 );
984
985 let mut reporter = ReportingResponseHandler {
987 request_header: header,
988 query,
989 protocol,
990 src_addr,
991 handler: response_handler,
992 };
993
994 let response = MessageResponseBuilder::new(None);
995 let result = reporter
996 .send_response(response.error_msg(&header, ResponseCode::FormErr))
997 .await;
998
999 if let Err(e) = result {
1000 warn!("failed to return FormError to client: {}", e);
1001 }
1002 }
1003 Err(e) => warn!("failed to read message: {}", e),
1004 }
1005}
1006
1007fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
1015 if src.port() == 0 {
1017 return Err(format!("cannot respond to src on port 0: {src}"));
1018 }
1019
1020 fn verify_v4(src: Ipv4Addr) -> Result<(), String> {
1021 if src.is_unspecified() {
1022 return Err(format!("cannot respond to unspecified v4 addr: {src}"));
1023 }
1024
1025 if src.is_broadcast() {
1026 return Err(format!("cannot respond to broadcast v4 addr: {src}"));
1027 }
1028
1029 Ok(())
1032 }
1033
1034 fn verify_v6(src: Ipv6Addr) -> Result<(), String> {
1035 if src.is_unspecified() {
1036 return Err(format!("cannot respond to unspecified v6 addr: {src}"));
1037 }
1038
1039 Ok(())
1040 }
1041
1042 match src.ip() {
1044 IpAddr::V4(v4) => verify_v4(v4),
1045 IpAddr::V6(v6) => verify_v6(v6),
1046 }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051 use super::*;
1052 use crate::authority::Catalog;
1053 use futures_util::future;
1054 #[cfg(feature = "dns-over-rustls")]
1055 use rustls::{Certificate, PrivateKey};
1056 use std::net::SocketAddr;
1057 use tokio::net::{TcpListener, UdpSocket};
1058 use tokio::time::timeout;
1059
1060 #[tokio::test]
1061 async fn abort() {
1062 let endpoints = Endpoints::new().await;
1063
1064 let endpoints2 = endpoints.clone();
1065 let (abortable, abort_handle) = future::abortable(async move {
1066 let mut server_future = ServerFuture::new(Catalog::new());
1067 endpoints2.register(&mut server_future).await;
1068 server_future.block_until_done().await
1069 });
1070
1071 abort_handle.abort();
1072 abortable.await.expect_err("expected abort");
1073
1074 endpoints.rebind_all().await;
1075 }
1076
1077 #[tokio::test]
1078 async fn graceful_shutdown() {
1079 let mut server_future = ServerFuture::new(Catalog::new());
1080 let endpoints = Endpoints::new().await;
1081 endpoints.register(&mut server_future).await;
1082
1083 timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
1084 .await
1085 .expect("timed out waiting for the server to complete")
1086 .expect("error while awaiting tasks");
1087
1088 endpoints.rebind_all().await;
1089 }
1090
1091 #[test]
1092 fn test_sanitize_src_addr() {
1093 assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 4096))).is_ok());
1095 assert!(sanitize_src_address(SocketAddr::from(([127, 0, 0, 1], 53))).is_ok());
1096
1097 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 0))).is_err());
1098 assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 0))).is_err());
1099 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 4096))).is_err());
1100 assert!(sanitize_src_address(SocketAddr::from(([255, 255, 255, 255], 4096))).is_err());
1101
1102 assert!(
1104 sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 4096))).is_ok()
1105 );
1106 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 4096))).is_ok());
1107
1108 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 4096))).is_err());
1109 assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0))).is_err());
1110 assert!(
1111 sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
1112 );
1113 }
1114
1115 #[derive(Clone)]
1116 struct Endpoints {
1117 udp_addr: SocketAddr,
1118 udp_std_addr: SocketAddr,
1119 tcp_addr: SocketAddr,
1120 tcp_std_addr: SocketAddr,
1121 #[cfg(feature = "dns-over-rustls")]
1122 rustls_addr: SocketAddr,
1123 #[cfg(feature = "dns-over-https-rustls")]
1124 https_rustls_addr: SocketAddr,
1125 #[cfg(feature = "dns-over-quic")]
1126 quic_addr: SocketAddr,
1127 }
1128
1129 impl Endpoints {
1130 async fn new() -> Self {
1131 let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1132 let udp_std = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1133 let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
1134 let tcp_std = TcpListener::bind("127.0.0.1:0").await.unwrap();
1135 #[cfg(feature = "dns-over-rustls")]
1136 let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1137 #[cfg(feature = "dns-over-https-rustls")]
1138 let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1139 #[cfg(feature = "dns-over-quic")]
1140 let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1141
1142 Self {
1143 udp_addr: udp.local_addr().unwrap(),
1144 udp_std_addr: udp_std.local_addr().unwrap(),
1145 tcp_addr: tcp.local_addr().unwrap(),
1146 tcp_std_addr: tcp_std.local_addr().unwrap(),
1147 #[cfg(feature = "dns-over-rustls")]
1148 rustls_addr: rustls.local_addr().unwrap(),
1149 #[cfg(feature = "dns-over-https-rustls")]
1150 https_rustls_addr: https_rustls.local_addr().unwrap(),
1151 #[cfg(feature = "dns-over-quic")]
1152 quic_addr: quic.local_addr().unwrap(),
1153 }
1154 }
1155
1156 async fn register<T: RequestHandler>(&self, server: &mut ServerFuture<T>) {
1157 server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
1158 server
1159 .register_socket_std(std::net::UdpSocket::bind(self.udp_std_addr).unwrap())
1160 .unwrap();
1161 server.register_listener(
1162 TcpListener::bind(self.tcp_addr).await.unwrap(),
1163 Duration::from_secs(1),
1164 );
1165 server
1166 .register_listener_std(
1167 std::net::TcpListener::bind(self.tcp_std_addr).unwrap(),
1168 Duration::from_secs(1),
1169 )
1170 .unwrap();
1171
1172 #[cfg(feature = "dns-over-rustls")]
1173 {
1174 let cert_key = rustls_cert_key();
1175 server
1176 .register_tls_listener(
1177 TcpListener::bind(self.rustls_addr).await.unwrap(),
1178 Duration::from_secs(30),
1179 cert_key,
1180 )
1181 .unwrap();
1182 }
1183
1184 #[cfg(feature = "dns-over-https-rustls")]
1185 {
1186 let cert_key = rustls_cert_key();
1187 server
1188 .register_https_listener(
1189 TcpListener::bind(self.https_rustls_addr).await.unwrap(),
1190 Duration::from_secs(1),
1191 cert_key,
1192 None,
1193 )
1194 .unwrap();
1195 }
1196
1197 #[cfg(feature = "dns-over-quic")]
1198 {
1199 let cert_key = rustls_cert_key();
1200 server
1201 .register_quic_listener(
1202 UdpSocket::bind(self.quic_addr).await.unwrap(),
1203 Duration::from_secs(1),
1204 cert_key,
1205 None,
1206 )
1207 .unwrap();
1208 }
1209 }
1210
1211 async fn rebind_all(&self) {
1212 UdpSocket::bind(self.udp_addr).await.unwrap();
1213 UdpSocket::bind(self.udp_std_addr).await.unwrap();
1214 TcpListener::bind(self.tcp_addr).await.unwrap();
1215 TcpListener::bind(self.tcp_std_addr).await.unwrap();
1216 #[cfg(feature = "dns-over-rustls")]
1217 TcpListener::bind(self.rustls_addr).await.unwrap();
1218 #[cfg(feature = "dns-over-https-rustls")]
1219 TcpListener::bind(self.https_rustls_addr).await.unwrap();
1220 #[cfg(feature = "dns-over-quic")]
1221 UdpSocket::bind(self.quic_addr).await.unwrap();
1222 }
1223 }
1224
1225 #[cfg(feature = "dns-over-rustls")]
1226 fn rustls_cert_key() -> (Vec<Certificate>, PrivateKey) {
1227 use std::env;
1228 use std::path::Path;
1229 use trust_dns_proto::rustls::tls_server;
1230
1231 let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
1232
1233 let cert = tls_server::read_cert(Path::new(&format!(
1234 "{}/tests/test-data/cert.pem",
1235 server_path
1236 )))
1237 .map_err(|e| format!("error reading cert: {e}"))
1238 .unwrap();
1239 let key = tls_server::read_key_from_pem(Path::new(&format!(
1240 "{}/tests/test-data/cert.key",
1241 server_path
1242 )))
1243 .unwrap();
1244
1245 (cert, key)
1246 }
1247}