srv_rs/client/mod.rs
1//! Clients based on SRV lookups.
2
3use crate::{resolver::SrvResolver, SrvRecord};
4use arc_swap::ArcSwap;
5use futures_util::{
6 pin_mut,
7 stream::{self, Stream, StreamExt},
8 FutureExt,
9};
10use http::uri::{Scheme, Uri};
11use std::{fmt::Debug, future::Future, iter::FromIterator, sync::Arc, time::Instant};
12
13mod cache;
14pub use cache::Cache;
15
16/// SRV target selection policies.
17pub mod policy;
18
19/// Errors encountered by a [`SrvClient`].
20#[derive(Debug, thiserror::Error)]
21pub enum Error<Lookup: Debug> {
22 /// SRV lookup errors
23 #[error("SRV lookup error")]
24 Lookup(Lookup),
25 /// SRV record parsing errors
26 #[error("building uri from SRV record: {0}")]
27 RecordParsing(#[from] http::Error),
28 /// Produced when there are no SRV targets for a client to use
29 #[error("no SRV targets to use")]
30 NoTargets,
31}
32
33/// Client for intelligently performing operations on a service located by SRV records.
34///
35/// # Usage
36///
37/// After being created by [`SrvClient::new`] or [`SrvClient::new_with_resolver`],
38/// operations can be performed on the service pointed to by a [`SrvClient`] with
39/// the [`execute`] and [`execute_stream`] methods.
40///
41/// ## DNS Resolvers
42///
43/// The resolver used to lookup SRV records is determined by a client's
44/// [`SrvResolver`], and can be set with [`SrvClient::resolver`].
45///
46/// ## SRV Target Selection Policies
47///
48/// SRV target selection order is determined by a client's [`Policy`],
49/// and can be set with [`SrvClient::policy`].
50///
51/// [`execute`]: SrvClient::execute()
52/// [`execute_stream`]: SrvClient::execute_stream()
53/// [`Policy`]: policy::Policy
54#[derive(Debug)]
55pub struct SrvClient<Resolver, Policy: policy::Policy = policy::Affinity> {
56 srv: String,
57 resolver: Resolver,
58 http_scheme: Scheme,
59 path_prefix: String,
60 policy: Policy,
61 cache: ArcSwap<Cache<Policy::CacheItem>>,
62}
63
64/// Execution mode to use when performing an operation on SRV targets.
65pub enum Execution {
66 /// Operations are performed *serially* (i.e. one after the other).
67 Serial,
68 /// Operations are performed *concurrently* (i.e. all at once).
69 /// Note that this does not imply parallelism--no additional tasks are spawned.
70 Concurrent,
71}
72
73impl Default for Execution {
74 fn default() -> Self {
75 Self::Serial
76 }
77}
78
79impl<Resolver: Default, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
80 /// Creates a new client for communicating with services located by `srv_name`.
81 ///
82 /// # Examples
83 /// ```
84 /// use srv_rs::{SrvClient, resolver::libresolv::LibResolv};
85 /// let client = SrvClient::<LibResolv>::new("_http._tcp.example.com");
86 /// ```
87 pub fn new(srv_name: impl ToString) -> Self {
88 Self::new_with_resolver(srv_name, Resolver::default())
89 }
90}
91
92impl<Resolver, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
93 /// Creates a new client for communicating with services located by `srv_name`.
94 pub fn new_with_resolver(srv_name: impl ToString, resolver: Resolver) -> Self {
95 Self {
96 srv: srv_name.to_string(),
97 resolver,
98 http_scheme: Scheme::HTTPS,
99 path_prefix: String::from("/"),
100 policy: Default::default(),
101 cache: Default::default(),
102 }
103 }
104}
105
106impl<Resolver: SrvResolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
107 /// Gets a fresh set of SRV records from a client's DNS resolver, returning
108 /// them along with the time they're valid until.
109 pub async fn get_srv_records(
110 &self,
111 ) -> Result<(Vec<Resolver::Record>, Instant), Error<Resolver::Error>> {
112 self.resolver
113 .get_srv_records(&self.srv)
114 .await
115 .map_err(Error::Lookup)
116 }
117
118 /// Gets a fresh set of SRV records from a client's DNS resolver and parses
119 /// their target/port pairs into URIs, which are returned along with the
120 /// time they're valid until--i.e., the time a cache containing these URIs
121 /// should expire.
122 pub async fn get_fresh_uri_candidates(
123 &self,
124 ) -> Result<(Vec<Uri>, Instant), Error<Resolver::Error>> {
125 // Query DNS for the SRV record
126 let (records, valid_until) = self.get_srv_records().await?;
127
128 // Create URIs from SRV records
129 let uris = records
130 .iter()
131 .map(|record| self.parse_record(record))
132 .collect::<Result<Vec<Uri>, _>>()?;
133
134 Ok((uris, valid_until))
135 }
136
137 async fn refresh_cache(&self) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
138 let new_cache = Arc::new(self.policy.refresh_cache(self).await?);
139 self.cache.store(new_cache.clone());
140 Ok(new_cache)
141 }
142
143 /// Gets a client's cached items, refreshing the existing cache if it is invalid.
144 async fn get_valid_cache(
145 &self,
146 ) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
147 match self.cache.load_full() {
148 cache if cache.valid() => Ok(cache),
149 _ => self.refresh_cache().await,
150 }
151 }
152
153 /// Performs an operation on all of a client's SRV targets, producing a
154 /// stream of results (one for each target). If the serial execution mode is
155 /// specified, the operation will be performed on each target in the order
156 /// determined by the current [`Policy`], and the results will be returned
157 /// in the same order. If the concurrent execution mode is specified, the
158 /// operation will be performed on all targets concurrently, and results
159 /// will be returned in the order they become available.
160 ///
161 /// # Examples
162 ///
163 /// ```
164 /// # use srv_rs::EXAMPLE_SRV;
165 /// use srv_rs::{SrvClient, Error, Execution};
166 /// use srv_rs::resolver::libresolv::{LibResolv, LibResolvError};
167 ///
168 /// # #[tokio::main]
169 /// # async fn main() -> Result<(), Error<LibResolvError>> {
170 /// # let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV);
171 /// let results_stream = client.execute_stream(Execution::Serial, |address| async move {
172 /// Ok::<_, std::convert::Infallible>(address.to_string())
173 /// })
174 /// .await?;
175 /// // Do something with the stream, for example collect all results into a `Vec`:
176 /// use futures::stream::StreamExt;
177 /// let results: Vec<Result<_, _>> = results_stream.collect().await;
178 /// for result in results {
179 /// assert!(result.is_ok());
180 /// }
181 /// # Ok(())
182 /// # }
183 /// ```
184 ///
185 /// [`Policy`]: policy::Policy
186 pub async fn execute_stream<'a, T, E, Fut>(
187 &'a self,
188 execution_mode: Execution,
189 func: impl FnMut(Uri) -> Fut + 'a,
190 ) -> Result<impl Stream<Item = Result<T, E>> + 'a, Error<Resolver::Error>>
191 where
192 E: std::error::Error,
193 Fut: Future<Output = Result<T, E>> + 'a,
194 {
195 let mut func = func;
196 let cache = self.get_valid_cache().await?;
197 let order = self.policy.order(cache.items());
198 let func = {
199 let cache = cache.clone();
200 move |idx| {
201 let candidate = Policy::cache_item_to_uri(&cache.items()[idx]);
202 func(candidate.to_owned()).map(move |res| (idx, res))
203 }
204 };
205 let results = match execution_mode {
206 Execution::Serial => stream::iter(order).then(func).left_stream(),
207 #[allow(clippy::from_iter_instead_of_collect)]
208 Execution::Concurrent => {
209 stream::FuturesUnordered::from_iter(order.map(func)).right_stream()
210 }
211 };
212 let results = results.map(move |(candidate_idx, result)| {
213 let candidate = Policy::cache_item_to_uri(&cache.items()[candidate_idx]);
214 match result {
215 Ok(res) => {
216 #[cfg(feature = "log")]
217 tracing::info!(URI = %candidate, "execution attempt succeeded");
218 self.policy.note_success(candidate);
219 Ok(res)
220 }
221 Err(err) => {
222 #[cfg(feature = "log")]
223 tracing::info!(URI = %candidate, error = %err, "execution attempt failed");
224 self.policy.note_failure(candidate);
225 Err(err)
226 }
227 }
228 });
229 Ok(results)
230 }
231
232 /// Performs an operation on a client's SRV targets, producing the first
233 /// successful result or the last error encountered if every execution of
234 /// the operation was unsuccessful.
235 ///
236 /// # Examples
237 ///
238 /// ```
239 /// # use srv_rs::EXAMPLE_SRV;
240 /// use srv_rs::{SrvClient, Error, Execution};
241 /// use srv_rs::resolver::libresolv::{LibResolv, LibResolvError};
242 ///
243 /// # #[tokio::main]
244 /// # async fn main() -> Result<(), Error<LibResolvError>> {
245 /// let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV);
246 ///
247 /// let res = client.execute(Execution::Serial, |address| async move {
248 /// Ok::<_, std::convert::Infallible>(address.to_string())
249 /// })
250 /// .await?;
251 /// assert!(res.is_ok());
252 ///
253 /// let res = client.execute(Execution::Concurrent, |address| async move {
254 /// address.to_string().parse::<usize>()
255 /// })
256 /// .await?;
257 /// assert!(res.is_err());
258 /// # Ok(())
259 /// # }
260 /// ```
261 pub async fn execute<T, E, Fut>(
262 &self,
263 execution_mode: Execution,
264 func: impl FnMut(Uri) -> Fut,
265 ) -> Result<Result<T, E>, Error<Resolver::Error>>
266 where
267 E: std::error::Error,
268 Fut: Future<Output = Result<T, E>>,
269 {
270 let results = self.execute_stream(execution_mode, func).await?;
271 pin_mut!(results);
272
273 let mut last_error = None;
274 while let Some(result) = results.next().await {
275 match result {
276 Ok(res) => return Ok(Ok(res)),
277 Err(err) => last_error = Some(err),
278 }
279 }
280
281 if let Some(err) = last_error {
282 Ok(Err(err))
283 } else {
284 Err(Error::NoTargets)
285 }
286 }
287
288 fn parse_record(&self, record: &Resolver::Record) -> Result<Uri, http::Error> {
289 record.parse(self.http_scheme.clone(), self.path_prefix.as_str())
290 }
291}
292
293impl<Resolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
294 /// Sets the SRV name of the client.
295 pub fn srv_name(self, srv_name: impl ToString) -> Self {
296 Self {
297 srv: srv_name.to_string(),
298 ..self
299 }
300 }
301
302 /// Sets the resolver of the client.
303 pub fn resolver<R>(self, resolver: R) -> SrvClient<R, Policy> {
304 SrvClient {
305 resolver,
306 cache: Default::default(),
307 policy: self.policy,
308 srv: self.srv,
309 http_scheme: self.http_scheme,
310 path_prefix: self.path_prefix,
311 }
312 }
313
314 /// Sets the policy of the client.
315 ///
316 /// # Examples
317 ///
318 /// ```
319 /// # use srv_rs::EXAMPLE_SRV;
320 /// use srv_rs::{SrvClient, policy::Rfc2782, resolver::libresolv::LibResolv};
321 /// let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV).policy(Rfc2782);
322 /// ```
323 pub fn policy<P: policy::Policy>(self, policy: P) -> SrvClient<Resolver, P> {
324 SrvClient {
325 policy,
326 cache: Default::default(),
327 resolver: self.resolver,
328 srv: self.srv,
329 http_scheme: self.http_scheme,
330 path_prefix: self.path_prefix,
331 }
332 }
333
334 /// Sets the http scheme of the client.
335 pub fn http_scheme(self, http_scheme: Scheme) -> Self {
336 Self {
337 http_scheme,
338 ..self
339 }
340 }
341
342 /// Sets the path prefix of the client.
343 pub fn path_prefix(self, path_prefix: impl ToString) -> Self {
344 Self {
345 path_prefix: path_prefix.to_string(),
346 ..self
347 }
348 }
349}