1use crate::{resolver::SrvResolver, Error, SrvClient, SrvRecord};
2use arc_swap::ArcSwapOption;
3use async_trait::async_trait;
4use http::Uri;
5use std::sync::Arc;
6
7pub use super::Cache;
8
9#[async_trait]
11pub trait Policy: Sized {
12 type CacheItem;
14
15 type Ordering: Iterator<Item = usize>;
17
18 async fn refresh_cache<Resolver: SrvResolver>(
20 &self,
21 client: &SrvClient<Resolver, Self>,
22 ) -> Result<Cache<Self::CacheItem>, Error<Resolver::Error>>;
23
24 fn order(&self, items: &[Self::CacheItem]) -> Self::Ordering;
27
28 fn cache_item_to_uri(item: &Self::CacheItem) -> &Uri;
30
31 #[allow(unused_variables)]
33 fn note_success(&self, uri: &Uri) {}
34
35 #[allow(unused_variables)]
37 fn note_failure(&self, uri: &Uri) {}
38}
39
40#[derive(Default)]
43pub struct Affinity {
44 last_working_target: ArcSwapOption<Uri>,
45}
46
47#[async_trait]
48impl Policy for Affinity {
49 type CacheItem = Uri;
50 type Ordering = AffinityUriIter;
51
52 async fn refresh_cache<Resolver: SrvResolver>(
53 &self,
54 client: &SrvClient<Resolver, Self>,
55 ) -> Result<Cache<Self::CacheItem>, Error<Resolver::Error>> {
56 let (uris, valid_until) = client.get_fresh_uri_candidates().await?;
57 Ok(Cache::new(uris, valid_until))
58 }
59
60 fn order(&self, uris: &[Uri]) -> Self::Ordering {
61 let preferred = self.last_working_target.load();
62 Affinity::uris_preferring(uris, preferred.as_deref())
63 }
64
65 fn cache_item_to_uri(item: &Self::CacheItem) -> &Uri {
66 item
67 }
68
69 fn note_success(&self, uri: &Uri) {
70 self.last_working_target.store(Some(Arc::new(uri.clone())));
71 }
72}
73
74impl Affinity {
75 fn uris_preferring(uris: &[Uri], preferred: Option<&Uri>) -> AffinityUriIter {
76 let preferred = preferred
77 .as_deref()
78 .and_then(|preferred| uris.as_ref().iter().position(|uri| uri == preferred))
79 .unwrap_or(0);
80 AffinityUriIter {
81 n: uris.len(),
82 preferred,
83 next: None,
84 }
85 }
86}
87
88pub struct AffinityUriIter {
90 n: usize,
92 preferred: usize,
95 next: Option<usize>,
98}
99
100impl Iterator for AffinityUriIter {
101 type Item = usize;
102
103 fn next(&mut self) -> Option<Self::Item> {
104 let (idx, next) = match self.next {
105 None => (self.preferred, 0),
107 Some(next) if next == self.preferred => (next + 1, next + 2),
109 Some(next) => (next, next + 1),
111 };
112 self.next = Some(next);
113 if idx < self.n {
114 Some(idx)
115 } else {
116 None
117 }
118 }
119}
120
121#[derive(Default)]
124pub struct Rfc2782;
125
126pub struct ParsedRecord {
128 uri: Uri,
129 priority: u16,
130 weight: u16,
131}
132
133impl ParsedRecord {
134 fn new<Record: SrvRecord>(record: &Record, uri: Uri) -> Self {
135 Self {
136 uri,
137 priority: record.priority(),
138 weight: record.weight(),
139 }
140 }
141}
142
143#[async_trait]
144impl Policy for Rfc2782 {
145 type CacheItem = ParsedRecord;
146 type Ordering = <Vec<usize> as IntoIterator>::IntoIter;
147
148 async fn refresh_cache<Resolver: SrvResolver>(
149 &self,
150 client: &SrvClient<Resolver, Self>,
151 ) -> Result<Cache<Self::CacheItem>, Error<Resolver::Error>> {
152 let (records, valid_until) = client.get_srv_records().await?;
153 let parsed = records
154 .iter()
155 .map(|record| {
156 client
157 .parse_record(record)
158 .map(|uri| ParsedRecord::new(record, uri))
159 })
160 .collect::<Result<Vec<_>, _>>()?;
161 Ok(Cache::new(parsed, valid_until))
162 }
163
164 fn order(&self, records: &[ParsedRecord]) -> Self::Ordering {
165 let mut indices = (0..records.len()).collect::<Vec<_>>();
166 let mut rng = rand::thread_rng();
167 indices.sort_by_cached_key(|&idx| {
168 let (priority, weight) = (records[idx].priority, records[idx].weight);
169 crate::record::sort_key(priority, weight, &mut rng)
170 });
171 indices.into_iter()
172 }
173
174 fn cache_item_to_uri(item: &Self::CacheItem) -> &Uri {
175 &item.uri
176 }
177}
178
179#[test]
180fn affinity_uris_iter_order() {
181 let google: Uri = "https://google.com".parse().unwrap();
182 let amazon: Uri = "https://amazon.com".parse().unwrap();
183 let desco: Uri = "https://deshaw.com".parse().unwrap();
184 let cache = vec![google.clone(), amazon.clone(), desco.clone()];
185 let order = |preferred| {
186 Affinity::uris_preferring(&cache, preferred)
187 .map(|idx| &cache[idx])
188 .collect::<Vec<_>>()
189 };
190 assert_eq!(order(None), vec![&google, &amazon, &desco]);
191 assert_eq!(order(Some(&google)), vec![&google, &amazon, &desco]);
192 assert_eq!(order(Some(&amazon)), vec![&amazon, &google, &desco]);
193 assert_eq!(order(Some(&desco)), vec![&desco, &google, &amazon]);
194}
195
196#[test]
197fn balance_uris_iter_order() {
198 #[allow(clippy::mutable_key_type)]
201 let mut priorities = std::collections::HashMap::new();
202 priorities.insert("https://google.com".parse::<Uri>().unwrap(), 2);
203 priorities.insert("https://cloudflare.com".parse().unwrap(), 2);
204 priorities.insert("https://amazon.com".parse().unwrap(), 1);
205 priorities.insert("https://deshaw.com".parse().unwrap(), 1);
206
207 let cache = priorities
208 .iter()
209 .map(|(uri, &priority)| ParsedRecord {
210 uri: uri.clone(),
211 priority,
212 weight: rand::random::<u8>() as u16,
213 })
214 .collect::<Vec<_>>();
215
216 let ordered = |iter: <Rfc2782 as Policy>::Ordering| {
217 let mut last = None;
218 for item in iter.map(|idx| &cache[idx]) {
219 if let Some(last) = last {
220 assert!(priorities[last] <= priorities[&item.uri]);
221 }
222 last = Some(&item.uri);
223 }
224 };
225
226 for _ in 0..5 {
227 ordered(Rfc2782.order(&cache));
228 }
229}