srv_rs/client/
mod.rs

1//! Clients based on SRV lookups.
2
3use crate::{resolver::SrvResolver, SrvRecord};
4use arc_swap::ArcSwap;
5use futures_util::{
6    pin_mut,
7    stream::{self, Stream, StreamExt},
8    FutureExt,
9};
10use http::uri::{Scheme, Uri};
11use std::{fmt::Debug, future::Future, iter::FromIterator, sync::Arc, time::Instant};
12
13mod cache;
14pub use cache::Cache;
15
16/// SRV target selection policies.
17pub mod policy;
18
19/// Errors encountered by a [`SrvClient`].
20#[derive(Debug, thiserror::Error)]
21pub enum Error<Lookup: Debug> {
22    /// SRV lookup errors
23    #[error("SRV lookup error")]
24    Lookup(Lookup),
25    /// SRV record parsing errors
26    #[error("building uri from SRV record: {0}")]
27    RecordParsing(#[from] http::Error),
28    /// Produced when there are no SRV targets for a client to use
29    #[error("no SRV targets to use")]
30    NoTargets,
31}
32
33/// Client for intelligently performing operations on a service located by SRV records.
34///
35/// # Usage
36///
37/// After being created by [`SrvClient::new`] or [`SrvClient::new_with_resolver`],
38/// operations can be performed on the service pointed to by a [`SrvClient`] with
39/// the [`execute`] and [`execute_stream`] methods.
40///
41/// ## DNS Resolvers
42///
43/// The resolver used to lookup SRV records is determined by a client's
44/// [`SrvResolver`], and can be set with [`SrvClient::resolver`].
45///
46/// ## SRV Target Selection Policies
47///
48/// SRV target selection order is determined by a client's [`Policy`],
49/// and can be set with [`SrvClient::policy`].
50///
51/// [`execute`]: SrvClient::execute()
52/// [`execute_stream`]: SrvClient::execute_stream()
53/// [`Policy`]: policy::Policy
54#[derive(Debug)]
55pub struct SrvClient<Resolver, Policy: policy::Policy = policy::Affinity> {
56    srv: String,
57    resolver: Resolver,
58    http_scheme: Scheme,
59    path_prefix: String,
60    policy: Policy,
61    cache: ArcSwap<Cache<Policy::CacheItem>>,
62}
63
64/// Execution mode to use when performing an operation on SRV targets.
65pub enum Execution {
66    /// Operations are performed *serially* (i.e. one after the other).
67    Serial,
68    /// Operations are performed *concurrently* (i.e. all at once).
69    /// Note that this does not imply parallelism--no additional tasks are spawned.
70    Concurrent,
71}
72
73impl Default for Execution {
74    fn default() -> Self {
75        Self::Serial
76    }
77}
78
79impl<Resolver: Default, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
80    /// Creates a new client for communicating with services located by `srv_name`.
81    ///
82    /// # Examples
83    /// ```
84    /// use srv_rs::{SrvClient, resolver::libresolv::LibResolv};
85    /// let client = SrvClient::<LibResolv>::new("_http._tcp.example.com");
86    /// ```
87    pub fn new(srv_name: impl ToString) -> Self {
88        Self::new_with_resolver(srv_name, Resolver::default())
89    }
90}
91
92impl<Resolver, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
93    /// Creates a new client for communicating with services located by `srv_name`.
94    pub fn new_with_resolver(srv_name: impl ToString, resolver: Resolver) -> Self {
95        Self {
96            srv: srv_name.to_string(),
97            resolver,
98            http_scheme: Scheme::HTTPS,
99            path_prefix: String::from("/"),
100            policy: Default::default(),
101            cache: Default::default(),
102        }
103    }
104}
105
106impl<Resolver: SrvResolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
107    /// Gets a fresh set of SRV records from a client's DNS resolver, returning
108    /// them along with the time they're valid until.
109    pub async fn get_srv_records(
110        &self,
111    ) -> Result<(Vec<Resolver::Record>, Instant), Error<Resolver::Error>> {
112        self.resolver
113            .get_srv_records(&self.srv)
114            .await
115            .map_err(Error::Lookup)
116    }
117
118    /// Gets a fresh set of SRV records from a client's DNS resolver and parses
119    /// their target/port pairs into URIs, which are returned along with the
120    /// time they're valid until--i.e., the time a cache containing these URIs
121    /// should expire.
122    pub async fn get_fresh_uri_candidates(
123        &self,
124    ) -> Result<(Vec<Uri>, Instant), Error<Resolver::Error>> {
125        // Query DNS for the SRV record
126        let (records, valid_until) = self.get_srv_records().await?;
127
128        // Create URIs from SRV records
129        let uris = records
130            .iter()
131            .map(|record| self.parse_record(record))
132            .collect::<Result<Vec<Uri>, _>>()?;
133
134        Ok((uris, valid_until))
135    }
136
137    async fn refresh_cache(&self) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
138        let new_cache = Arc::new(self.policy.refresh_cache(self).await?);
139        self.cache.store(new_cache.clone());
140        Ok(new_cache)
141    }
142
143    /// Gets a client's cached items, refreshing the existing cache if it is invalid.
144    async fn get_valid_cache(
145        &self,
146    ) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
147        match self.cache.load_full() {
148            cache if cache.valid() => Ok(cache),
149            _ => self.refresh_cache().await,
150        }
151    }
152
153    /// Performs an operation on all of a client's SRV targets, producing a
154    /// stream of results (one for each target). If the serial execution mode is
155    /// specified, the operation will be performed on each target in the order
156    /// determined by the current [`Policy`], and the results will be returned
157    /// in the same order. If the concurrent execution mode is specified, the
158    /// operation will be performed on all targets concurrently, and results
159    /// will be returned in the order they become available.
160    ///
161    /// # Examples
162    ///
163    /// ```
164    /// # use srv_rs::EXAMPLE_SRV;
165    /// use srv_rs::{SrvClient, Error, Execution};
166    /// use srv_rs::resolver::libresolv::{LibResolv, LibResolvError};
167    ///
168    /// # #[tokio::main]
169    /// # async fn main() -> Result<(), Error<LibResolvError>> {
170    /// # let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV);
171    /// let results_stream = client.execute_stream(Execution::Serial, |address| async move {
172    ///     Ok::<_, std::convert::Infallible>(address.to_string())
173    /// })
174    /// .await?;
175    /// // Do something with the stream, for example collect all results into a `Vec`:
176    /// use futures::stream::StreamExt;
177    /// let results: Vec<Result<_, _>> = results_stream.collect().await;
178    /// for result in results {
179    ///     assert!(result.is_ok());
180    /// }
181    /// # Ok(())
182    /// # }
183    /// ```
184    ///
185    /// [`Policy`]: policy::Policy
186    pub async fn execute_stream<'a, T, E, Fut>(
187        &'a self,
188        execution_mode: Execution,
189        func: impl FnMut(Uri) -> Fut + 'a,
190    ) -> Result<impl Stream<Item = Result<T, E>> + 'a, Error<Resolver::Error>>
191    where
192        E: std::error::Error,
193        Fut: Future<Output = Result<T, E>> + 'a,
194    {
195        let mut func = func;
196        let cache = self.get_valid_cache().await?;
197        let order = self.policy.order(cache.items());
198        let func = {
199            let cache = cache.clone();
200            move |idx| {
201                let candidate = Policy::cache_item_to_uri(&cache.items()[idx]);
202                func(candidate.to_owned()).map(move |res| (idx, res))
203            }
204        };
205        let results = match execution_mode {
206            Execution::Serial => stream::iter(order).then(func).left_stream(),
207            #[allow(clippy::from_iter_instead_of_collect)]
208            Execution::Concurrent => {
209                stream::FuturesUnordered::from_iter(order.map(func)).right_stream()
210            }
211        };
212        let results = results.map(move |(candidate_idx, result)| {
213            let candidate = Policy::cache_item_to_uri(&cache.items()[candidate_idx]);
214            match result {
215                Ok(res) => {
216                    #[cfg(feature = "log")]
217                    tracing::info!(URI = %candidate, "execution attempt succeeded");
218                    self.policy.note_success(candidate);
219                    Ok(res)
220                }
221                Err(err) => {
222                    #[cfg(feature = "log")]
223                    tracing::info!(URI = %candidate, error = %err, "execution attempt failed");
224                    self.policy.note_failure(candidate);
225                    Err(err)
226                }
227            }
228        });
229        Ok(results)
230    }
231
232    /// Performs an operation on a client's SRV targets, producing the first
233    /// successful result or the last error encountered if every execution of
234    /// the operation was unsuccessful.
235    ///
236    /// # Examples
237    ///
238    /// ```
239    /// # use srv_rs::EXAMPLE_SRV;
240    /// use srv_rs::{SrvClient, Error, Execution};
241    /// use srv_rs::resolver::libresolv::{LibResolv, LibResolvError};
242    ///
243    /// # #[tokio::main]
244    /// # async fn main() -> Result<(), Error<LibResolvError>> {
245    /// let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV);
246    ///
247    /// let res = client.execute(Execution::Serial, |address| async move {
248    ///     Ok::<_, std::convert::Infallible>(address.to_string())
249    /// })
250    /// .await?;
251    /// assert!(res.is_ok());
252    ///
253    /// let res = client.execute(Execution::Concurrent, |address| async move {
254    ///     address.to_string().parse::<usize>()
255    /// })
256    /// .await?;
257    /// assert!(res.is_err());
258    /// # Ok(())
259    /// # }
260    /// ```
261    pub async fn execute<T, E, Fut>(
262        &self,
263        execution_mode: Execution,
264        func: impl FnMut(Uri) -> Fut,
265    ) -> Result<Result<T, E>, Error<Resolver::Error>>
266    where
267        E: std::error::Error,
268        Fut: Future<Output = Result<T, E>>,
269    {
270        let results = self.execute_stream(execution_mode, func).await?;
271        pin_mut!(results);
272
273        let mut last_error = None;
274        while let Some(result) = results.next().await {
275            match result {
276                Ok(res) => return Ok(Ok(res)),
277                Err(err) => last_error = Some(err),
278            }
279        }
280
281        if let Some(err) = last_error {
282            Ok(Err(err))
283        } else {
284            Err(Error::NoTargets)
285        }
286    }
287
288    fn parse_record(&self, record: &Resolver::Record) -> Result<Uri, http::Error> {
289        record.parse(self.http_scheme.clone(), self.path_prefix.as_str())
290    }
291}
292
293impl<Resolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
294    /// Sets the SRV name of the client.
295    pub fn srv_name(self, srv_name: impl ToString) -> Self {
296        Self {
297            srv: srv_name.to_string(),
298            ..self
299        }
300    }
301
302    /// Sets the resolver of the client.
303    pub fn resolver<R>(self, resolver: R) -> SrvClient<R, Policy> {
304        SrvClient {
305            resolver,
306            cache: Default::default(),
307            policy: self.policy,
308            srv: self.srv,
309            http_scheme: self.http_scheme,
310            path_prefix: self.path_prefix,
311        }
312    }
313
314    /// Sets the policy of the client.
315    ///
316    /// # Examples
317    ///
318    /// ```
319    /// # use srv_rs::EXAMPLE_SRV;
320    /// use srv_rs::{SrvClient, policy::Rfc2782, resolver::libresolv::LibResolv};
321    /// let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV).policy(Rfc2782);
322    /// ```
323    pub fn policy<P: policy::Policy>(self, policy: P) -> SrvClient<Resolver, P> {
324        SrvClient {
325            policy,
326            cache: Default::default(),
327            resolver: self.resolver,
328            srv: self.srv,
329            http_scheme: self.http_scheme,
330            path_prefix: self.path_prefix,
331        }
332    }
333
334    /// Sets the http scheme of the client.
335    pub fn http_scheme(self, http_scheme: Scheme) -> Self {
336        Self {
337            http_scheme,
338            ..self
339        }
340    }
341
342    /// Sets the path prefix of the client.
343    pub fn path_prefix(self, path_prefix: impl ToString) -> Self {
344        Self {
345            path_prefix: path_prefix.to_string(),
346            ..self
347        }
348    }
349}