1use std::{collections::BinaryHeap, net::IpAddr, sync::Arc, time::Duration};
2
3use chrono::{DateTime, TimeDelta, Utc};
4use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
5use tokio::{
6 sync::RwLock,
7 task::{AbortHandle, JoinHandle},
8 time::interval,
9};
10use tonic::transport::Channel;
11use tracing::{debug, trace};
12
13use crate::{
14 broken_endpoints::{BrokenEndpoints, DelayedAddress},
15 dns::resolve_domain,
16 endpoint_template::EndpointTemplate,
17 ready_channels::ReadyChannels,
18};
19
20#[derive(Debug, Clone)]
22pub struct ChannelPoolBuilder {
23 endpoint: EndpointTemplate,
24 dns_interval: Duration,
25}
26
27impl ChannelPoolBuilder {
28 #[must_use]
30 pub fn new(endpoint: impl Into<EndpointTemplate>) -> Self {
31 Self {
32 endpoint: endpoint.into(),
33 dns_interval: Duration::from_secs(5),
35 }
36 }
37
38 #[must_use]
43 pub fn dns_interval(&mut self, dns_interval: impl Into<Duration>) -> &mut Self {
44 self.dns_interval = dns_interval.into();
45 self
46 }
47
48 #[must_use]
54 pub fn build(self) -> ChannelPool {
55 let ready_clients = Arc::new(ReadyChannels::default());
56 let broken_endpoints = Arc::new(BrokenEndpoints::default());
57
58 let dns_lookup_task = {
59 let ready_clients = ready_clients.clone();
61 let broken_endpoints = broken_endpoints.clone();
62 let endpoint = self.endpoint.clone();
63
64 tokio::spawn(async move {
65 let mut interval = interval(self.dns_interval);
66 loop {
67 check_dns(&endpoint, &ready_clients, &broken_endpoints).await;
68
69 interval.tick().await;
70 }
71 })
72 };
73
74 let doctor_task = {
75 let ready_clients = ready_clients.clone();
77 let broken_endpoints = broken_endpoints.clone();
78 let endpoint = self.endpoint.clone();
79
80 tokio::spawn(async move {
81 loop {
82 recheck_broken_endpoint(
84 broken_endpoints.next_broken_ip_address().await,
85 &endpoint,
86 &ready_clients,
87 &broken_endpoints,
88 )
89 .await;
90 }
91 })
92 };
93
94 ChannelPool {
95 template: Arc::new(self.endpoint),
96 ready_clients,
97 broken_endpoints,
98 _dns_lookup_task: Arc::new(dns_lookup_task.into()),
99 _doctor_task: Arc::new(doctor_task.into()),
100 }
101 }
102}
103
104async fn check_dns(
105 endpoint_template: &EndpointTemplate,
106 ready_clients: &ReadyChannels,
107 broken_endpoints: &BrokenEndpoints,
108) {
109 let Ok(addresses) = resolve_domain(endpoint_template.domain()) else {
111 return;
116 };
117
118 let mut ready = Vec::new();
119 let mut broken = BinaryHeap::new();
120
121 for address in addresses {
122 if let Some(channel) = ready_clients.find(address).await {
124 trace!("Skipping {:?} as already ready", address);
125 ready.push((address, channel));
126 continue;
127 }
128
129 if let Some(entry) = broken_endpoints.get_address(address).await {
131 trace!("Skipping {:?} as already broken", address);
132 broken.push(entry);
133 continue;
134 }
135
136 debug!("Connecting to: {:?}", address);
137 let channel = endpoint_template.build(address).connect().await;
138 if let Ok(channel) = channel {
139 ready.push((address, channel));
140 } else {
141 broken.push(address.into());
142 }
143 }
144
145 ready_clients.replace_with(ready).await;
147 broken_endpoints.replace_with(broken).await;
148}
149
150async fn recheck_broken_endpoint(
151 address: DelayedAddress,
152 endpoint: &EndpointTemplate,
153 ready_clients: &ReadyChannels,
154 broken_endpoints: &BrokenEndpoints,
155) {
156 let connection_test_result = endpoint.build(*address).connect().await;
157
158 if let Ok(channel) = connection_test_result {
159 debug!("Connection established to {:?}", *address);
160 ready_clients.add(*address, channel).await;
161 } else {
162 debug!("Can't connect to {:?}", *address);
163 broken_endpoints.re_add_address(address).await;
164 }
165}
166
167#[derive(Debug, Default)]
168struct AbortOnDrop(Option<AbortHandle>);
169
170impl<T> From<JoinHandle<T>> for AbortOnDrop {
171 fn from(handle: JoinHandle<T>) -> Self {
172 Self(Some(handle.abort_handle()))
173 }
174}
175
176impl Drop for AbortOnDrop {
177 fn drop(&mut self) {
178 if let Some(handle) = self.0.take() {
179 handle.abort();
180 }
181 }
182}
183
184#[derive(Debug)]
187pub struct ChannelPool {
188 template: Arc<EndpointTemplate>,
189 ready_clients: Arc<ReadyChannels>,
190 broken_endpoints: Arc<BrokenEndpoints>,
191
192 _dns_lookup_task: Arc<AbortOnDrop>,
193 _doctor_task: Arc<AbortOnDrop>,
194}
195
196impl ChannelPool {
197 pub async fn get_channel(&self) -> Option<(IpAddr, Channel)> {
221 static RECHECK_BROKEN_ENDPOINTS: RwLock<DateTime<Utc>> =
222 RwLock::const_new(DateTime::<Utc>::MIN_UTC);
223 const MIN_INTERVAL: TimeDelta = TimeDelta::milliseconds(500);
224
225 if let Some(entry) = self.ready_clients.get_any().await {
226 return Some(entry);
227 }
228
229 let _guard = match RECHECK_BROKEN_ENDPOINTS.try_read() {
231 Ok(last_recheck_time)
232 if Utc::now().signed_duration_since(*last_recheck_time) < MIN_INTERVAL =>
233 {
234 return None;
235 }
236 Ok(guard) => {
237 drop(guard);
238 let mut guard = RECHECK_BROKEN_ENDPOINTS.write().await;
239 if let Some(entry) = self.ready_clients.get_any().await {
240 return Some(entry);
241 }
242 *guard = Utc::now();
243 guard
244 }
245 Err(_) => {
246 let _ = RECHECK_BROKEN_ENDPOINTS.write().await;
250 return self.ready_clients.get_any().await;
251 }
252 };
253
254 trace!("Force recheck of broken endpoints");
255
256 let mut fut = FuturesUnordered::new();
257 fut.push(
258 async {
259 check_dns(&self.template, &self.ready_clients, &self.broken_endpoints).await;
260 self.ready_clients.get_any().await
261 }
262 .boxed(),
263 );
264
265 for address in self.broken_endpoints.addresses().await.iter().copied() {
266 fut.push(
267 async move {
268 recheck_broken_endpoint(
269 address,
270 &self.template,
271 &self.ready_clients,
272 &self.broken_endpoints,
273 )
274 .await;
275 self.ready_clients.get_any().await
276 }
277 .boxed(),
278 );
279 }
280
281 fut.select_next_some().await
282 }
283
284 pub async fn report_broken(&self, ip_address: impl Into<IpAddr>) {
288 let ip_address = ip_address.into();
289 self.ready_clients.remove(ip_address).await;
290 self.broken_endpoints.add_address(ip_address).await;
291 }
292}
293
294impl Clone for ChannelPool {
297 fn clone(&self) -> Self {
298 #[allow(clippy::used_underscore_binding)]
299 Self {
300 template: self.template.clone(),
301 ready_clients: self.ready_clients.clone(),
302 broken_endpoints: self.broken_endpoints.clone(),
303 _dns_lookup_task: self._dns_lookup_task.clone(),
304 _doctor_task: self._doctor_task.clone(),
305 }
306 }
307}