srv_rs/client/
policy.rs

1use crate::{resolver::SrvResolver, Error, SrvClient, SrvRecord};
2use arc_swap::ArcSwapOption;
3use async_trait::async_trait;
4use http::Uri;
5use std::sync::Arc;
6
7pub use super::Cache;
8
9/// Policy for [`SrvClient`] to use when selecting SRV targets to recommend.
10#[async_trait]
11pub trait Policy: Sized {
12    /// Type of item stored in a client's cache.
13    type CacheItem;
14
15    /// Iterator of indices used to order cache items.
16    type Ordering: Iterator<Item = usize>;
17
18    /// Obtains a refreshed cache for a client.
19    async fn refresh_cache<Resolver: SrvResolver>(
20        &self,
21        client: &SrvClient<Resolver, Self>,
22    ) -> Result<Cache<Self::CacheItem>, Error<Resolver::Error>>;
23
24    /// Creates an iterator of indices corresponding to cache items in the
25    /// order a [`SrvClient`] should try using them to perform an operation.
26    fn order(&self, items: &[Self::CacheItem]) -> Self::Ordering;
27
28    /// Converts a reference to a cached item into a reference to a [`Uri`].
29    fn cache_item_to_uri(item: &Self::CacheItem) -> &Uri;
30
31    /// Makes any policy adjustments following a successful execution on `uri`.
32    #[allow(unused_variables)]
33    fn note_success(&self, uri: &Uri) {}
34
35    /// Makes any policy adjustments following a failed execution on `uri`.
36    #[allow(unused_variables)]
37    fn note_failure(&self, uri: &Uri) {}
38}
39
40/// Policy that selects targets based on past successes--if a target was used
41/// successfully in a past execution, it will be recommended first.
42#[derive(Default)]
43pub struct Affinity {
44    last_working_target: ArcSwapOption<Uri>,
45}
46
47#[async_trait]
48impl Policy for Affinity {
49    type CacheItem = Uri;
50    type Ordering = AffinityUriIter;
51
52    async fn refresh_cache<Resolver: SrvResolver>(
53        &self,
54        client: &SrvClient<Resolver, Self>,
55    ) -> Result<Cache<Self::CacheItem>, Error<Resolver::Error>> {
56        let (uris, valid_until) = client.get_fresh_uri_candidates().await?;
57        Ok(Cache::new(uris, valid_until))
58    }
59
60    fn order(&self, uris: &[Uri]) -> Self::Ordering {
61        let preferred = self.last_working_target.load();
62        Affinity::uris_preferring(uris, preferred.as_deref())
63    }
64
65    fn cache_item_to_uri(item: &Self::CacheItem) -> &Uri {
66        item
67    }
68
69    fn note_success(&self, uri: &Uri) {
70        self.last_working_target.store(Some(Arc::new(uri.clone())));
71    }
72}
73
74impl Affinity {
75    fn uris_preferring(uris: &[Uri], preferred: Option<&Uri>) -> AffinityUriIter {
76        let preferred = preferred
77            .as_deref()
78            .and_then(|preferred| uris.as_ref().iter().position(|uri| uri == preferred))
79            .unwrap_or(0);
80        AffinityUriIter {
81            n: uris.len(),
82            preferred,
83            next: None,
84        }
85    }
86}
87
88/// Iterator over [`Uri`]s based on affinity. See [`Affinity`].
89pub struct AffinityUriIter {
90    /// Number of uris in the cache.e
91    n: usize,
92    /// Index of the URI to produce first (i.e. the preferred URI).
93    /// `0` if the first is preferred or there is no preferred URI at all.
94    preferred: usize,
95    /// Index of the next URI to be produced.
96    /// If `None`, the preferred URI will be produced.
97    next: Option<usize>,
98}
99
100impl Iterator for AffinityUriIter {
101    type Item = usize;
102
103    fn next(&mut self) -> Option<Self::Item> {
104        let (idx, next) = match self.next {
105            // If no URIs have been produced, produce the preferred URI then go back to the first
106            None => (self.preferred, 0),
107            // If `preferred` is next, skip past it since it was produced already (`self.next != None`)
108            Some(next) if next == self.preferred => (next + 1, next + 2),
109            // Otherwise, advance normally
110            Some(next) => (next, next + 1),
111        };
112        self.next = Some(next);
113        if idx < self.n {
114            Some(idx)
115        } else {
116            None
117        }
118    }
119}
120
121/// Policy that selects targets based on the algorithm in RFC 2782, reshuffling
122/// by weight for each selection.
123#[derive(Default)]
124pub struct Rfc2782;
125
126/// Representation of a SRV record with its target and port parsed into a [`Uri`].
127pub struct ParsedRecord {
128    uri: Uri,
129    priority: u16,
130    weight: u16,
131}
132
133impl ParsedRecord {
134    fn new<Record: SrvRecord>(record: &Record, uri: Uri) -> Self {
135        Self {
136            uri,
137            priority: record.priority(),
138            weight: record.weight(),
139        }
140    }
141}
142
143#[async_trait]
144impl Policy for Rfc2782 {
145    type CacheItem = ParsedRecord;
146    type Ordering = <Vec<usize> as IntoIterator>::IntoIter;
147
148    async fn refresh_cache<Resolver: SrvResolver>(
149        &self,
150        client: &SrvClient<Resolver, Self>,
151    ) -> Result<Cache<Self::CacheItem>, Error<Resolver::Error>> {
152        let (records, valid_until) = client.get_srv_records().await?;
153        let parsed = records
154            .iter()
155            .map(|record| {
156                client
157                    .parse_record(record)
158                    .map(|uri| ParsedRecord::new(record, uri))
159            })
160            .collect::<Result<Vec<_>, _>>()?;
161        Ok(Cache::new(parsed, valid_until))
162    }
163
164    fn order(&self, records: &[ParsedRecord]) -> Self::Ordering {
165        let mut indices = (0..records.len()).collect::<Vec<_>>();
166        let mut rng = rand::thread_rng();
167        indices.sort_by_cached_key(|&idx| {
168            let (priority, weight) = (records[idx].priority, records[idx].weight);
169            crate::record::sort_key(priority, weight, &mut rng)
170        });
171        indices.into_iter()
172    }
173
174    fn cache_item_to_uri(item: &Self::CacheItem) -> &Uri {
175        &item.uri
176    }
177}
178
179#[test]
180fn affinity_uris_iter_order() {
181    let google: Uri = "https://google.com".parse().unwrap();
182    let amazon: Uri = "https://amazon.com".parse().unwrap();
183    let desco: Uri = "https://deshaw.com".parse().unwrap();
184    let cache = vec![google.clone(), amazon.clone(), desco.clone()];
185    let order = |preferred| {
186        Affinity::uris_preferring(&cache, preferred)
187            .map(|idx| &cache[idx])
188            .collect::<Vec<_>>()
189    };
190    assert_eq!(order(None), vec![&google, &amazon, &desco]);
191    assert_eq!(order(Some(&google)), vec![&google, &amazon, &desco]);
192    assert_eq!(order(Some(&amazon)), vec![&amazon, &google, &desco]);
193    assert_eq!(order(Some(&desco)), vec![&desco, &google, &amazon]);
194}
195
196#[test]
197fn balance_uris_iter_order() {
198    // Clippy doesn't like that Uri has interior mutability and is being used
199    // as a HashMap key but we aren't doing anything naughty in the test
200    #[allow(clippy::mutable_key_type)]
201    let mut priorities = std::collections::HashMap::new();
202    priorities.insert("https://google.com".parse::<Uri>().unwrap(), 2);
203    priorities.insert("https://cloudflare.com".parse().unwrap(), 2);
204    priorities.insert("https://amazon.com".parse().unwrap(), 1);
205    priorities.insert("https://deshaw.com".parse().unwrap(), 1);
206
207    let cache = priorities
208        .iter()
209        .map(|(uri, &priority)| ParsedRecord {
210            uri: uri.clone(),
211            priority,
212            weight: rand::random::<u8>() as u16,
213        })
214        .collect::<Vec<_>>();
215
216    let ordered = |iter: <Rfc2782 as Policy>::Ordering| {
217        let mut last = None;
218        for item in iter.map(|idx| &cache[idx]) {
219            if let Some(last) = last {
220                assert!(priorities[last] <= priorities[&item.uri]);
221            }
222            last = Some(&item.uri);
223        }
224    };
225
226    for _ in 0..5 {
227        ordered(Rfc2782.order(&cache));
228    }
229}