trust_dns_client/client/
async_client.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::{
9    future::Future,
10    pin::Pin,
11    sync::Arc,
12    task::{Context, Poll},
13    time::Duration,
14};
15
16use futures_util::{
17    ready,
18    stream::{Stream, StreamExt},
19};
20use rand;
21use tracing::debug;
22
23use crate::{
24    client::Signer,
25    error::*,
26    op::{Message, MessageType, OpCode, Query},
27    proto::{
28        error::{ProtoError, ProtoErrorKind},
29        op::{update_message, Edns},
30        xfer::{
31            BufDnsStreamHandle, DnsClientStream, DnsExchange, DnsExchangeBackground,
32            DnsExchangeSend, DnsHandle, DnsMultiplexer, DnsRequest, DnsRequestOptions,
33            DnsRequestSender, DnsResponse,
34        },
35        TokioTime,
36    },
37    rr::{rdata::SOA, DNSClass, Name, RData, Record, RecordSet, RecordType},
38};
39
40/// A DNS Client implemented over futures-rs.
41///
42/// This Client is generic and capable of wrapping UDP, TCP, and other underlying DNS protocol
43///  implementations.
44pub type ClientFuture = AsyncClient;
45
46/// A DNS Client implemented over futures-rs.
47///
48/// This Client is generic and capable of wrapping UDP, TCP, and other underlying DNS protocol
49///  implementations.
50#[derive(Clone)]
51pub struct AsyncClient {
52    exchange: DnsExchange,
53    use_edns: bool,
54}
55
56impl AsyncClient {
57    /// Spawns a new AsyncClient Stream. This uses a default timeout of 5 seconds for all requests.
58    ///
59    /// # Arguments
60    ///
61    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
62    ///              (see TcpClientStream or UdpClientStream)
63    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
64    /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
65    #[allow(clippy::new_ret_no_self)]
66    pub async fn new<F, S>(
67        stream: F,
68        stream_handle: BufDnsStreamHandle,
69        signer: Option<Arc<Signer>>,
70    ) -> Result<
71        (
72            Self,
73            DnsExchangeBackground<DnsMultiplexer<S, Signer>, TokioTime>,
74        ),
75        ProtoError,
76    >
77    where
78        F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
79        S: DnsClientStream + 'static + Unpin,
80    {
81        Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer).await
82    }
83
84    /// Spawns a new AsyncClient Stream.
85    ///
86    /// # Arguments
87    ///
88    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
89    ///              (see TcpClientStream or UdpClientStream)
90    /// * `timeout_duration` - All requests may fail due to lack of response, this is the time to
91    ///                        wait for a response before canceling the request.
92    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
93    /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
94    pub async fn with_timeout<F, S>(
95        stream: F,
96        stream_handle: BufDnsStreamHandle,
97        timeout_duration: Duration,
98        signer: Option<Arc<Signer>>,
99    ) -> Result<
100        (
101            Self,
102            DnsExchangeBackground<DnsMultiplexer<S, Signer>, TokioTime>,
103        ),
104        ProtoError,
105    >
106    where
107        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
108        S: DnsClientStream + 'static + Unpin,
109    {
110        let mp = DnsMultiplexer::with_timeout(stream, stream_handle, timeout_duration, signer);
111        Self::connect(mp).await
112    }
113
114    /// Returns a future, which itself wraps a future which is awaiting connection.
115    ///
116    /// The connect_future should be lazy.
117    ///
118    /// # Returns
119    ///
120    /// This returns a tuple of Self a handle to send dns messages and an optional background.
121    ///  The background task must be run on an executor before handle is used, if it is Some.
122    ///  If it is None, then another thread has already run the background.
123    pub async fn connect<F, S>(
124        connect_future: F,
125    ) -> Result<(Self, DnsExchangeBackground<S, TokioTime>), ProtoError>
126    where
127        S: DnsRequestSender,
128        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
129    {
130        let result = DnsExchange::connect(connect_future).await;
131        let use_edns = true;
132        result.map(|(exchange, bg)| (Self { exchange, use_edns }, bg))
133    }
134
135    /// (Re-)enable usage of EDNS for outgoing messages
136    pub fn enable_edns(&mut self) {
137        self.use_edns = true;
138    }
139
140    /// Disable usage of EDNS for outgoing messages
141    pub fn disable_edns(&mut self) {
142        self.use_edns = false;
143    }
144}
145
146impl DnsHandle for AsyncClient {
147    type Response = DnsExchangeSend;
148    type Error = ProtoError;
149
150    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
151        self.exchange.send(request)
152    }
153
154    fn is_using_edns(&self) -> bool {
155        self.use_edns
156    }
157}
158
159impl<T> ClientHandle for T where T: DnsHandle<Error = ProtoError> {}
160
161/// A trait for implementing high level functions of DNS.
162pub trait ClientHandle: 'static + Clone + DnsHandle<Error = ProtoError> + Send {
163    /// A *classic* DNS query
164    ///
165    /// *Note* As of now, this will not recurse on PTR or CNAME record responses, that is up to
166    ///        the caller.
167    ///
168    /// # Arguments
169    ///
170    /// * `name` - the label to lookup
171    /// * `query_class` - most likely this should always be DNSClass::IN
172    /// * `query_type` - record type to lookup
173    fn query(
174        &mut self,
175        name: Name,
176        query_class: DNSClass,
177        query_type: RecordType,
178    ) -> ClientResponse<<Self as DnsHandle>::Response> {
179        let mut query = Query::query(name, query_type);
180        query.set_query_class(query_class);
181        let mut options = DnsRequestOptions::default();
182        options.use_edns = self.is_using_edns();
183        ClientResponse(self.lookup(query, options))
184    }
185
186    /// Sends a NOTIFY message to the remote system
187    ///
188    /// [RFC 1996](https://tools.ietf.org/html/rfc1996), DNS NOTIFY, August 1996
189    ///
190    ///
191    /// ```text
192    /// 1. Rationale and Scope
193    ///
194    ///   1.1. Slow propagation of new and changed data in a DNS zone can be
195    ///   due to a zone's relatively long refresh times.  Longer refresh times
196    ///   are beneficial in that they reduce load on the Primary Zone Servers, but
197    ///   that benefit comes at the cost of long intervals of incoherence among
198    ///   authority servers whenever the zone is updated.
199    ///
200    ///   1.2. The DNS NOTIFY transaction allows Primary Zone Servers to inform Secondary
201    ///   Zone Servers when the zone has changed -- an interrupt as opposed to poll
202    ///   model -- which it is hoped will reduce propagation delay while not
203    ///   unduly increasing the masters' load.  This specification only allows
204    ///   slaves to be notified of SOA RR changes, but the architecture of
205    ///   NOTIFY is intended to be extensible to other RR types.
206    ///
207    ///   1.3. This document intentionally gives more definition to the roles
208    ///   of "Primary", "Secondary" and "Stealth" servers, their enumeration in NS
209    ///   RRs, and the SOA MNAME field.  In that sense, this document can be
210    ///   considered an addendum to [RFC1035].
211    ///
212    /// ```
213    ///
214    /// The below section describes how the Notify message should be constructed. The function
215    ///  implementation accepts a Record, but the actual data of the record should be ignored by the
216    ///  server, i.e. the server should make a request subsequent to receiving this Notification for
217    ///  the authority record, but could be used to decide to request an update or not:
218    ///
219    /// ```text
220    ///   3.7. A NOTIFY request has QDCOUNT>0, ANCOUNT>=0, AUCOUNT>=0,
221    ///   ADCOUNT>=0.  If ANCOUNT>0, then the answer section represents an
222    ///   unsecure hint at the new RRset for this <QNAME,QCLASS,QTYPE>.  A
223    ///   Secondary receiving such a hint is free to treat equivalence of this
224    ///   answer section with its local data as a "no further work needs to be
225    ///   done" indication.  If ANCOUNT=0, or ANCOUNT>0 and the answer section
226    ///   differs from the Secondary's local data, then the Secondary should query its
227    ///   known Primaries to retrieve the new data.
228    /// ```
229    ///
230    /// Client's should be ready to handle, or be aware of, a server response of NOTIMP:
231    ///
232    /// ```text
233    ///   3.12. If a NOTIFY request is received by a Secondary who does not
234    ///   implement the NOTIFY opcode, it will respond with a NOTIMP
235    ///   (unimplemented feature error) message.  A Primary Zone Server who receives
236    ///   such a NOTIMP should consider the NOTIFY transaction complete for
237    ///   that Secondary.
238    /// ```
239    ///
240    /// # Arguments
241    ///
242    /// * `name` - the label which is being notified
243    /// * `query_class` - most likely this should always be DNSClass::IN
244    /// * `query_type` - record type which has been updated
245    /// * `rrset` - the new version of the record(s) being notified
246    fn notify<R>(
247        &mut self,
248        name: Name,
249        query_class: DNSClass,
250        query_type: RecordType,
251        rrset: Option<R>,
252    ) -> ClientResponse<<Self as DnsHandle>::Response>
253    where
254        R: Into<RecordSet>,
255    {
256        debug!("notifying: {} {:?}", name, query_type);
257
258        // build the message
259        let mut message: Message = Message::new();
260        let id: u16 = rand::random();
261        message
262            .set_id(id)
263            // 3.3. NOTIFY is similar to QUERY in that it has a request message with
264            // the header QR flag "clear" and a response message with QR "set".  The
265            // response message contains no useful information, but its reception by
266            // the Primary is an indication that the Secondary has received the NOTIFY
267            // and that the Primary Zone Server can remove the Secondary from any retry queue for
268            // this NOTIFY event.
269            .set_message_type(MessageType::Query)
270            .set_op_code(OpCode::Notify);
271
272        // Extended dns
273        if self.is_using_edns() {
274            message
275                .extensions_mut()
276                .get_or_insert_with(Edns::new)
277                .set_max_payload(update_message::MAX_PAYLOAD_LEN)
278                .set_version(0);
279        }
280
281        // add the query
282        let mut query: Query = Query::new();
283        query
284            .set_name(name)
285            .set_query_class(query_class)
286            .set_query_type(query_type);
287        message.add_query(query);
288
289        // add the notify message, see https://tools.ietf.org/html/rfc1996, section 3.7
290        if let Some(rrset) = rrset {
291            message.add_answers(rrset.into());
292        }
293
294        ClientResponse(self.send(message))
295    }
296
297    /// Sends a record to create on the server, this will fail if the record exists (atomicity
298    ///  depends on the server)
299    ///
300    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
301    ///
302    /// ```text
303    ///  2.4.3 - RRset Does Not Exist
304    ///
305    ///   No RRs with a specified NAME and TYPE (in the zone and class denoted
306    ///   by the Zone Section) can exist.
307    ///
308    ///   For this prerequisite, a requestor adds to the section a single RR
309    ///   whose NAME and TYPE are equal to that of the RRset whose nonexistence
310    ///   is required.  The RDLENGTH of this record is zero (0), and RDATA
311    ///   field is therefore empty.  CLASS must be specified as NONE in order
312    ///   to distinguish this condition from a valid RR whose RDLENGTH is
313    ///   naturally zero (0) (for example, the NULL RR).  TTL must be specified
314    ///   as zero (0).
315    ///
316    /// 2.5.1 - Add To An RRset
317    ///
318    ///    RRs are added to the Update Section whose NAME, TYPE, TTL, RDLENGTH
319    ///    and RDATA are those being added, and CLASS is the same as the zone
320    ///    class.  Any duplicate RRs will be silently ignored by the Primary Zone
321    ///    Server.
322    /// ```
323    ///
324    /// # Arguments
325    ///
326    /// * `rrset` - the record(s) to create
327    /// * `zone_origin` - the zone name to update, i.e. SOA name
328    ///
329    /// The update must go to a zone authority (i.e. the server used in the ClientConnection)
330    fn create<R>(
331        &mut self,
332        rrset: R,
333        zone_origin: Name,
334    ) -> ClientResponse<<Self as DnsHandle>::Response>
335    where
336        R: Into<RecordSet>,
337    {
338        let rrset = rrset.into();
339        let message = update_message::create(rrset, zone_origin, self.is_using_edns());
340
341        ClientResponse(self.send(message))
342    }
343
344    /// Appends a record to an existing rrset, optionally require the rrset to exist (atomicity
345    ///  depends on the server)
346    ///
347    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
348    ///
349    /// ```text
350    /// 2.4.1 - RRset Exists (Value Independent)
351    ///
352    ///   At least one RR with a specified NAME and TYPE (in the zone and class
353    ///   specified in the Zone Section) must exist.
354    ///
355    ///   For this prerequisite, a requestor adds to the section a single RR
356    ///   whose NAME and TYPE are equal to that of the zone RRset whose
357    ///   existence is required.  RDLENGTH is zero and RDATA is therefore
358    ///   empty.  CLASS must be specified as ANY to differentiate this
359    ///   condition from that of an actual RR whose RDLENGTH is naturally zero
360    ///   (0) (e.g., NULL).  TTL is specified as zero (0).
361    ///
362    /// 2.5.1 - Add To An RRset
363    ///
364    ///    RRs are added to the Update Section whose NAME, TYPE, TTL, RDLENGTH
365    ///    and RDATA are those being added, and CLASS is the same as the zone
366    ///    class.  Any duplicate RRs will be silently ignored by the Primary Zone
367    ///    Server.
368    /// ```
369    ///
370    /// # Arguments
371    ///
372    /// * `rrset` - the record(s) to append to an RRSet
373    /// * `zone_origin` - the zone name to update, i.e. SOA name
374    /// * `must_exist` - if true, the request will fail if the record does not exist
375    ///
376    /// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
377    /// the rrset does not exist and must_exist is false, then the RRSet will be created.
378    fn append<R>(
379        &mut self,
380        rrset: R,
381        zone_origin: Name,
382        must_exist: bool,
383    ) -> ClientResponse<<Self as DnsHandle>::Response>
384    where
385        R: Into<RecordSet>,
386    {
387        let rrset = rrset.into();
388        let message = update_message::append(rrset, zone_origin, must_exist, self.is_using_edns());
389
390        ClientResponse(self.send(message))
391    }
392
393    /// Compares and if it matches, swaps it for the new value (atomicity depends on the server)
394    ///
395    /// ```text
396    ///  2.4.2 - RRset Exists (Value Dependent)
397    ///
398    ///   A set of RRs with a specified NAME and TYPE exists and has the same
399    ///   members with the same RDATAs as the RRset specified here in this
400    ///   section.  While RRset ordering is undefined and therefore not
401    ///   significant to this comparison, the sets be identical in their
402    ///   extent.
403    ///
404    ///   For this prerequisite, a requestor adds to the section an entire
405    ///   RRset whose preexistence is required.  NAME and TYPE are that of the
406    ///   RRset being denoted.  CLASS is that of the zone.  TTL must be
407    ///   specified as zero (0) and is ignored when comparing RRsets for
408    ///   identity.
409    ///
410    ///  2.5.4 - Delete An RR From An RRset
411    ///
412    ///   RRs to be deleted are added to the Update Section.  The NAME, TYPE,
413    ///   RDLENGTH and RDATA must match the RR being deleted.  TTL must be
414    ///   specified as zero (0) and will otherwise be ignored by the Primary
415    ///   Zone Server.  CLASS must be specified as NONE to distinguish this from an
416    ///   RR addition.  If no such RRs exist, then this Update RR will be
417    ///   silently ignored by the Primary Zone Server.
418    ///
419    ///  2.5.1 - Add To An RRset
420    ///
421    ///   RRs are added to the Update Section whose NAME, TYPE, TTL, RDLENGTH
422    ///   and RDATA are those being added, and CLASS is the same as the zone
423    ///   class.  Any duplicate RRs will be silently ignored by the Primary
424    ///   Zone Server.
425    /// ```
426    ///
427    /// # Arguments
428    ///
429    /// * `current` - the current rrset which must exist for the swap to complete
430    /// * `new` - the new rrset with which to replace the current rrset
431    /// * `zone_origin` - the zone name to update, i.e. SOA name
432    ///
433    /// The update must go to a zone authority (i.e. the server used in the ClientConnection).
434    fn compare_and_swap<C, N>(
435        &mut self,
436        current: C,
437        new: N,
438        zone_origin: Name,
439    ) -> ClientResponse<<Self as DnsHandle>::Response>
440    where
441        C: Into<RecordSet>,
442        N: Into<RecordSet>,
443    {
444        let current = current.into();
445        let new = new.into();
446
447        let message =
448            update_message::compare_and_swap(current, new, zone_origin, self.is_using_edns());
449        ClientResponse(self.send(message))
450    }
451
452    /// Deletes a record (by rdata) from an rrset, optionally require the rrset to exist.
453    ///
454    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
455    ///
456    /// ```text
457    /// 2.4.1 - RRset Exists (Value Independent)
458    ///
459    ///   At least one RR with a specified NAME and TYPE (in the zone and class
460    ///   specified in the Zone Section) must exist.
461    ///
462    ///   For this prerequisite, a requestor adds to the section a single RR
463    ///   whose NAME and TYPE are equal to that of the zone RRset whose
464    ///   existence is required.  RDLENGTH is zero and RDATA is therefore
465    ///   empty.  CLASS must be specified as ANY to differentiate this
466    ///   condition from that of an actual RR whose RDLENGTH is naturally zero
467    ///   (0) (e.g., NULL).  TTL is specified as zero (0).
468    ///
469    /// 2.5.4 - Delete An RR From An RRset
470    ///
471    ///   RRs to be deleted are added to the Update Section.  The NAME, TYPE,
472    ///   RDLENGTH and RDATA must match the RR being deleted.  TTL must be
473    ///   specified as zero (0) and will otherwise be ignored by the Primary
474    ///   Zone Server.  CLASS must be specified as NONE to distinguish this from an
475    ///   RR addition.  If no such RRs exist, then this Update RR will be
476    ///   silently ignored by the Primary Zone Server.
477    /// ```
478    ///
479    /// # Arguments
480    ///
481    /// * `rrset` - the record(s) to delete from a RRSet, the name, type and rdata must match the
482    ///              record to delete
483    /// * `zone_origin` - the zone name to update, i.e. SOA name
484    /// * `signer` - the signer, with private key, to use to sign the request
485    ///
486    /// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
487    /// the rrset does not exist and must_exist is false, then the RRSet will be deleted.
488    fn delete_by_rdata<R>(
489        &mut self,
490        rrset: R,
491        zone_origin: Name,
492    ) -> ClientResponse<<Self as DnsHandle>::Response>
493    where
494        R: Into<RecordSet>,
495    {
496        let rrset = rrset.into();
497        let message = update_message::delete_by_rdata(rrset, zone_origin, self.is_using_edns());
498
499        ClientResponse(self.send(message))
500    }
501
502    /// Deletes an entire rrset, optionally require the rrset to exist.
503    ///
504    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
505    ///
506    /// ```text
507    /// 2.4.1 - RRset Exists (Value Independent)
508    ///
509    ///   At least one RR with a specified NAME and TYPE (in the zone and class
510    ///   specified in the Zone Section) must exist.
511    ///
512    ///   For this prerequisite, a requestor adds to the section a single RR
513    ///   whose NAME and TYPE are equal to that of the zone RRset whose
514    ///   existence is required.  RDLENGTH is zero and RDATA is therefore
515    ///   empty.  CLASS must be specified as ANY to differentiate this
516    ///   condition from that of an actual RR whose RDLENGTH is naturally zero
517    ///   (0) (e.g., NULL).  TTL is specified as zero (0).
518    ///
519    /// 2.5.2 - Delete An RRset
520    ///
521    ///   One RR is added to the Update Section whose NAME and TYPE are those
522    ///   of the RRset to be deleted.  TTL must be specified as zero (0) and is
523    ///   otherwise not used by the Primary Zone Server.  CLASS must be specified as
524    ///   ANY.  RDLENGTH must be zero (0) and RDATA must therefore be empty.
525    ///   If no such RRset exists, then this Update RR will be silently ignored
526    ///   by the Primary Zone Server.
527    /// ```
528    ///
529    /// # Arguments
530    ///
531    /// * `record` - The name, class and record_type will be used to match and delete the RecordSet
532    /// * `zone_origin` - the zone name to update, i.e. SOA name
533    ///
534    /// The update must go to a zone authority (i.e. the server used in the ClientConnection). If
535    /// the rrset does not exist and must_exist is false, then the RRSet will be deleted.
536    fn delete_rrset(
537        &mut self,
538        record: Record,
539        zone_origin: Name,
540    ) -> ClientResponse<<Self as DnsHandle>::Response> {
541        assert!(zone_origin.zone_of(record.name()));
542        let message = update_message::delete_rrset(record, zone_origin, self.is_using_edns());
543
544        ClientResponse(self.send(message))
545    }
546
547    /// Deletes all records at the specified name
548    ///
549    /// [RFC 2136](https://tools.ietf.org/html/rfc2136), DNS Update, April 1997
550    ///
551    /// ```text
552    /// 2.5.3 - Delete All RRsets From A Name
553    ///
554    ///   One RR is added to the Update Section whose NAME is that of the name
555    ///   to be cleansed of RRsets.  TYPE must be specified as ANY.  TTL must
556    ///   be specified as zero (0) and is otherwise not used by the Primary
557    ///   Zone Server.  CLASS must be specified as ANY.  RDLENGTH must be zero (0)
558    ///   and RDATA must therefore be empty.  If no such RRsets exist, then
559    ///   this Update RR will be silently ignored by the Primary Zone Server.
560    /// ```
561    ///
562    /// # Arguments
563    ///
564    /// * `name_of_records` - the name of all the record sets to delete
565    /// * `zone_origin` - the zone name to update, i.e. SOA name
566    /// * `dns_class` - the class of the SOA
567    ///
568    /// The update must go to a zone authority (i.e. the server used in the ClientConnection). This
569    /// operation attempts to delete all resource record sets the specified name regardless of
570    /// the record type.
571    fn delete_all(
572        &mut self,
573        name_of_records: Name,
574        zone_origin: Name,
575        dns_class: DNSClass,
576    ) -> ClientResponse<<Self as DnsHandle>::Response> {
577        assert!(zone_origin.zone_of(&name_of_records));
578        let message = update_message::delete_all(
579            name_of_records,
580            zone_origin,
581            dns_class,
582            self.is_using_edns(),
583        );
584
585        ClientResponse(self.send(message))
586    }
587
588    /// Download all records from a zone, or all records modified since given SOA was observed.
589    /// The request will either be a AXFR Query (ask for full zone transfer) if a SOA was not
590    /// provided, or a IXFR Query (incremental zone transfer) if a SOA was provided.
591    ///
592    /// # Arguments
593    /// * `zone_origin` - the zone name to update, i.e. SOA name
594    /// * `last_soa` - the last SOA known, if any. If provided, name must match `zone_origin`
595
596    fn zone_transfer(
597        &mut self,
598        zone_origin: Name,
599        last_soa: Option<SOA>,
600    ) -> ClientStreamXfr<<Self as DnsHandle>::Response> {
601        let ixfr = last_soa.is_some();
602        let message = update_message::zone_transfer(zone_origin, last_soa);
603
604        ClientStreamXfr::new(self.send(message), ixfr)
605    }
606}
607
608/// A stream result of a Client Request
609#[must_use = "stream do nothing unless polled"]
610pub struct ClientStreamingResponse<R>(pub(crate) R)
611where
612    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
613
614impl<R> Stream for ClientStreamingResponse<R>
615where
616    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
617{
618    type Item = Result<DnsResponse, ClientError>;
619
620    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
621        Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ClientError::from)))
622    }
623}
624
625/// A future result of a Client Request
626#[must_use = "futures do nothing unless polled"]
627pub struct ClientResponse<R>(pub(crate) R)
628where
629    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
630
631impl<R> Future for ClientResponse<R>
632where
633    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
634{
635    type Output = Result<DnsResponse, ClientError>;
636
637    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
638        Poll::Ready(
639            match ready!(self.0.poll_next_unpin(cx)) {
640                Some(r) => r,
641                None => Err(ProtoError::from(ProtoErrorKind::Timeout)),
642            }
643            .map_err(ClientError::from),
644        )
645    }
646}
647
648/// A stream result of a zone transfer Client Request
649/// Accept messages until the end of a zone transfer. For AXFR, it search for a starting and an
650/// ending SOA. For IXFR, it do so taking into account there will be other SOA inbetween
651#[must_use = "stream do nothing unless polled"]
652pub struct ClientStreamXfr<R>
653where
654    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
655{
656    state: ClientStreamXfrState<R>,
657}
658
659impl<R> ClientStreamXfr<R>
660where
661    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
662{
663    fn new(inner: R, maybe_incr: bool) -> Self {
664        Self {
665            state: ClientStreamXfrState::Start { inner, maybe_incr },
666        }
667    }
668}
669
670/// State machine for ClientStreamXfr, implementing almost all logic
671#[derive(Debug)]
672enum ClientStreamXfrState<R> {
673    Start {
674        inner: R,
675        maybe_incr: bool,
676    },
677    Second {
678        inner: R,
679        expected_serial: u32,
680        maybe_incr: bool,
681    },
682    Axfr {
683        inner: R,
684        expected_serial: u32,
685    },
686    Ixfr {
687        inner: R,
688        even: bool,
689        expected_serial: u32,
690    },
691    Ended,
692    Invalid,
693}
694
695impl<R> ClientStreamXfrState<R> {
696    /// Helper to get the stream from the enum
697    fn inner(&mut self) -> &mut R {
698        use ClientStreamXfrState::*;
699        match self {
700            Start { inner, .. } => inner,
701            Second { inner, .. } => inner,
702            Axfr { inner, .. } => inner,
703            Ixfr { inner, .. } => inner,
704            Ended | Invalid => unreachable!(),
705        }
706    }
707
708    /// Helper to ingest answer Records
709    // TODO: this is complex enough it should get its own tests
710    fn process(&mut self, answers: &[Record]) -> Result<(), ClientError> {
711        use ClientStreamXfrState::*;
712        fn get_serial(r: &Record) -> Option<u32> {
713            r.data().and_then(RData::as_soa).map(SOA::serial)
714        }
715
716        if answers.is_empty() {
717            return Ok(());
718        }
719        match std::mem::replace(self, Invalid) {
720            Start { inner, maybe_incr } => {
721                if let Some(expected_serial) = get_serial(&answers[0]) {
722                    *self = Second {
723                        inner,
724                        maybe_incr,
725                        expected_serial,
726                    };
727                    self.process(&answers[1..])
728                } else {
729                    *self = Ended;
730                    Ok(())
731                }
732            }
733            Second {
734                inner,
735                maybe_incr,
736                expected_serial,
737            } => {
738                if let Some(serial) = get_serial(&answers[0]) {
739                    // maybe IXFR, or empty AXFR
740                    if serial == expected_serial {
741                        // empty AXFR
742                        *self = Ended;
743                        if answers.len() == 1 {
744                            Ok(())
745                        } else {
746                            // invalid answer : trailing records
747                            Err(ClientErrorKind::Message(
748                                "invalid zone transfer, contains trailing records",
749                            )
750                            .into())
751                        }
752                    } else if maybe_incr {
753                        *self = Ixfr {
754                            inner,
755                            expected_serial,
756                            even: true,
757                        };
758                        self.process(&answers[1..])
759                    } else {
760                        *self = Ended;
761                        Err(ClientErrorKind::Message(
762                            "invalid zone transfer, expected AXFR, got IXFR",
763                        )
764                        .into())
765                    }
766                } else {
767                    // standard AXFR
768                    *self = Axfr {
769                        inner,
770                        expected_serial,
771                    };
772                    self.process(&answers[1..])
773                }
774            }
775            Axfr {
776                inner,
777                expected_serial,
778            } => {
779                let soa_count = answers
780                    .iter()
781                    .filter(|a| a.record_type() == RecordType::SOA)
782                    .count();
783                match soa_count {
784                    0 => {
785                        *self = Axfr {
786                            inner,
787                            expected_serial,
788                        };
789                        Ok(())
790                    }
791                    1 => {
792                        *self = Ended;
793                        match answers.last().map(|r| r.record_type()) {
794                            Some(RecordType::SOA) => Ok(()),
795                            _ => Err(ClientErrorKind::Message(
796                                "invalid zone transfer, contains trailing records",
797                            )
798                            .into()),
799                        }
800                    }
801                    _ => {
802                        *self = Ended;
803                        Err(ClientErrorKind::Message(
804                            "invalid zone transfer, contains trailing records",
805                        )
806                        .into())
807                    }
808                }
809            }
810            Ixfr {
811                inner,
812                even,
813                expected_serial,
814            } => {
815                let even = answers
816                    .iter()
817                    .fold(even, |even, a| even ^ (a.record_type() == RecordType::SOA));
818                if even {
819                    if let Some(serial) = get_serial(answers.last().unwrap()) {
820                        if serial == expected_serial {
821                            *self = Ended;
822                            return Ok(());
823                        }
824                    }
825                }
826                *self = Ixfr {
827                    inner,
828                    even,
829                    expected_serial,
830                };
831                Ok(())
832            }
833            Ended | Invalid => {
834                unreachable!();
835            }
836        }
837    }
838}
839
840impl<R> Stream for ClientStreamXfr<R>
841where
842    R: Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
843{
844    type Item = Result<DnsResponse, ClientError>;
845
846    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
847        use ClientStreamXfrState::*;
848
849        if matches!(self.state, Ended) {
850            return Poll::Ready(None);
851        }
852
853        let message = ready!(self.state.inner().poll_next_unpin(cx)).map(|response| {
854            let ok = response?;
855            self.state.process(ok.answers())?;
856            Ok(ok)
857        });
858        Poll::Ready(message)
859    }
860}
861
862#[cfg(test)]
863mod tests {
864    use super::*;
865
866    use crate::rr::rdata::{A, SOA};
867    use futures_util::stream::iter;
868    use ClientStreamXfrState::*;
869
870    fn soa_record(serial: u32) -> Record {
871        let soa = RData::SOA(SOA::new(
872            Name::from_ascii("example.com.").unwrap(),
873            Name::from_ascii("admin.example.com.").unwrap(),
874            serial,
875            60,
876            60,
877            60,
878            60,
879        ));
880        Record::from_rdata(Name::from_ascii("example.com.").unwrap(), 600, soa)
881    }
882
883    fn a_record(ip: u8) -> Record {
884        let a = RData::A(A::new(0, 0, 0, ip));
885        Record::from_rdata(Name::from_ascii("www.example.com.").unwrap(), 600, a)
886    }
887
888    fn get_stream_testcase(
889        records: Vec<Vec<Record>>,
890    ) -> impl Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static {
891        let stream = records.into_iter().map(|r| {
892            Ok({
893                let mut m = Message::new();
894                m.insert_answers(r);
895                DnsResponse::from_message(m).unwrap()
896            })
897        });
898        iter(stream)
899    }
900
901    #[tokio::test]
902    async fn test_stream_xfr_valid_axfr() {
903        let stream = get_stream_testcase(vec![vec![
904            soa_record(3),
905            a_record(1),
906            a_record(2),
907            soa_record(3),
908        ]]);
909        let mut stream = ClientStreamXfr::new(stream, false);
910        assert!(matches!(stream.state, Start { .. }));
911
912        let response = stream.next().await.unwrap().unwrap();
913        assert!(matches!(stream.state, Ended));
914        assert_eq!(response.answers().len(), 4);
915
916        assert!(stream.next().await.is_none());
917    }
918
919    #[tokio::test]
920    async fn test_stream_xfr_valid_axfr_multipart() {
921        let stream = get_stream_testcase(vec![
922            vec![soa_record(3)],
923            vec![a_record(1)],
924            vec![soa_record(3)],
925            vec![a_record(2)], // will be ignored as connection is dropped before reading this message
926        ]);
927        let mut stream = ClientStreamXfr::new(stream, false);
928        assert!(matches!(stream.state, Start { .. }));
929
930        let response = stream.next().await.unwrap().unwrap();
931        assert!(matches!(stream.state, Second { .. }));
932        assert_eq!(response.answers().len(), 1);
933
934        let response = stream.next().await.unwrap().unwrap();
935        assert!(matches!(stream.state, Axfr { .. }));
936        assert_eq!(response.answers().len(), 1);
937
938        let response = stream.next().await.unwrap().unwrap();
939        assert!(matches!(stream.state, Ended));
940        assert_eq!(response.answers().len(), 1);
941
942        assert!(stream.next().await.is_none());
943    }
944
945    #[tokio::test]
946    async fn test_stream_xfr_empty_axfr() {
947        let stream = get_stream_testcase(vec![vec![soa_record(3)], vec![soa_record(3)]]);
948        let mut stream = ClientStreamXfr::new(stream, false);
949        assert!(matches!(stream.state, Start { .. }));
950
951        let response = stream.next().await.unwrap().unwrap();
952        assert!(matches!(stream.state, Second { .. }));
953        assert_eq!(response.answers().len(), 1);
954
955        let response = stream.next().await.unwrap().unwrap();
956        assert!(matches!(stream.state, Ended));
957        assert_eq!(response.answers().len(), 1);
958
959        assert!(stream.next().await.is_none());
960    }
961
962    #[tokio::test]
963    async fn test_stream_xfr_axfr_with_ixfr_reply() {
964        let stream = get_stream_testcase(vec![vec![
965            soa_record(3),
966            soa_record(2),
967            a_record(1),
968            soa_record(3),
969            a_record(2),
970            soa_record(3),
971        ]]);
972        let mut stream = ClientStreamXfr::new(stream, false);
973        assert!(matches!(stream.state, Start { .. }));
974
975        stream.next().await.unwrap().unwrap_err();
976        assert!(matches!(stream.state, Ended));
977
978        assert!(stream.next().await.is_none());
979    }
980
981    #[tokio::test]
982    async fn test_stream_xfr_axfr_with_non_xfr_reply() {
983        let stream = get_stream_testcase(vec![
984            vec![a_record(1)], // assume this is an error response, not a zone transfer
985            vec![a_record(2)],
986        ]);
987        let mut stream = ClientStreamXfr::new(stream, false);
988        assert!(matches!(stream.state, Start { .. }));
989
990        let response = stream.next().await.unwrap().unwrap();
991        assert!(matches!(stream.state, Ended));
992        assert_eq!(response.answers().len(), 1);
993
994        assert!(stream.next().await.is_none());
995    }
996
997    #[tokio::test]
998    async fn test_stream_xfr_invalid_axfr_multipart() {
999        let stream = get_stream_testcase(vec![
1000            vec![soa_record(3)],
1001            vec![a_record(1)],
1002            vec![soa_record(3), a_record(2)],
1003            vec![soa_record(3)],
1004        ]);
1005        let mut stream = ClientStreamXfr::new(stream, false);
1006        assert!(matches!(stream.state, Start { .. }));
1007
1008        let response = stream.next().await.unwrap().unwrap();
1009        assert!(matches!(stream.state, Second { .. }));
1010        assert_eq!(response.answers().len(), 1);
1011
1012        let response = stream.next().await.unwrap().unwrap();
1013        assert!(matches!(stream.state, Axfr { .. }));
1014        assert_eq!(response.answers().len(), 1);
1015
1016        stream.next().await.unwrap().unwrap_err();
1017        assert!(matches!(stream.state, Ended));
1018
1019        assert!(stream.next().await.is_none());
1020    }
1021
1022    #[tokio::test]
1023    async fn test_stream_xfr_valid_ixfr() {
1024        let stream = get_stream_testcase(vec![vec![
1025            soa_record(3),
1026            soa_record(2),
1027            a_record(1),
1028            soa_record(3),
1029            a_record(2),
1030            soa_record(3),
1031        ]]);
1032        let mut stream = ClientStreamXfr::new(stream, true);
1033        assert!(matches!(stream.state, Start { .. }));
1034
1035        let response = stream.next().await.unwrap().unwrap();
1036        assert!(matches!(stream.state, Ended));
1037        assert_eq!(response.answers().len(), 6);
1038
1039        assert!(stream.next().await.is_none());
1040    }
1041
1042    #[tokio::test]
1043    async fn test_stream_xfr_valid_ixfr_multipart() {
1044        let stream = get_stream_testcase(vec![
1045            vec![soa_record(3)],
1046            vec![soa_record(2)],
1047            vec![a_record(1)],
1048            vec![soa_record(3)],
1049            vec![a_record(2)],
1050            vec![soa_record(3)],
1051            vec![a_record(3)], //
1052        ]);
1053        let mut stream = ClientStreamXfr::new(stream, true);
1054        assert!(matches!(stream.state, Start { .. }));
1055
1056        let response = stream.next().await.unwrap().unwrap();
1057        assert!(matches!(stream.state, Second { .. }));
1058        assert_eq!(response.answers().len(), 1);
1059
1060        let response = stream.next().await.unwrap().unwrap();
1061        assert!(matches!(stream.state, Ixfr { even: true, .. }));
1062        assert_eq!(response.answers().len(), 1);
1063
1064        let response = stream.next().await.unwrap().unwrap();
1065        assert!(matches!(stream.state, Ixfr { even: true, .. }));
1066        assert_eq!(response.answers().len(), 1);
1067
1068        let response = stream.next().await.unwrap().unwrap();
1069        assert!(matches!(stream.state, Ixfr { even: false, .. }));
1070        assert_eq!(response.answers().len(), 1);
1071
1072        let response = stream.next().await.unwrap().unwrap();
1073        assert!(matches!(stream.state, Ixfr { even: false, .. }));
1074        assert_eq!(response.answers().len(), 1);
1075
1076        let response = stream.next().await.unwrap().unwrap();
1077        assert!(matches!(stream.state, Ended));
1078        assert_eq!(response.answers().len(), 1);
1079
1080        assert!(stream.next().await.is_none());
1081    }
1082
1083    #[tokio::test]
1084    async fn async_client() {
1085        use crate::client::{AsyncClient, ClientHandle};
1086        use crate::proto::iocompat::AsyncIoTokioAsStd;
1087        use crate::rr::{DNSClass, Name, RData, RecordType};
1088        use crate::tcp::TcpClientStream;
1089        use std::str::FromStr;
1090        use tokio::net::TcpStream as TokioTcpStream;
1091
1092        // Since we used UDP in the previous examples, let's change things up a bit and use TCP here
1093        let (stream, sender) =
1094            TcpClientStream::<AsyncIoTokioAsStd<TokioTcpStream>>::new(([8, 8, 8, 8], 53).into());
1095
1096        // Create a new client, the bg is a background future which handles
1097        //   the multiplexing of the DNS requests to the server.
1098        //   the client is a handle to an unbounded queue for sending requests via the
1099        //   background. The background must be scheduled to run before the client can
1100        //   send any dns requests
1101        let client = AsyncClient::new(stream, sender, None);
1102
1103        // await the connection to be established
1104        let (mut client, bg) = client.await.expect("connection failed");
1105
1106        // make sure to run the background task
1107        tokio::spawn(bg);
1108
1109        // Create a query future
1110        let query = client.query(
1111            Name::from_str("www.example.com.").unwrap(),
1112            DNSClass::IN,
1113            RecordType::A,
1114        );
1115
1116        // wait for its response
1117        let (message_returned, buffer) = query.await.unwrap().into_parts();
1118
1119        // validate it's what we expected
1120        if let Some(RData::A(addr)) = message_returned.answers()[0].data() {
1121            assert_eq!(*addr, A::new(93, 184, 216, 34));
1122        }
1123
1124        let message_parsed = Message::from_vec(&buffer)
1125            .expect("buffer was parsed already by AsyncClient so we should be able to do it again");
1126
1127        // validate it's what we expected
1128        if let Some(RData::A(addr)) = message_parsed.answers()[0].data() {
1129            assert_eq!(*addr, A::new(93, 184, 216, 34));
1130        }
1131    }
1132}