1use std::error::Error as StdError;
45use std::future::Future;
46use std::result::Result as StdResult;
47use std::time::Duration;
48
49use sift_error::prelude::*;
50use tonic;
51
52#[derive(Debug, Clone)]
54pub struct RetryConfig {
55 pub max_attempts: usize,
57 pub base_delay: Duration,
59 pub max_delay: Duration,
61 pub backoff_multiplier: f64,
63}
64
65impl Default for RetryConfig {
66 fn default() -> Self {
72 Self {
73 max_attempts: 3,
74 base_delay: Duration::from_millis(100),
75 max_delay: Duration::from_secs(5),
76 backoff_multiplier: 2.0,
77 }
78 }
79}
80
81impl RetryConfig {
82 pub fn backoff(&self, attempt: usize) -> Duration {
91 if attempt <= 1 {
92 return self.base_delay;
93 }
94
95 let exponent = (attempt - 1) as f64;
96 let delay_ms = self.base_delay.as_millis() as f64 * self.backoff_multiplier.powf(exponent);
97 let delay = Duration::from_millis(delay_ms as u64);
98
99 delay.min(self.max_delay)
100 }
101}
102
103pub trait RetryDecider<E> {
105 fn should_retry(&self, err: &E) -> bool;
107}
108
109pub struct DefaultGrpcRetry;
115
116impl RetryDecider<sift_error::Error> for DefaultGrpcRetry {
117 fn should_retry(&self, err: &sift_error::Error) -> bool {
118 let mut source = err.source();
120 while let Some(err_ref) = source {
121 if let Some(status) = err_ref.downcast_ref::<tonic::Status>() {
122 return matches!(
123 status.code(),
124 tonic::Code::Unavailable
125 | tonic::Code::ResourceExhausted
126 | tonic::Code::DeadlineExceeded
127 );
128 }
129 source = err_ref.source();
130 }
131
132 matches!(
134 err.kind(),
135 ErrorKind::GrpcConnectError
136 | ErrorKind::RetrieveAssetError
137 | ErrorKind::RetrieveIngestionConfigError
138 | ErrorKind::RetrieveRunError
139 )
140 }
141}
142
143#[derive(Clone, Debug)]
145pub struct Retrying<T, D = DefaultGrpcRetry> {
146 inner: T,
147 cfg: RetryConfig,
148 decider: D,
149}
150
151impl<T> Retrying<T> {
152 pub fn new(inner: T, cfg: RetryConfig) -> Self {
154 Self {
155 inner,
156 cfg,
157 decider: DefaultGrpcRetry,
158 }
159 }
160}
161
162impl<T, D> Retrying<T, D> {
163 pub fn with_decider<D2>(self, decider: D2) -> Retrying<T, D2> {
165 Retrying {
166 inner: self.inner,
167 cfg: self.cfg,
168 decider,
169 }
170 }
171
172 pub fn inner(&self) -> &T {
174 &self.inner
175 }
176}
177
178impl<T, D> Retrying<T, D>
179where
180 T: Clone,
181{
182 pub async fn call<F, Fut, R, E>(&self, mut f: F) -> StdResult<R, E>
199 where
200 F: FnMut(T) -> Fut,
201 Fut: Future<Output = StdResult<R, E>>,
202 D: RetryDecider<E>,
203 {
204 let mut last_err = None;
205
206 for attempt in 1..=self.cfg.max_attempts {
207 let wrapper = self.inner.clone();
208 match f(wrapper).await {
209 Ok(result) => return Ok(result),
210 Err(e) => {
211 last_err = Some(e);
212 if attempt < self.cfg.max_attempts
213 && self.decider.should_retry(last_err.as_ref().unwrap())
214 {
215 let delay = self.cfg.backoff(attempt);
216 tokio::time::sleep(delay).await;
217 continue;
218 }
219 break;
220 }
221 }
222 }
223
224 Err(last_err.expect("retry loop invariant violated"))
225 }
226}
227
228pub trait RetryExt: Sized {
230 fn retrying(self, cfg: RetryConfig) -> Retrying<Self> {
232 Retrying::new(self, cfg)
233 }
234}
235
236impl<T> RetryExt for T {}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::sync::Arc;
242 use std::sync::atomic::{AtomicUsize, Ordering};
243
244 #[test]
245 fn test_backoff_calculation() {
246 let cfg = RetryConfig::default();
247
248 assert_eq!(cfg.backoff(1), Duration::from_millis(100));
250
251 assert_eq!(cfg.backoff(2), Duration::from_millis(200));
253
254 assert_eq!(cfg.backoff(3), Duration::from_millis(400));
256
257 assert_eq!(cfg.backoff(4), Duration::from_millis(800));
259 }
260
261 #[test]
262 fn test_backoff_caps_at_max() {
263 let cfg = RetryConfig {
264 max_attempts: 10,
265 base_delay: Duration::from_millis(100),
266 max_delay: Duration::from_millis(500),
267 backoff_multiplier: 2.0,
268 };
269
270 let delay = cfg.backoff(10);
272 assert_eq!(delay, Duration::from_millis(500));
273 }
274
275 #[tokio::test]
276 async fn test_retry_loop_succeeds_after_failures() {
277 let counter = Arc::new(AtomicUsize::new(0));
278
279 let cfg = RetryConfig {
280 max_attempts: 3,
281 base_delay: Duration::from_millis(10),
282 max_delay: Duration::from_secs(5),
283 backoff_multiplier: 2.0,
284 };
285
286 let retrying = Retrying::new((), cfg);
287
288 let result = retrying
289 .call(|_| {
290 let counter = counter.clone();
291 async move {
292 let attempts = counter.fetch_add(1, Ordering::SeqCst) + 1;
293 if attempts < 3 {
294 Err::<(), sift_error::Error>(Error::new_msg(
295 ErrorKind::RetrieveAssetError,
296 "temporary failure",
297 ))
298 } else {
299 Ok(())
300 }
301 }
302 })
303 .await;
304
305 assert!(result.is_ok());
306 assert_eq!(counter.load(Ordering::SeqCst), 3);
307 }
308
309 #[tokio::test]
310 async fn test_retry_loop_exhausts_attempts() {
311 let counter = Arc::new(AtomicUsize::new(0));
312
313 let cfg = RetryConfig {
314 max_attempts: 3,
315 base_delay: Duration::from_millis(10),
316 max_delay: Duration::from_secs(5),
317 backoff_multiplier: 2.0,
318 };
319
320 let retrying = Retrying::new((), cfg);
321
322 let result = retrying
323 .call(|_| {
324 let counter = counter.clone();
325 async move {
326 counter.fetch_add(1, Ordering::SeqCst);
327 Err::<(), sift_error::Error>(Error::new_msg(
328 ErrorKind::RetrieveAssetError,
329 "persistent failure",
330 ))
331 }
332 })
333 .await;
334
335 assert!(result.is_err());
336 assert_eq!(counter.load(Ordering::SeqCst), 3);
337 }
338
339 #[test]
340 fn test_default_grpc_retry_with_tonic_status() {
341 let decider = DefaultGrpcRetry;
342
343 let unavailable = Error::new(
345 ErrorKind::RetrieveAssetError,
346 tonic::Status::unavailable("service unavailable"),
347 );
348 assert!(decider.should_retry(&unavailable));
349
350 let resource_exhausted = Error::new(
351 ErrorKind::RetrieveAssetError,
352 tonic::Status::resource_exhausted("resource exhausted"),
353 );
354 assert!(decider.should_retry(&resource_exhausted));
355
356 let deadline_exceeded = Error::new(
357 ErrorKind::RetrieveAssetError,
358 tonic::Status::deadline_exceeded("deadline exceeded"),
359 );
360 assert!(decider.should_retry(&deadline_exceeded));
361
362 let invalid_argument = Error::new(
364 ErrorKind::ArgumentValidationError,
365 tonic::Status::invalid_argument("invalid argument"),
366 );
367 assert!(!decider.should_retry(&invalid_argument));
368
369 let not_found = Error::new(
370 ErrorKind::NotFoundError,
371 tonic::Status::not_found("not found"),
372 );
373 assert!(!decider.should_retry(¬_found));
374 }
375
376 #[test]
377 fn test_default_grpc_retry_with_error_kind_fallback() {
378 let decider = DefaultGrpcRetry;
379
380 let grpc_connect_error = Error::new_msg(ErrorKind::GrpcConnectError, "connection failed");
382 assert!(decider.should_retry(&grpc_connect_error));
383
384 let retrieve_asset_error =
385 Error::new_msg(ErrorKind::RetrieveAssetError, "retrieval failed");
386 assert!(decider.should_retry(&retrieve_asset_error));
387
388 let retrieve_ingestion_config_error =
389 Error::new_msg(ErrorKind::RetrieveIngestionConfigError, "retrieval failed");
390 assert!(decider.should_retry(&retrieve_ingestion_config_error));
391
392 let retrieve_run_error = Error::new_msg(ErrorKind::RetrieveRunError, "retrieval failed");
393 assert!(decider.should_retry(&retrieve_run_error));
394
395 let argument_error = Error::new_msg(ErrorKind::ArgumentValidationError, "bad argument");
397 assert!(!decider.should_retry(&argument_error));
398
399 let not_found_error = Error::new_msg(ErrorKind::NotFoundError, "not found");
400 assert!(!decider.should_retry(¬_found_error));
401 }
402
403 #[tokio::test]
404 async fn test_no_retry_on_non_retryable_error() {
405 let counter = Arc::new(AtomicUsize::new(0));
406
407 let cfg = RetryConfig {
408 max_attempts: 3,
409 base_delay: Duration::from_millis(10),
410 max_delay: Duration::from_secs(5),
411 backoff_multiplier: 2.0,
412 };
413
414 let retrying = Retrying::new((), cfg);
415
416 let result = retrying
417 .call(|_| {
418 let counter = counter.clone();
419 async move {
420 counter.fetch_add(1, Ordering::SeqCst);
421 Err::<(), sift_error::Error>(Error::new(
423 ErrorKind::ArgumentValidationError,
424 tonic::Status::invalid_argument("invalid argument"),
425 ))
426 }
427 })
428 .await;
429
430 assert!(result.is_err());
431 assert_eq!(counter.load(Ordering::SeqCst), 1);
433 }
434}