trust_dns_client/client/
memoize_client_handle.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use futures_util::future::FutureExt;
12use futures_util::lock::Mutex;
13use futures_util::stream::Stream;
14use trust_dns_proto::{
15    error::ProtoError,
16    xfer::{DnsHandle, DnsRequest, DnsResponse},
17};
18
19use crate::client::rc_stream::{rc_stream, RcStream};
20use crate::client::ClientHandle;
21use crate::op::Query;
22
23// TODO: move to proto
24/// A ClientHandle for memoized (cached) responses to queries.
25///
26/// This wraps a ClientHandle, changing the implementation `send()` to store the response against
27///  the Message.Query that was sent. This should reduce network traffic especially during things
28///  like DNSSEC validation. *Warning* this will currently cache for the life of the Client.
29#[derive(Clone)]
30#[must_use = "queries can only be sent through a ClientHandle"]
31pub struct MemoizeClientHandle<H: ClientHandle> {
32    client: H,
33    active_queries: Arc<Mutex<HashMap<Query, RcStream<<H as DnsHandle>::Response>>>>,
34}
35
36impl<H> MemoizeClientHandle<H>
37where
38    H: ClientHandle,
39{
40    /// Returns a new handle wrapping the specified client
41    pub fn new(client: H) -> Self {
42        Self {
43            client,
44            active_queries: Arc::new(Mutex::new(HashMap::new())),
45        }
46    }
47
48    async fn inner_send(
49        request: DnsRequest,
50        active_queries: Arc<Mutex<HashMap<Query, RcStream<<H as DnsHandle>::Response>>>>,
51        mut client: H,
52    ) -> impl Stream<Item = Result<DnsResponse, ProtoError>> {
53        // TODO: what if we want to support multiple queries (non-standard)?
54        let query = request.queries().first().expect("no query!").clone();
55
56        // lock all the currently running queries
57        let mut active_queries = active_queries.lock().await;
58
59        // TODO: we need to consider TTL on the records here at some point
60        // If the query is running, grab that existing one...
61        if let Some(rc_stream) = active_queries.get(&query) {
62            return rc_stream.clone();
63        };
64
65        // Otherwise issue a new query and store in the map
66        active_queries
67            .entry(query)
68            .or_insert_with(|| rc_stream(client.send(request)))
69            .clone()
70    }
71}
72
73impl<H> DnsHandle for MemoizeClientHandle<H>
74where
75    H: ClientHandle,
76{
77    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
78    type Error = ProtoError;
79
80    fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
81        let request = request.into();
82
83        Box::pin(
84            Self::inner_send(
85                request,
86                Arc::clone(&self.active_queries),
87                self.client.clone(),
88            )
89            .flatten_stream(),
90        )
91    }
92}
93
94#[cfg(test)]
95mod test {
96    #![allow(clippy::dbg_macro, clippy::print_stdout)]
97
98    use std::pin::Pin;
99    use std::sync::Arc;
100
101    use futures::lock::Mutex;
102    use futures::*;
103    use trust_dns_proto::{
104        error::ProtoError,
105        xfer::{DnsHandle, DnsRequest, DnsResponse},
106    };
107
108    use crate::client::*;
109    use crate::op::*;
110    use crate::rr::*;
111    use trust_dns_proto::xfer::FirstAnswer;
112
113    #[derive(Clone)]
114    struct TestClient {
115        i: Arc<Mutex<u16>>,
116    }
117
118    impl DnsHandle for TestClient {
119        type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
120        type Error = ProtoError;
121
122        fn send<R: Into<DnsRequest> + Send + 'static>(&mut self, request: R) -> Self::Response {
123            let i = Arc::clone(&self.i);
124            let future = async {
125                let i = i;
126                let request = request;
127                let mut message = Message::new();
128
129                let mut i = i.lock().await;
130
131                message.set_id(*i);
132                println!(
133                    "sending {}: {}",
134                    *i,
135                    request.into().queries().first().expect("no query!").clone()
136                );
137
138                *i += 1;
139
140                Ok(DnsResponse::from_message(message).unwrap())
141            };
142
143            Box::pin(stream::once(future))
144        }
145    }
146
147    #[test]
148    fn test_memoized() {
149        use futures::executor::block_on;
150
151        let mut client = MemoizeClientHandle::new(TestClient {
152            i: Arc::new(Mutex::new(0)),
153        });
154
155        let mut test1 = Message::new();
156        test1.add_query(Query::new().set_query_type(RecordType::A).clone());
157
158        let mut test2 = Message::new();
159        test2.add_query(Query::new().set_query_type(RecordType::AAAA).clone());
160
161        let result = block_on(client.send(test1.clone()).first_answer())
162            .ok()
163            .unwrap();
164        assert_eq!(result.id(), 0);
165
166        let result = block_on(client.send(test2.clone()).first_answer())
167            .ok()
168            .unwrap();
169        assert_eq!(result.id(), 1);
170
171        // should get the same result for each...
172        let result = block_on(client.send(test1).first_answer()).ok().unwrap();
173        assert_eq!(result.id(), 0);
174
175        let result = block_on(client.send(test2).first_answer()).ok().unwrap();
176        assert_eq!(result.id(), 1);
177    }
178}