trust_dns_server/server/
server_future.rs

1// Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7use std::future::Future;
8use std::{
9    io,
10    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
11    sync::Arc,
12    time::Duration,
13};
14
15use drain::{Signal, Watch};
16use futures_util::{FutureExt, StreamExt};
17#[cfg(feature = "dns-over-rustls")]
18use rustls::{Certificate, PrivateKey};
19use tokio::{net, task::JoinSet};
20use tracing::{debug, info, warn};
21use trust_dns_proto::{op::MessageType, rr::Record};
22
23#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
24use crate::proto::openssl::tls_server::*;
25use crate::{
26    authority::{MessageRequest, MessageResponseBuilder},
27    proto::{
28        error::ProtoError,
29        iocompat::AsyncIoTokioAsStd,
30        op::{Edns, Header, LowerQuery, Query, ResponseCode},
31        serialize::binary::{BinDecodable, BinDecoder},
32        tcp::TcpStream,
33        udp::UdpStream,
34        xfer::SerialMessage,
35        BufDnsStreamHandle,
36    },
37    server::{Protocol, Request, RequestHandler, ResponseHandle, ResponseHandler, TimeoutStream},
38};
39
40// TODO, would be nice to have a Slab for buffers here...
41/// A Futures based implementation of a DNS server
42pub struct ServerFuture<T: RequestHandler> {
43    handler: Arc<T>,
44    join_set: JoinSet<Result<(), ProtoError>>,
45    shutdown_signal: ShutdownSignal,
46    shutdown_watch: Watch,
47}
48
49impl<T: RequestHandler> ServerFuture<T> {
50    /// Creates a new ServerFuture with the specified Handler.
51    pub fn new(handler: T) -> Self {
52        let (signal, watch) = drain::channel();
53        Self {
54            handler: Arc::new(handler),
55            join_set: JoinSet::new(),
56            shutdown_signal: ShutdownSignal::new(signal),
57            shutdown_watch: watch,
58        }
59    }
60
61    /// Register a UDP socket. Should be bound before calling this function.
62    pub fn register_socket(&mut self, socket: net::UdpSocket) {
63        debug!("registering udp: {:?}", socket);
64
65        // create the new UdpStream, the IP address isn't relevant, and ideally goes essentially no where.
66        //   the address used is acquired from the inbound queries
67        let (stream, stream_handle) =
68            UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into());
69        let shutdown = self.shutdown_watch.clone();
70        let mut stream = stream.take_until(Box::pin(shutdown.signaled()));
71        let handler = self.handler.clone();
72
73        // this spawns a ForEach future which handles all the requests into a Handler.
74        self.join_set.spawn({
75            async move {
76                let mut inner_join_set = JoinSet::new();
77                while let Some(message) = stream.next().await {
78                    let message = match message {
79                        Err(e) => {
80                            warn!("error receiving message on udp_socket: {}", e);
81                            continue;
82                        }
83                        Ok(message) => message,
84                    };
85
86                    let src_addr = message.addr();
87                    debug!("received udp request from: {}", src_addr);
88
89                    // verify that the src address is safe for responses
90                    if let Err(e) = sanitize_src_address(src_addr) {
91                        warn!(
92                            "address can not be responded to {src_addr}: {e}",
93                            src_addr = src_addr,
94                            e = e
95                        );
96                        continue;
97                    }
98
99                    let handler = handler.clone();
100                    let stream_handle = stream_handle.with_remote_addr(src_addr);
101
102                    inner_join_set.spawn(async move {
103                        handle_raw_request(message, Protocol::Udp, handler, stream_handle).await;
104                    });
105
106                    reap_tasks(&mut inner_join_set);
107                }
108
109                if stream.is_stopped() {
110                    Ok(())
111                } else {
112                    // TODO: let's consider capturing all the initial configuration details so that the socket could be recreated...
113                    Err(ProtoError::from("unexpected close of UDP socket"))
114                }
115            }
116        });
117    }
118
119    /// Register a UDP socket. Should be bound before calling this function.
120    pub fn register_socket_std(&mut self, socket: std::net::UdpSocket) -> io::Result<()> {
121        self.register_socket(net::UdpSocket::from_std(socket)?);
122        Ok(())
123    }
124
125    /// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an
126    ///  IPv4 address.
127    ///
128    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
129    ///  to not make this too low depending on use cases.
130    ///
131    /// # Arguments
132    /// * `listener` - a bound TCP socket
133    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
134    ///               requests within this time period will be closed. In the future it should be
135    ///               possible to create long-lived queries, but these should be from trusted sources
136    ///               only, this would require some type of whitelisting.
137    pub fn register_listener(&mut self, listener: net::TcpListener, timeout: Duration) {
138        debug!("register tcp: {:?}", listener);
139
140        let handler = self.handler.clone();
141
142        // for each incoming request...
143        let shutdown = self.shutdown_watch.clone();
144        self.join_set.spawn(async move {
145            let mut inner_join_set = JoinSet::new();
146            loop {
147                let (tcp_stream, src_addr) = tokio::select! {
148                    tcp_stream = listener.accept() => match tcp_stream {
149                        Ok((t, s)) => (t, s),
150                        Err(e) => {
151                            debug!("error receiving TCP tcp_stream error: {}", e);
152                            continue;
153                        },
154                    },
155                    _ = shutdown.clone().signaled() => {
156                        // A graceful shutdown was initiated. Break out of the loop.
157                        break;
158                    },
159                };
160
161                // verify that the src address is safe for responses
162                if let Err(e) = sanitize_src_address(src_addr) {
163                    warn!(
164                        "address can not be responded to {src_addr}: {e}",
165                        src_addr = src_addr,
166                        e = e
167                    );
168                    continue;
169                }
170
171                let handler = handler.clone();
172
173                // and spawn to the io_loop
174                inner_join_set.spawn(async move {
175                    debug!("accepted request from: {}", src_addr);
176                    // take the created stream...
177                    let (buf_stream, stream_handle) =
178                        TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr);
179                    let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
180
181                    while let Some(message) = timeout_stream.next().await {
182                        let message = match message {
183                            Ok(message) => message,
184                            Err(e) => {
185                                debug!(
186                                    "error in TCP request_stream src: {} error: {}",
187                                    src_addr, e
188                                );
189                                // we're going to bail on this connection...
190                                return;
191                            }
192                        };
193
194                        // we don't spawn here to limit clients from getting too many resources
195                        handle_raw_request(
196                            message,
197                            Protocol::Tcp,
198                            handler.clone(),
199                            stream_handle.clone(),
200                        )
201                        .await;
202                    }
203                });
204
205                reap_tasks(&mut inner_join_set);
206            }
207
208            Ok(())
209        });
210    }
211
212    /// Register a TcpListener to the Server. This should already be bound to either an IPv6 or an
213    ///  IPv4 address.
214    ///
215    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
216    ///  to not make this too low depending on use cases.
217    ///
218    /// # Arguments
219    /// * `listener` - a bound TCP socket
220    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
221    ///               requests within this time period will be closed. In the future it should be
222    ///               possible to create long-lived queries, but these should be from trusted sources
223    ///               only, this would require some type of whitelisting.
224    pub fn register_listener_std(
225        &mut self,
226        listener: std::net::TcpListener,
227        timeout: Duration,
228    ) -> io::Result<()> {
229        self.register_listener(net::TcpListener::from_std(listener)?, timeout);
230        Ok(())
231    }
232
233    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
234    /// IPv6 or an IPv4 address.
235    ///
236    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
237    ///  to not make this too low depending on use cases.
238    ///
239    /// # Arguments
240    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
241    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
242    ///               requests within this time period will be closed. In the future it should be
243    ///               possible to create long-lived queries, but these should be from trusted sources
244    ///               only, this would require some type of whitelisting.
245    /// * `pkcs12` - certificate used to announce to clients
246    #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
247    #[cfg_attr(
248        docsrs,
249        doc(cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls"))))
250    )]
251    pub fn register_tls_listener(
252        &mut self,
253        listener: net::TcpListener,
254        timeout: Duration,
255        certificate_and_key: ((X509, Option<Stack<X509>>), PKey<Private>),
256    ) -> io::Result<()> {
257        use crate::proto::openssl::{tls_server, TlsStream};
258        use openssl::ssl::Ssl;
259        use std::pin::Pin;
260        use tokio_openssl::SslStream as TokioSslStream;
261
262        let ((cert, chain), key) = certificate_and_key;
263
264        let handler = self.handler.clone();
265        debug!("registered tcp: {:?}", listener);
266
267        let tls_acceptor = Box::pin(tls_server::new_acceptor(cert, chain, key)?);
268
269        // for each incoming request...
270        let shutdown = self.shutdown_watch.clone();
271        self.join_set.spawn(async move {
272            let mut inner_join_set = JoinSet::new();
273            loop {
274                let (tcp_stream, src_addr) = tokio::select! {
275                    tcp_stream = listener.accept() => match tcp_stream {
276                        Ok((t, s)) => (t, s),
277                        Err(e) => {
278                            debug!("error receiving TLS tcp_stream error: {}", e);
279                            continue;
280                        },
281                    },
282                    _ = shutdown.clone().signaled() => {
283                        // A graceful shutdown was initiated. Break out of the loop.
284                        break;
285                    },
286                };
287
288                // verify that the src address is safe for responses
289                if let Err(e) = sanitize_src_address(src_addr) {
290                    warn!(
291                        "address can not be responded to {src_addr}: {e}",
292                        src_addr = src_addr,
293                        e = e
294                    );
295                    continue;
296                }
297
298                let handler = handler.clone();
299                let tls_acceptor = tls_acceptor.clone();
300
301                // kick out to a different task immediately, let them do the TLS handshake
302                inner_join_set.spawn(async move {
303                    debug!("starting TLS request from: {}", src_addr);
304
305                    // perform the TLS
306                    let mut tls_stream = match Ssl::new(tls_acceptor.context())
307                        .and_then(|ssl| TokioSslStream::new(ssl, tcp_stream))
308                    {
309                        Ok(tls_stream) => tls_stream,
310                        Err(e) => {
311                            debug!("tls handshake src: {} error: {}", src_addr, e);
312                            return ();
313                        }
314                    };
315                    match Pin::new(&mut tls_stream).accept().await {
316                        Ok(()) => {}
317                        Err(e) => {
318                            debug!("tls handshake src: {} error: {}", src_addr, e);
319                            return ();
320                        }
321                    };
322                    debug!("accepted TLS request from: {}", src_addr);
323                    let (buf_stream, stream_handle) =
324                        TlsStream::from_stream(AsyncIoTokioAsStd(tls_stream), src_addr);
325                    let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
326                    while let Some(message) = timeout_stream.next().await {
327                        let message = match message {
328                            Ok(message) => message,
329                            Err(e) => {
330                                debug!(
331                                    "error in TLS request_stream src: {:?} error: {}",
332                                    src_addr, e
333                                );
334
335                                // kill this connection
336                                return ();
337                            }
338                        };
339
340                        self::handle_raw_request(
341                            message,
342                            Protocol::Tls,
343                            handler.clone(),
344                            stream_handle.clone(),
345                        )
346                        .await;
347                    }
348                });
349
350                reap_tasks(&mut inner_join_set);
351            }
352
353            Ok(())
354        });
355
356        Ok(())
357    }
358
359    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
360    /// IPv6 or an IPv4 address.
361    ///
362    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
363    ///  to not make this too low depending on use cases.
364    ///
365    /// # Arguments
366    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
367    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
368    ///               requests within this time period will be closed. In the future it should be
369    ///               possible to create long-lived queries, but these should be from trusted sources
370    ///               only, this would require some type of whitelisting.
371    /// * `pkcs12` - certificate used to announce to clients
372    #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
373    #[cfg_attr(
374        docsrs,
375        doc(cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls"))))
376    )]
377    pub fn register_tls_listener_std(
378        &mut self,
379        listener: std::net::TcpListener,
380        timeout: Duration,
381        certificate_and_key: ((X509, Option<Stack<X509>>), PKey<Private>),
382    ) -> io::Result<()> {
383        self.register_tls_listener(
384            net::TcpListener::from_std(listener)?,
385            timeout,
386            certificate_and_key,
387        )
388    }
389
390    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
391    /// IPv6 or an IPv4 address.
392    ///
393    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
394    ///  to not make this too low depending on use cases.
395    ///
396    /// # Arguments
397    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
398    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
399    ///               requests within this time period will be closed. In the future it should be
400    ///               possible to create long-lived queries, but these should be from trusted sources
401    ///               only, this would require some type of whitelisting.
402    /// * `pkcs12` - certificate used to announce to clients
403    #[cfg(feature = "dns-over-rustls")]
404    #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-rustls")))]
405    pub fn register_tls_listener(
406        &mut self,
407        listener: net::TcpListener,
408        timeout: Duration,
409        certificate_and_key: (Vec<Certificate>, PrivateKey),
410    ) -> io::Result<()> {
411        use crate::proto::rustls::{tls_from_stream, tls_server};
412        use tokio_rustls::TlsAcceptor;
413
414        let handler = self.handler.clone();
415
416        debug!("registered tcp: {:?}", listener);
417
418        let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
419            .map_err(|e| {
420            io::Error::new(
421                io::ErrorKind::Other,
422                format!("error creating TLS acceptor: {e}"),
423            )
424        })?;
425        let tls_acceptor = TlsAcceptor::from(Arc::new(tls_acceptor));
426
427        // for each incoming request...
428        let shutdown = self.shutdown_watch.clone();
429        self.join_set.spawn(async move {
430            let mut inner_join_set = JoinSet::new();
431            loop {
432                let (tcp_stream, src_addr) = tokio::select! {
433                    tcp_stream = listener.accept() => match tcp_stream {
434                        Ok((t, s)) => (t, s),
435                        Err(e) => {
436                            debug!("error receiving TLS tcp_stream error: {}", e);
437                            continue;
438                        },
439                    },
440                    _ = shutdown.clone().signaled() => {
441                        // A graceful shutdown was initiated. Break out of the loop.
442                        break;
443                    },
444                };
445
446                // verify that the src address is safe for responses
447                if let Err(e) = sanitize_src_address(src_addr) {
448                    warn!(
449                        "address can not be responded to {src_addr}: {e}",
450                        src_addr = src_addr,
451                        e = e
452                    );
453                    continue;
454                }
455
456                let handler = handler.clone();
457                let tls_acceptor = tls_acceptor.clone();
458
459                // kick out to a different task immediately, let them do the TLS handshake
460                inner_join_set.spawn(async move {
461                    debug!("starting TLS request from: {}", src_addr);
462
463                    // perform the TLS
464                    let tls_stream = tls_acceptor.accept(tcp_stream).await;
465
466                    let tls_stream = match tls_stream {
467                        Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream),
468                        Err(e) => {
469                            debug!("tls handshake src: {} error: {}", src_addr, e);
470                            return;
471                        }
472                    };
473                    debug!("accepted TLS request from: {}", src_addr);
474                    let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr);
475                    let mut timeout_stream = TimeoutStream::new(buf_stream, timeout);
476                    while let Some(message) = timeout_stream.next().await {
477                        let message = match message {
478                            Ok(message) => message,
479                            Err(e) => {
480                                debug!(
481                                    "error in TLS request_stream src: {:?} error: {}",
482                                    src_addr, e
483                                );
484
485                                // kill this connection
486                                return;
487                            }
488                        };
489
490                        handle_raw_request(
491                            message,
492                            Protocol::Tls,
493                            handler.clone(),
494                            stream_handle.clone(),
495                        )
496                        .await;
497                    }
498                });
499
500                reap_tasks(&mut inner_join_set);
501            }
502
503            Ok(())
504        });
505
506        Ok(())
507    }
508
509    /// Register a TlsListener to the Server. The TlsListener should already be bound to either an
510    /// IPv6 or an IPv4 address.
511    ///
512    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
513    ///  to not make this too low depending on use cases.
514    ///
515    /// # Arguments
516    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
517    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
518    ///               requests within this time period will be closed. In the future it should be
519    ///               possible to create long-lived queries, but these should be from trusted sources
520    ///               only, this would require some type of whitelisting.
521    /// * `pkcs12` - certificate used to announce to clients
522    #[cfg(all(
523        feature = "dns-over-https-openssl",
524        not(feature = "dns-over-https-rustls")
525    ))]
526    #[cfg_attr(
527        docsrs,
528        doc(cfg(all(
529            feature = "dns-over-https-openssl",
530            not(feature = "dns-over-https-rustls")
531        )))
532    )]
533    pub fn register_https_listener(
534        &self,
535        listener: tcp::TcpListener,
536        timeout: Duration,
537        pkcs12: ParsedPkcs12,
538    ) -> io::Result<()> {
539        unimplemented!("openssl based `dns-over-https` not yet supported. see the `dns-over-https-rustls` feature")
540    }
541
542    /// Register a TcpListener for HTTPS (h2) to the Server for supporting DoH (dns-over-https). The TcpListener should already be bound to either an
543    /// IPv6 or an IPv4 address.
544    ///
545    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
546    ///  to not make this too low depending on use cases.
547    ///
548    /// # Arguments
549    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
550    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
551    ///               requests within this time period will be closed. In the future it should be
552    ///               possible to create long-lived queries, but these should be from trusted sources
553    ///               only, this would require some type of whitelisting.
554    /// * `certificate_and_key` - certificate and key used to announce to clients
555    #[cfg(feature = "dns-over-https-rustls")]
556    #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-https-rustls")))]
557    pub fn register_https_listener(
558        &mut self,
559        listener: net::TcpListener,
560        // TODO: need to set a timeout between requests.
561        _timeout: Duration,
562        certificate_and_key: (Vec<Certificate>, PrivateKey),
563        dns_hostname: Option<String>,
564    ) -> io::Result<()> {
565        use tokio_rustls::TlsAcceptor;
566
567        use crate::proto::rustls::tls_server;
568        use crate::server::https_handler::h2_handler;
569
570        let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
571
572        let handler = self.handler.clone();
573        debug!("registered https: {listener:?}");
574
575        let tls_acceptor = tls_server::new_acceptor(certificate_and_key.0, certificate_and_key.1)
576            .map_err(|e| {
577            io::Error::new(
578                io::ErrorKind::Other,
579                format!("error creating TLS acceptor: {e}"),
580            )
581        })?;
582        let tls_acceptor = TlsAcceptor::from(Arc::new(tls_acceptor));
583
584        // for each incoming request...
585        let dns_hostname = dns_hostname;
586        let shutdown = self.shutdown_watch.clone();
587        self.join_set.spawn(async move {
588            let mut inner_join_set = JoinSet::new();
589            let dns_hostname = dns_hostname;
590            loop {
591                let shutdown = shutdown.clone();
592                let (tcp_stream, src_addr) = tokio::select! {
593                    tcp_stream = listener.accept() => match tcp_stream {
594                        Ok((t, s)) => (t, s),
595                        Err(e) => {
596                            debug!("error receiving HTTPS tcp_stream error: {}", e);
597                            continue;
598                        },
599                    },
600                    _ = shutdown.clone().signaled() => {
601                        // A graceful shutdown was initiated. Break out of the loop.
602                        break;
603                    },
604                };
605
606                // verify that the src address is safe for responses
607                if let Err(e) = sanitize_src_address(src_addr) {
608                    warn!("address can not be responded to {src_addr}: {e}");
609                    continue;
610                }
611
612                let handler = handler.clone();
613                let tls_acceptor = tls_acceptor.clone();
614                let dns_hostname = dns_hostname.clone();
615
616                inner_join_set.spawn(async move {
617                    debug!("starting HTTPS request from: {src_addr}");
618
619                    // TODO: need to consider timeout of total connect...
620                    // take the created stream...
621                    let tls_stream = tls_acceptor.accept(tcp_stream).await;
622
623                    let tls_stream = match tls_stream {
624                        Ok(tls_stream) => tls_stream,
625                        Err(e) => {
626                            debug!("https handshake src: {src_addr} error: {e}");
627                            return;
628                        }
629                    };
630                    debug!("accepted HTTPS request from: {src_addr}");
631
632                    h2_handler(
633                        handler,
634                        tls_stream,
635                        src_addr,
636                        dns_hostname,
637                        shutdown.clone(),
638                    )
639                    .await;
640                });
641
642                reap_tasks(&mut inner_join_set);
643            }
644
645            Ok(())
646        });
647
648        Ok(())
649    }
650
651    /// Register a UdpSocket to the Server for supporting DoQ (dns-over-quic). The UdpSocket should already be bound to either an
652    /// IPv6 or an IPv4 address.
653    ///
654    /// To make the server more resilient to DOS issues, there is a timeout. Care should be taken
655    ///  to not make this too low depending on use cases.
656    ///
657    /// # Arguments
658    /// * `listener` - a bound TCP (needs to be on a different port from standard TCP connections) socket
659    /// * `timeout` - timeout duration of incoming requests, any connection that does not send
660    ///               requests within this time period will be closed. In the future it should be
661    ///               possible to create long-lived queries, but these should be from trusted sources
662    ///               only, this would require some type of whitelisting.
663    /// * `pkcs12` - certificate used to announce to clients
664    #[cfg(feature = "dns-over-quic")]
665    #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-quic")))]
666    pub fn register_quic_listener(
667        &mut self,
668        socket: net::UdpSocket,
669        // TODO: need to set a timeout between requests.
670        _timeout: Duration,
671        certificate_and_key: (Vec<Certificate>, PrivateKey),
672        dns_hostname: Option<String>,
673    ) -> io::Result<()> {
674        use crate::proto::quic::QuicServer;
675        use crate::server::quic_handler::quic_handler;
676
677        let dns_hostname: Option<Arc<str>> = dns_hostname.map(|n| n.into());
678
679        let handler = self.handler.clone();
680
681        debug!("registered quic: {:?}", socket);
682        let mut server =
683            QuicServer::with_socket(socket, certificate_and_key.0, certificate_and_key.1)?;
684
685        // for each incoming request...
686        let dns_hostname = dns_hostname;
687        let shutdown = self.shutdown_watch.clone();
688        self.join_set.spawn(async move {
689            let mut inner_join_set = JoinSet::new();
690            let dns_hostname = dns_hostname;
691            loop {
692                let shutdown = shutdown.clone();
693                let (streams, src_addr) = tokio::select! {
694                    result = server.next() => match result {
695                        Ok(Some(c)) => c,
696                        Ok(None) => continue,
697                        Err(e) => {
698                            debug!("error receiving quic connection: {e}");
699                            continue;
700                        }
701                    },
702                    _ = shutdown.clone().signaled() => {
703                        // A graceful shutdown was initiated. Break out of the loop.
704                        break;
705                    },
706                };
707
708                // verify that the src address is safe for responses
709                // TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing
710                if let Err(e) = sanitize_src_address(src_addr) {
711                    warn!(
712                        "address can not be responded to {src_addr}: {e}",
713                        src_addr = src_addr,
714                        e = e
715                    );
716                    continue;
717                }
718
719                let handler = handler.clone();
720                let dns_hostname = dns_hostname.clone();
721
722                inner_join_set.spawn(async move {
723                    debug!("starting quic stream request from: {src_addr}");
724
725                    // TODO: need to consider timeout of total connect...
726                    let result =
727                        quic_handler(handler, streams, src_addr, dns_hostname, shutdown.clone())
728                            .await;
729
730                    if let Err(e) = result {
731                        warn!("quic stream processing failed from {src_addr}: {e}")
732                    }
733                });
734
735                reap_tasks(&mut inner_join_set);
736            }
737
738            Ok(())
739        });
740
741        Ok(())
742    }
743
744    /// Returns a signal used for initiating a graceful shutdown of the server and the future
745    /// used for awaiting completion of the shutdown.
746    ///
747    /// This allows the application to have separate code paths that are responsible for
748    /// triggering shutdown and awaiting application completion.
749    pub fn graceful(self) -> (ShutdownSignal, impl Future<Output = Result<(), ProtoError>>) {
750        let signal = self.shutdown_signal;
751        let join_set = self.join_set;
752        (signal, block_until_done(join_set))
753    }
754
755    /// Triggers a graceful shutdown the server. All background tasks will stop accepting
756    /// new connections and the returned future will complete once all tasks have terminated.
757    ///
758    /// This is equivalent to calling [Self::graceful], then triggering the graceful
759    /// shutdown (via [ShutdownSignal::shutdown]) and awaiting completion of the server.
760    pub async fn shutdown_gracefully(self) -> Result<(), ProtoError> {
761        let (signal, fut) = self.graceful();
762
763        // Trigger shutdown.
764        signal.shutdown().await;
765
766        // Wait for the server to complete.
767        fut.await
768    }
769
770    /// This will run until all background tasks complete. If one or more tasks return an error,
771    /// one will be chosen as the returned error for this future.
772    pub async fn block_until_done(self) -> Result<(), ProtoError> {
773        block_until_done(self.join_set).await
774    }
775}
776
777/// Signals the start of a graceful shutdown.
778#[derive(Debug)]
779pub struct ShutdownSignal {
780    signal: Signal,
781}
782
783impl ShutdownSignal {
784    fn new(signal: Signal) -> Self {
785        Self { signal }
786    }
787
788    /// Asynchronously sends the shutdown command to all server threads and
789    /// waits for them to complete.
790    pub async fn shutdown(self) {
791        self.signal.drain().await
792    }
793}
794
795async fn block_until_done(mut join_set: JoinSet<Result<(), ProtoError>>) -> Result<(), ProtoError> {
796    if join_set.is_empty() {
797        warn!("block_until_done called with no pending tasks");
798        return Ok(());
799    }
800
801    // Now wait for all of the tasks to complete.
802    let mut out = Ok(());
803    while let Some(join_result) = join_set.join_next().await {
804        match join_result {
805            Ok(result) => {
806                match result {
807                    Ok(_) => (),
808                    Err(e) => {
809                        // Save the last error.
810                        out = Err(e);
811                    }
812                }
813            }
814            Err(e) => return Err(ProtoError::from(format!("Internal error in spawn: {e}"))),
815        }
816    }
817    out
818}
819
820/// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
821fn reap_tasks(join_set: &mut JoinSet<()>) {
822    while FutureExt::now_or_never(join_set.join_next()).is_some() {}
823}
824
825pub(crate) async fn handle_raw_request<T: RequestHandler>(
826    message: SerialMessage,
827    protocol: Protocol,
828    request_handler: Arc<T>,
829    response_handler: BufDnsStreamHandle,
830) {
831    let src_addr = message.addr();
832    let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);
833
834    handle_request(
835        message.bytes(),
836        src_addr,
837        protocol,
838        request_handler,
839        response_handler,
840    )
841    .await;
842}
843
844#[derive(Clone)]
845struct ReportingResponseHandler<R: ResponseHandler> {
846    request_header: Header,
847    query: LowerQuery,
848    protocol: Protocol,
849    src_addr: SocketAddr,
850    handler: R,
851}
852
853#[async_trait::async_trait]
854#[allow(clippy::uninlined_format_args)]
855impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
856    async fn send_response<'a>(
857        &mut self,
858        response: crate::authority::MessageResponse<
859            '_,
860            'a,
861            impl Iterator<Item = &'a Record> + Send + 'a,
862            impl Iterator<Item = &'a Record> + Send + 'a,
863            impl Iterator<Item = &'a Record> + Send + 'a,
864            impl Iterator<Item = &'a Record> + Send + 'a,
865        >,
866    ) -> io::Result<super::ResponseInfo> {
867        let response_info = self.handler.send_response(response).await?;
868
869        let id = self.request_header.id();
870        let rid = response_info.id();
871        if id != rid {
872            warn!("request id:{id} does not match response id:{rid}");
873            debug_assert_eq!(id, rid, "request id and response id should match");
874        }
875
876        let rflags = response_info.flags();
877        let answer_count = response_info.answer_count();
878        let authority_count = response_info.name_server_count();
879        let additional_count = response_info.additional_count();
880        let response_code = response_info.response_code();
881
882        info!("request:{id} src:{proto}://{addr}#{port} {op}:{query}:{qtype}:{class} qflags:{qflags} response:{code:?} rr:{answers}/{authorities}/{additionals} rflags:{rflags}",
883            id = rid,
884            proto = self.protocol,
885            addr = self.src_addr.ip(),
886            port = self.src_addr.port(),
887            op = self.request_header.op_code(),
888            query = self.query.name(),
889            qtype = self.query.query_type(),
890            class = self.query.query_class(),
891            qflags = self.request_header.flags(),
892            code = response_code,
893            answers = answer_count,
894            authorities = authority_count,
895            additionals = additional_count,
896            rflags = rflags
897        );
898
899        Ok(response_info)
900    }
901}
902
903pub(crate) async fn handle_request<R: ResponseHandler, T: RequestHandler>(
904    // TODO: allow Message here...
905    message_bytes: &[u8],
906    src_addr: SocketAddr,
907    protocol: Protocol,
908    request_handler: Arc<T>,
909    response_handler: R,
910) {
911    let mut decoder = BinDecoder::new(message_bytes);
912
913    // method to handle the request
914    let inner_handle_request = |message: MessageRequest, response_handler: R| async move {
915        if message.message_type() == MessageType::Response {
916            // Don't process response messages to avoid DoS attacks from reflection.
917            return;
918        }
919
920        let id = message.id();
921        let qflags = message.header().flags();
922        let qop_code = message.op_code();
923        let message_type = message.message_type();
924        let is_dnssec = message.edns().map_or(false, Edns::dnssec_ok);
925
926        let request = Request::new(message, src_addr, protocol);
927
928        let info = request.request_info();
929        let query = info.query.clone();
930        let query_name = info.query.name();
931        let query_type = info.query.query_type();
932        let query_class = info.query.query_class();
933
934        debug!(
935            "request:{id} src:{proto}://{addr}#{port} type:{message_type} dnssec:{is_dnssec} {op}:{query}:{qtype}:{class} qflags:{qflags}",
936            id = id,
937            proto = protocol,
938            addr = src_addr.ip(),
939            port = src_addr.port(),
940            message_type= message_type,
941            is_dnssec = is_dnssec,
942            op = qop_code,
943            query = query_name,
944            qtype = query_type,
945            class = query_class,
946            qflags = qflags,
947        );
948
949        // The reporter will handle making sure to log the result of the request
950        let reporter = ReportingResponseHandler {
951            request_header: *request.header(),
952            query,
953            protocol,
954            src_addr,
955            handler: response_handler,
956        };
957
958        request_handler.handle_request(&request, reporter).await;
959    };
960
961    // Attempt to decode the message
962    match MessageRequest::read(&mut decoder) {
963        Ok(message) => {
964            inner_handle_request(message, response_handler).await;
965        }
966        Err(ProtoError { kind, .. }) if kind.as_form_error().is_some() => {
967            // We failed to parse the request due to some issue in the message, but the header is available, so we can respond
968            let (header, error) = kind
969                .into_form_error()
970                .expect("as form_error already confirmed this is a FormError");
971            let query = LowerQuery::query(Query::default());
972
973            // debug for more info on why the message parsing failed
974            debug!(
975                "request:{id} src:{proto}://{addr}#{port} type:{message_type} {op}:FormError:{error}",
976                id = header.id(),
977                proto = protocol,
978                addr = src_addr.ip(),
979                port = src_addr.port(),
980                message_type= header.message_type(),
981                op = header.op_code(),
982                error = error,
983            );
984
985            // The reporter will handle making sure to log the result of the request
986            let mut reporter = ReportingResponseHandler {
987                request_header: header,
988                query,
989                protocol,
990                src_addr,
991                handler: response_handler,
992            };
993
994            let response = MessageResponseBuilder::new(None);
995            let result = reporter
996                .send_response(response.error_msg(&header, ResponseCode::FormErr))
997                .await;
998
999            if let Err(e) = result {
1000                warn!("failed to return FormError to client: {}", e);
1001            }
1002        }
1003        Err(e) => warn!("failed to read message: {}", e),
1004    }
1005}
1006
1007/// Checks if the IP address is safe for returning messages
1008///
1009/// Examples of unsafe addresses are any with a port of `0`
1010///
1011/// # Returns
1012///
1013/// Error if the address should not be used for returned requests
1014fn sanitize_src_address(src: SocketAddr) -> Result<(), String> {
1015    // currently checks that the src address aren't either the undefined IPv4 or IPv6 address, and not port 0.
1016    if src.port() == 0 {
1017        return Err(format!("cannot respond to src on port 0: {src}"));
1018    }
1019
1020    fn verify_v4(src: Ipv4Addr) -> Result<(), String> {
1021        if src.is_unspecified() {
1022            return Err(format!("cannot respond to unspecified v4 addr: {src}"));
1023        }
1024
1025        if src.is_broadcast() {
1026            return Err(format!("cannot respond to broadcast v4 addr: {src}"));
1027        }
1028
1029        // TODO: add check for is_reserved when that stabilizes
1030
1031        Ok(())
1032    }
1033
1034    fn verify_v6(src: Ipv6Addr) -> Result<(), String> {
1035        if src.is_unspecified() {
1036            return Err(format!("cannot respond to unspecified v6 addr: {src}"));
1037        }
1038
1039        Ok(())
1040    }
1041
1042    // currently checks that the src address aren't either the undefined IPv4 or IPv6 address, and not port 0.
1043    match src.ip() {
1044        IpAddr::V4(v4) => verify_v4(v4),
1045        IpAddr::V6(v6) => verify_v6(v6),
1046    }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051    use super::*;
1052    use crate::authority::Catalog;
1053    use futures_util::future;
1054    #[cfg(feature = "dns-over-rustls")]
1055    use rustls::{Certificate, PrivateKey};
1056    use std::net::SocketAddr;
1057    use tokio::net::{TcpListener, UdpSocket};
1058    use tokio::time::timeout;
1059
1060    #[tokio::test]
1061    async fn abort() {
1062        let endpoints = Endpoints::new().await;
1063
1064        let endpoints2 = endpoints.clone();
1065        let (abortable, abort_handle) = future::abortable(async move {
1066            let mut server_future = ServerFuture::new(Catalog::new());
1067            endpoints2.register(&mut server_future).await;
1068            server_future.block_until_done().await
1069        });
1070
1071        abort_handle.abort();
1072        abortable.await.expect_err("expected abort");
1073
1074        endpoints.rebind_all().await;
1075    }
1076
1077    #[tokio::test]
1078    async fn graceful_shutdown() {
1079        let mut server_future = ServerFuture::new(Catalog::new());
1080        let endpoints = Endpoints::new().await;
1081        endpoints.register(&mut server_future).await;
1082
1083        timeout(Duration::from_secs(2), server_future.shutdown_gracefully())
1084            .await
1085            .expect("timed out waiting for the server to complete")
1086            .expect("error while awaiting tasks");
1087
1088        endpoints.rebind_all().await;
1089    }
1090
1091    #[test]
1092    fn test_sanitize_src_addr() {
1093        // ipv4 tests
1094        assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 4096))).is_ok());
1095        assert!(sanitize_src_address(SocketAddr::from(([127, 0, 0, 1], 53))).is_ok());
1096
1097        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 0))).is_err());
1098        assert!(sanitize_src_address(SocketAddr::from(([192, 168, 1, 1], 0))).is_err());
1099        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0], 4096))).is_err());
1100        assert!(sanitize_src_address(SocketAddr::from(([255, 255, 255, 255], 4096))).is_err());
1101
1102        // ipv6 tests
1103        assert!(
1104            sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 4096))).is_ok()
1105        );
1106        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 4096))).is_ok());
1107
1108        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 4096))).is_err());
1109        assert!(sanitize_src_address(SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0))).is_err());
1110        assert!(
1111            sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err()
1112        );
1113    }
1114
1115    #[derive(Clone)]
1116    struct Endpoints {
1117        udp_addr: SocketAddr,
1118        udp_std_addr: SocketAddr,
1119        tcp_addr: SocketAddr,
1120        tcp_std_addr: SocketAddr,
1121        #[cfg(feature = "dns-over-rustls")]
1122        rustls_addr: SocketAddr,
1123        #[cfg(feature = "dns-over-https-rustls")]
1124        https_rustls_addr: SocketAddr,
1125        #[cfg(feature = "dns-over-quic")]
1126        quic_addr: SocketAddr,
1127    }
1128
1129    impl Endpoints {
1130        async fn new() -> Self {
1131            let udp = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1132            let udp_std = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1133            let tcp = TcpListener::bind("127.0.0.1:0").await.unwrap();
1134            let tcp_std = TcpListener::bind("127.0.0.1:0").await.unwrap();
1135            #[cfg(feature = "dns-over-rustls")]
1136            let rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1137            #[cfg(feature = "dns-over-https-rustls")]
1138            let https_rustls = TcpListener::bind("127.0.0.1:0").await.unwrap();
1139            #[cfg(feature = "dns-over-quic")]
1140            let quic = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1141
1142            Self {
1143                udp_addr: udp.local_addr().unwrap(),
1144                udp_std_addr: udp_std.local_addr().unwrap(),
1145                tcp_addr: tcp.local_addr().unwrap(),
1146                tcp_std_addr: tcp_std.local_addr().unwrap(),
1147                #[cfg(feature = "dns-over-rustls")]
1148                rustls_addr: rustls.local_addr().unwrap(),
1149                #[cfg(feature = "dns-over-https-rustls")]
1150                https_rustls_addr: https_rustls.local_addr().unwrap(),
1151                #[cfg(feature = "dns-over-quic")]
1152                quic_addr: quic.local_addr().unwrap(),
1153            }
1154        }
1155
1156        async fn register<T: RequestHandler>(&self, server: &mut ServerFuture<T>) {
1157            server.register_socket(UdpSocket::bind(self.udp_addr).await.unwrap());
1158            server
1159                .register_socket_std(std::net::UdpSocket::bind(self.udp_std_addr).unwrap())
1160                .unwrap();
1161            server.register_listener(
1162                TcpListener::bind(self.tcp_addr).await.unwrap(),
1163                Duration::from_secs(1),
1164            );
1165            server
1166                .register_listener_std(
1167                    std::net::TcpListener::bind(self.tcp_std_addr).unwrap(),
1168                    Duration::from_secs(1),
1169                )
1170                .unwrap();
1171
1172            #[cfg(feature = "dns-over-rustls")]
1173            {
1174                let cert_key = rustls_cert_key();
1175                server
1176                    .register_tls_listener(
1177                        TcpListener::bind(self.rustls_addr).await.unwrap(),
1178                        Duration::from_secs(30),
1179                        cert_key,
1180                    )
1181                    .unwrap();
1182            }
1183
1184            #[cfg(feature = "dns-over-https-rustls")]
1185            {
1186                let cert_key = rustls_cert_key();
1187                server
1188                    .register_https_listener(
1189                        TcpListener::bind(self.https_rustls_addr).await.unwrap(),
1190                        Duration::from_secs(1),
1191                        cert_key,
1192                        None,
1193                    )
1194                    .unwrap();
1195            }
1196
1197            #[cfg(feature = "dns-over-quic")]
1198            {
1199                let cert_key = rustls_cert_key();
1200                server
1201                    .register_quic_listener(
1202                        UdpSocket::bind(self.quic_addr).await.unwrap(),
1203                        Duration::from_secs(1),
1204                        cert_key,
1205                        None,
1206                    )
1207                    .unwrap();
1208            }
1209        }
1210
1211        async fn rebind_all(&self) {
1212            UdpSocket::bind(self.udp_addr).await.unwrap();
1213            UdpSocket::bind(self.udp_std_addr).await.unwrap();
1214            TcpListener::bind(self.tcp_addr).await.unwrap();
1215            TcpListener::bind(self.tcp_std_addr).await.unwrap();
1216            #[cfg(feature = "dns-over-rustls")]
1217            TcpListener::bind(self.rustls_addr).await.unwrap();
1218            #[cfg(feature = "dns-over-https-rustls")]
1219            TcpListener::bind(self.https_rustls_addr).await.unwrap();
1220            #[cfg(feature = "dns-over-quic")]
1221            UdpSocket::bind(self.quic_addr).await.unwrap();
1222        }
1223    }
1224
1225    #[cfg(feature = "dns-over-rustls")]
1226    fn rustls_cert_key() -> (Vec<Certificate>, PrivateKey) {
1227        use std::env;
1228        use std::path::Path;
1229        use trust_dns_proto::rustls::tls_server;
1230
1231        let server_path = env::var("TDNS_WORKSPACE_ROOT").unwrap_or_else(|_| "../..".to_owned());
1232
1233        let cert = tls_server::read_cert(Path::new(&format!(
1234            "{}/tests/test-data/cert.pem",
1235            server_path
1236        )))
1237        .map_err(|e| format!("error reading cert: {e}"))
1238        .unwrap();
1239        let key = tls_server::read_key_from_pem(Path::new(&format!(
1240            "{}/tests/test-data/cert.key",
1241            server_path
1242        )))
1243        .unwrap();
1244
1245        (cert, key)
1246    }
1247}