1#![doc = include_str!("../README.md")]
25#![cfg_attr(not(test), warn(clippy::unwrap_used, clippy::expect_used))]
28
29mod attempt;
30mod budget;
31mod channel_pool;
32mod driver;
33mod driver_supervisor;
34mod error;
35mod leader_hint;
36mod response;
37mod retry;
38mod retry_policy;
39mod transport;
40mod worklist;
41
42#[cfg(test)]
43mod test_support;
44
45pub use error::ClientError;
46pub use retry_policy::RetryPolicy;
47pub use transport::BoxError;
48
49use std::sync::Arc;
50use std::time::Duration;
51use tsoracle_core::{Epoch, LOGICAL_MAX, Timestamp};
52
53pub(crate) const MAX_TIMESTAMPS_PER_RPC: u32 = LOGICAL_MAX + 1;
58
59use crate::channel_pool::ChannelPool;
60
61pub struct ClientBuilder {
62 endpoints: Vec<String>,
63 flush_interval: Duration,
64 connector: Option<Arc<crate::transport::ChannelConnector>>,
65 tls_required: bool,
66 retry_policy: RetryPolicy,
67}
68
69impl ClientBuilder {
70 pub fn endpoints(endpoints: Vec<String>) -> Self {
71 ClientBuilder {
72 endpoints,
73 flush_interval: Duration::from_millis(1),
74 connector: None,
75 tls_required: false,
76 retry_policy: RetryPolicy::default(),
77 }
78 }
79
80 pub fn batch_flush_interval(mut self, flush_interval: Duration) -> Self {
81 self.flush_interval = flush_interval;
82 self
83 }
84
85 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
97 self.retry_policy = policy;
98 self
99 }
100
101 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
117 pub fn tls_config(mut self, cfg: tonic::transport::ClientTlsConfig) -> Self {
118 self.connector = Some(crate::transport::tls_connector(
119 cfg,
120 self.retry_policy.clone(),
121 ));
122 self.tls_required = true;
123 self
124 }
125
126 pub fn channel_connector<F, Fut>(mut self, connector: F) -> Self
137 where
138 F: Fn(&str) -> Fut + Send + Sync + 'static,
139 Fut: std::future::Future<Output = Result<tonic::transport::Channel, crate::BoxError>>
140 + Send
141 + 'static,
142 {
143 let wrapped: Arc<crate::transport::ChannelConnector> = Arc::new(move |endpoint: &str| {
144 let fut = connector(endpoint);
145 Box::pin(async move { fut.await.map_err(ClientError::Connector) })
146 });
147 self.connector = Some(wrapped);
148 self.tls_required = false;
149 self
150 }
151
152 pub async fn build(self) -> Result<Client, ClientError> {
153 if self.endpoints.is_empty() {
154 return Err(ClientError::NoReachableEndpoints);
155 }
156 let pool = Arc::new(ChannelPool::new(
157 self.endpoints,
158 self.connector,
159 self.tls_required,
160 self.retry_policy,
161 ));
162 let pool_for_rpc = pool.clone();
163 let driver = driver::Driver::spawn(
164 move |count| {
165 let pool = pool_for_rpc.clone();
166 Box::pin(async move { retry::issue_rpc(&pool, count).await })
167 },
168 self.flush_interval,
169 );
170 Ok(Client { pool, driver })
171 }
172}
173
174pub struct Client {
175 pool: Arc<ChannelPool>,
176 driver: driver::Driver,
177}
178
179impl Client {
180 pub async fn connect(endpoints: Vec<String>) -> Result<Self, ClientError> {
181 ClientBuilder::endpoints(endpoints).build().await
182 }
183
184 pub fn cached_leader(&self) -> Option<String> {
194 self.pool.cached_leader()
195 }
196
197 pub async fn get_ts(&self) -> Result<Timestamp, ClientError> {
198 Ok(self.driver.request(1).await?[0])
199 }
200
201 pub async fn get_ts_batch(&self, count: u32) -> Result<Vec<Timestamp>, ClientError> {
202 if count == 0 || count > MAX_TIMESTAMPS_PER_RPC {
203 return Err(ClientError::InvalidCount(count));
204 }
205 self.driver.request(count).await
206 }
207
208 pub async fn get_current_max_safe(&self) -> Result<MaxSafe, ClientError> {
215 let endpoint = self
216 .pool
217 .cached_leader()
218 .or_else(|| self.pool.iter_round_robin().into_iter().next())
219 .ok_or(ClientError::NoReachableEndpoints)?;
220 let (mut svc, _cell) = self.pool.client_with_cell(&endpoint).await?;
221 let resp = svc
222 .get_current_max_safe(tsoracle_proto::v1::GetCurrentMaxSafeRequest {})
223 .await
224 .map_err(ClientError::Rpc)?;
225 let inner = resp.into_inner();
226 Ok(MaxSafe {
227 max_safe_physical_ms: inner.max_safe_physical_ms,
228 epoch: Epoch::from_wire(inner.epoch_hi, inner.epoch_lo),
229 })
230 }
231}
232
233#[derive(Copy, Clone, Debug, PartialEq, Eq)]
236pub struct MaxSafe {
237 pub max_safe_physical_ms: u64,
240 pub epoch: Epoch,
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[tokio::test]
249 async fn cached_leader_is_none_before_any_rpc() {
250 let client = Client::connect(vec!["http://127.0.0.1:1".into()])
256 .await
257 .expect("build with a non-empty endpoint list must succeed");
258 assert_eq!(client.cached_leader(), None);
259 }
260
261 #[tokio::test]
262 async fn build_rejects_empty_endpoint_list() {
263 match ClientBuilder::endpoints(Vec::new()).build().await {
267 Err(ClientError::NoReachableEndpoints) => {}
268 Err(other) => panic!("expected NoReachableEndpoints, got {other:?}"),
269 Ok(_) => panic!("expected Err, got Ok(Client)"),
270 }
271 }
272
273 #[tokio::test]
274 async fn channel_connector_error_surfaces_as_connector_variant() {
275 let builder = ClientBuilder::endpoints(vec!["a:1".into()]).channel_connector(
276 |_endpoint: &str| async move {
277 Err::<tonic::transport::Channel, crate::BoxError>(
278 std::io::Error::other("boom").into(),
279 )
280 },
281 );
282 let client = builder.build().await.expect("build must not fail");
283 let result = client.get_ts().await;
284 match result {
285 Err(ClientError::Connector(inner)) => {
286 assert!(inner.to_string().contains("boom"));
287 }
288 other => panic!("expected ClientError::Connector, got {other:?}"),
289 }
290 }
291
292 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
297 async fn marker_connector_failure() -> Result<tonic::transport::Channel, crate::BoxError> {
298 Err("MARKER".into())
299 }
300
301 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
302 #[tokio::test]
303 async fn tls_config_then_channel_connector_last_wins() {
304 let builder = ClientBuilder::endpoints(vec!["a:1".into()])
308 .tls_config(tonic::transport::ClientTlsConfig::new())
309 .channel_connector(|_endpoint: &str| marker_connector_failure());
310 let client = builder.build().await.expect("build must not fail");
311 match client.get_ts().await {
312 Err(ClientError::Connector(inner)) => {
313 assert!(inner.to_string().contains("MARKER"));
314 }
315 other => panic!("expected Connector(MARKER), got {other:?}"),
316 }
317 }
318
319 #[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
320 #[tokio::test]
321 async fn channel_connector_then_tls_config_last_wins() {
322 let builder = ClientBuilder::endpoints(vec!["a:1".into()])
326 .channel_connector(|_endpoint: &str| marker_connector_failure())
327 .tls_config(tonic::transport::ClientTlsConfig::new());
328 let client = builder.build().await.expect("build must not fail");
329 let result = client.get_ts().await;
330 if let Err(ClientError::Connector(inner)) = &result
331 && inner.to_string().contains("MARKER")
332 {
333 panic!("tls_config set last must overwrite the prior channel_connector");
334 }
335 }
336
337 #[tokio::test]
338 async fn batch_flush_interval_overrides_default() {
339 let custom = Duration::from_millis(25);
344 let builder = ClientBuilder::endpoints(vec!["http://127.0.0.1:1".into()])
345 .batch_flush_interval(custom);
346 assert_eq!(builder.flush_interval, custom);
347 }
348
349 #[tokio::test]
350 async fn retry_policy_override_propagates_to_builder() {
351 let policy = RetryPolicy {
356 max_attempts: 7,
357 per_attempt_deadline: Duration::from_millis(11),
358 overall_deadline: Duration::from_millis(13),
359 base_backoff: Duration::from_millis(17),
360 leader_ttl: Duration::from_millis(19),
361 };
362 let builder = ClientBuilder::endpoints(vec!["http://127.0.0.1:1".into()])
363 .retry_policy(policy.clone());
364 assert_eq!(builder.retry_policy.max_attempts, policy.max_attempts);
365 assert_eq!(
366 builder.retry_policy.per_attempt_deadline,
367 policy.per_attempt_deadline
368 );
369 assert_eq!(
370 builder.retry_policy.overall_deadline,
371 policy.overall_deadline
372 );
373 assert_eq!(builder.retry_policy.base_backoff, policy.base_backoff);
374 assert_eq!(builder.retry_policy.leader_ttl, policy.leader_ttl);
375 }
376
377 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
378 async fn get_ts_returns_within_overall_deadline_when_all_endpoints_unreachable() {
379 let policy = RetryPolicy {
386 max_attempts: 3,
387 per_attempt_deadline: Duration::from_millis(100),
388 overall_deadline: Duration::from_millis(300),
389 base_backoff: Duration::ZERO,
390 leader_ttl: Duration::from_secs(30),
391 };
392 let client = ClientBuilder::endpoints(vec![
393 "http://127.0.0.1:1".into(),
394 "http://127.0.0.1:2".into(),
395 "http://127.0.0.1:3".into(),
396 ])
397 .retry_policy(policy)
398 .build()
399 .await
400 .expect("builder must accept the policy");
401 let start = std::time::Instant::now();
402 let result = client.get_ts().await;
403 let elapsed = start.elapsed();
404 assert!(result.is_err(), "no listener can reply: {result:?}");
405 assert!(
406 elapsed < Duration::from_secs(2),
407 "deadline must short-circuit; took {elapsed:?}"
408 );
409 }
410}