trust_dns_client/client/
memoize_client_handle.rs1use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use futures_util::future::FutureExt;
12use futures_util::lock::Mutex;
13use futures_util::stream::Stream;
14use trust_dns_proto::{
15 error::ProtoError,
16 xfer::{DnsHandle, DnsRequest, DnsResponse},
17};
18
19use crate::client::rc_stream::{rc_stream, RcStream};
20use crate::client::ClientHandle;
21use crate::op::Query;
22
23#[derive(Clone)]
30#[must_use = "queries can only be sent through a ClientHandle"]
31pub struct MemoizeClientHandle<H: ClientHandle> {
32 client: H,
33 active_queries: Arc<Mutex<HashMap<Query, RcStream<<H as DnsHandle>::Response>>>>,
34}
35
36impl<H> MemoizeClientHandle<H>
37where
38 H: ClientHandle,
39{
40 pub fn new(client: H) -> Self {
42 Self {
43 client,
44 active_queries: Arc::new(Mutex::new(HashMap::new())),
45 }
46 }
47
48 async fn inner_send(
49 request: DnsRequest,
50 active_queries: Arc<Mutex<HashMap<Query, RcStream<<H as DnsHandle>::Response>>>>,
51 mut client: H,
52 ) -> impl Stream<Item = Result<DnsResponse, ProtoError>> {
53 let query = request.queries().first().expect("no query!").clone();
55
56 let mut active_queries = active_queries.lock().await;
58
59 if let Some(rc_stream) = active_queries.get(&query) {
62 return rc_stream.clone();
63 };
64
65 active_queries
67 .entry(query)
68 .or_insert_with(|| rc_stream(client.send(request)))
69 .clone()
70 }
71}
72
73impl<H> DnsHandle for MemoizeClientHandle<H>
74where
75 H: ClientHandle,
76{
77 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
78 type Error = ProtoError;
79
80 fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
81 let request = request.into();
82
83 Box::pin(
84 Self::inner_send(
85 request,
86 Arc::clone(&self.active_queries),
87 self.client.clone(),
88 )
89 .flatten_stream(),
90 )
91 }
92}
93
94#[cfg(test)]
95mod test {
96 #![allow(clippy::dbg_macro, clippy::print_stdout)]
97
98 use std::pin::Pin;
99 use std::sync::Arc;
100
101 use futures::lock::Mutex;
102 use futures::*;
103 use trust_dns_proto::{
104 error::ProtoError,
105 xfer::{DnsHandle, DnsRequest, DnsResponse},
106 };
107
108 use crate::client::*;
109 use crate::op::*;
110 use crate::rr::*;
111 use trust_dns_proto::xfer::FirstAnswer;
112
113 #[derive(Clone)]
114 struct TestClient {
115 i: Arc<Mutex<u16>>,
116 }
117
118 impl DnsHandle for TestClient {
119 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
120 type Error = ProtoError;
121
122 fn send<R: Into<DnsRequest> + Send + 'static>(&mut self, request: R) -> Self::Response {
123 let i = Arc::clone(&self.i);
124 let future = async {
125 let i = i;
126 let request = request;
127 let mut message = Message::new();
128
129 let mut i = i.lock().await;
130
131 message.set_id(*i);
132 println!(
133 "sending {}: {}",
134 *i,
135 request.into().queries().first().expect("no query!").clone()
136 );
137
138 *i += 1;
139
140 Ok(DnsResponse::from_message(message).unwrap())
141 };
142
143 Box::pin(stream::once(future))
144 }
145 }
146
147 #[test]
148 fn test_memoized() {
149 use futures::executor::block_on;
150
151 let mut client = MemoizeClientHandle::new(TestClient {
152 i: Arc::new(Mutex::new(0)),
153 });
154
155 let mut test1 = Message::new();
156 test1.add_query(Query::new().set_query_type(RecordType::A).clone());
157
158 let mut test2 = Message::new();
159 test2.add_query(Query::new().set_query_type(RecordType::AAAA).clone());
160
161 let result = block_on(client.send(test1.clone()).first_answer())
162 .ok()
163 .unwrap();
164 assert_eq!(result.id(), 0);
165
166 let result = block_on(client.send(test2.clone()).first_answer())
167 .ok()
168 .unwrap();
169 assert_eq!(result.id(), 1);
170
171 let result = block_on(client.send(test1).first_answer()).ok().unwrap();
173 assert_eq!(result.id(), 0);
174
175 let result = block_on(client.send(test2).first_answer()).ok().unwrap();
176 assert_eq!(result.id(), 1);
177 }
178}