1use std::time::Duration;
21
22pub trait BackoffStrategy: Clone + Send + Sync + 'static {
24 fn delay(&self, attempt: u32) -> Duration;
29}
30
31#[derive(Debug, Clone, Copy, Default)]
37pub struct NoBackoff;
38
39impl NoBackoff {
40 #[must_use]
42 pub fn new() -> Self {
43 Self
44 }
45}
46
47impl BackoffStrategy for NoBackoff {
48 fn delay(&self, _attempt: u32) -> Duration {
49 Duration::ZERO
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
59pub struct FixedBackoff {
60 delay: Duration,
61}
62
63impl FixedBackoff {
64 #[must_use]
66 pub fn new(delay: Duration) -> Self {
67 Self { delay }
68 }
69
70 #[must_use]
72 pub fn from_millis(millis: u64) -> Self {
73 Self::new(Duration::from_millis(millis))
74 }
75
76 #[must_use]
78 pub fn from_secs(secs: u64) -> Self {
79 Self::new(Duration::from_secs(secs))
80 }
81}
82
83impl Default for FixedBackoff {
84 fn default() -> Self {
85 Self::new(Duration::from_millis(100))
86 }
87}
88
89impl BackoffStrategy for FixedBackoff {
90 fn delay(&self, _attempt: u32) -> Duration {
91 self.delay
92 }
93}
94
95#[derive(Debug, Clone, Copy)]
101pub struct LinearBackoff {
102 initial_delay: Duration,
103 increment: Duration,
104 max_delay: Duration,
105}
106
107impl LinearBackoff {
108 #[must_use]
110 pub fn new(initial_delay: Duration) -> Self {
111 Self {
112 initial_delay,
113 increment: initial_delay,
114 max_delay: Duration::from_secs(30),
115 }
116 }
117
118 #[must_use]
120 pub fn with_increment(mut self, increment: Duration) -> Self {
121 self.increment = increment;
122 self
123 }
124
125 #[must_use]
127 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
128 self.max_delay = max_delay;
129 self
130 }
131}
132
133impl Default for LinearBackoff {
134 fn default() -> Self {
135 Self::new(Duration::from_millis(100))
136 }
137}
138
139impl BackoffStrategy for LinearBackoff {
140 fn delay(&self, attempt: u32) -> Duration {
141 let delay = self.initial_delay + self.increment * attempt;
142 delay.min(self.max_delay)
143 }
144}
145
146#[derive(Debug, Clone, Copy)]
154pub struct ExponentialBackoff {
155 initial_delay: Duration,
156 max_delay: Duration,
157 multiplier: f64,
158 jitter: bool,
159}
160
161impl ExponentialBackoff {
162 #[must_use]
164 pub fn new(initial_delay: Duration) -> Self {
165 Self {
166 initial_delay,
167 max_delay: Duration::from_secs(30),
168 multiplier: 2.0,
169 jitter: true,
170 }
171 }
172
173 #[must_use]
175 pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
176 self.max_delay = max_delay;
177 self
178 }
179
180 #[must_use]
182 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
183 self.multiplier = multiplier;
184 self
185 }
186
187 #[must_use]
189 pub fn with_jitter(mut self, jitter: bool) -> Self {
190 self.jitter = jitter;
191 self
192 }
193}
194
195impl Default for ExponentialBackoff {
196 fn default() -> Self {
197 Self::new(Duration::from_millis(100))
198 }
199}
200
201impl BackoffStrategy for ExponentialBackoff {
202 fn delay(&self, attempt: u32) -> Duration {
203 let base_delay =
204 self.initial_delay.as_millis() as f64 * self.multiplier.powi(attempt as i32);
205 let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
206
207 let final_delay = if self.jitter {
208 let jitter_range = capped_delay * 0.25;
210 let jitter = (attempt as f64 * 0.1).sin().abs() * jitter_range;
212 capped_delay + jitter
213 } else {
214 capped_delay
215 };
216
217 Duration::from_millis(final_delay as u64)
218 }
219}
220
221pub trait RetryPolicy: Clone + Send + Sync + 'static {
227 fn should_retry(&self, code: tonic::Code) -> bool;
229}
230
231#[derive(Debug, Clone, Copy, Default)]
233pub struct DefaultRetryPolicy;
234
235impl RetryPolicy for DefaultRetryPolicy {
236 fn should_retry(&self, code: tonic::Code) -> bool {
237 matches!(
238 code,
239 tonic::Code::Unavailable
240 | tonic::Code::Unknown
241 | tonic::Code::DeadlineExceeded
242 | tonic::Code::ResourceExhausted
243 | tonic::Code::Aborted
244 )
245 }
246}
247
248#[derive(Debug, Clone, Copy, Default)]
250pub struct NoRetryPolicy;
251
252impl RetryPolicy for NoRetryPolicy {
253 fn should_retry(&self, _code: tonic::Code) -> bool {
254 false
255 }
256}
257
258#[derive(Debug, Clone)]
260pub struct CustomRetryPolicy {
261 retry_codes: Vec<tonic::Code>,
262}
263
264impl CustomRetryPolicy {
265 #[must_use]
267 pub fn new(retry_codes: Vec<tonic::Code>) -> Self {
268 Self { retry_codes }
269 }
270
271 #[must_use]
273 pub fn network_errors() -> Self {
274 Self::new(vec![tonic::Code::Unavailable, tonic::Code::Unknown])
275 }
276}
277
278impl RetryPolicy for CustomRetryPolicy {
279 fn should_retry(&self, code: tonic::Code) -> bool {
280 self.retry_codes.contains(&code)
281 }
282}
283
284#[derive(Debug, Clone)]
290pub struct RetryConfig<P: RetryPolicy = DefaultRetryPolicy, B: BackoffStrategy = ExponentialBackoff>
291{
292 pub max_retries: u32,
294 pub policy: P,
296 pub backoff: B,
298 pub total_timeout: Option<Duration>,
300}
301
302impl Default for RetryConfig {
303 fn default() -> Self {
304 Self {
305 max_retries: 3,
306 policy: DefaultRetryPolicy,
307 backoff: ExponentialBackoff::default(),
308 total_timeout: Some(Duration::from_secs(30)),
309 }
310 }
311}
312
313impl RetryConfig {
314 #[must_use]
316 pub fn new() -> Self {
317 Self::default()
318 }
319
320 #[must_use]
322 pub fn builder() -> RetryConfigBuilder<DefaultRetryPolicy, ExponentialBackoff> {
323 RetryConfigBuilder::new()
324 }
325
326 #[must_use]
328 pub fn disabled() -> RetryConfig<NoRetryPolicy, NoBackoff> {
329 RetryConfig {
330 max_retries: 0,
331 policy: NoRetryPolicy,
332 backoff: NoBackoff,
333 total_timeout: None,
334 }
335 }
336}
337
338impl<P: RetryPolicy, B: BackoffStrategy> RetryConfig<P, B> {
339 pub async fn execute<T, E, F, Fut>(&self, mut operation: F) -> Result<T, E>
341 where
342 F: FnMut() -> Fut,
343 Fut: std::future::Future<Output = Result<T, E>>,
344 E: AsGrpcStatus,
345 {
346 let start = std::time::Instant::now();
347 let mut attempt = 0;
348
349 loop {
350 match operation().await {
351 Ok(result) => return Ok(result),
352 Err(e) => {
353 let code = e.grpc_code();
354
355 if !self.policy.should_retry(code) {
357 return Err(e);
358 }
359
360 if attempt >= self.max_retries {
362 return Err(e);
363 }
364
365 if let Some(timeout) = self.total_timeout {
367 if start.elapsed() >= timeout {
368 return Err(e);
369 }
370 }
371
372 let delay = self.backoff.delay(attempt);
374 tokio::time::sleep(delay).await;
375
376 attempt += 1;
377 }
378 }
379 }
380 }
381}
382
383#[derive(Debug, Clone)]
385pub struct RetryConfigBuilder<P: RetryPolicy, B: BackoffStrategy> {
386 max_retries: u32,
387 policy: P,
388 backoff: B,
389 total_timeout: Option<Duration>,
390}
391
392impl RetryConfigBuilder<DefaultRetryPolicy, ExponentialBackoff> {
393 #[must_use]
395 pub fn new() -> Self {
396 Self {
397 max_retries: 3,
398 policy: DefaultRetryPolicy,
399 backoff: ExponentialBackoff::default(),
400 total_timeout: Some(Duration::from_secs(30)),
401 }
402 }
403}
404
405impl Default for RetryConfigBuilder<DefaultRetryPolicy, ExponentialBackoff> {
406 fn default() -> Self {
407 Self::new()
408 }
409}
410
411impl<P: RetryPolicy, B: BackoffStrategy> RetryConfigBuilder<P, B> {
412 #[must_use]
414 pub fn max_retries(mut self, max: u32) -> Self {
415 self.max_retries = max;
416 self
417 }
418
419 #[must_use]
421 pub fn policy<P2: RetryPolicy>(self, policy: P2) -> RetryConfigBuilder<P2, B> {
422 RetryConfigBuilder {
423 max_retries: self.max_retries,
424 policy,
425 backoff: self.backoff,
426 total_timeout: self.total_timeout,
427 }
428 }
429
430 #[must_use]
432 pub fn backoff<B2: BackoffStrategy>(self, backoff: B2) -> RetryConfigBuilder<P, B2> {
433 RetryConfigBuilder {
434 max_retries: self.max_retries,
435 policy: self.policy,
436 backoff,
437 total_timeout: self.total_timeout,
438 }
439 }
440
441 #[must_use]
443 pub fn total_timeout(mut self, timeout: Duration) -> Self {
444 self.total_timeout = Some(timeout);
445 self
446 }
447
448 #[must_use]
450 pub fn no_total_timeout(mut self) -> Self {
451 self.total_timeout = None;
452 self
453 }
454
455 #[must_use]
457 pub fn build(self) -> RetryConfig<P, B> {
458 RetryConfig {
459 max_retries: self.max_retries,
460 policy: self.policy,
461 backoff: self.backoff,
462 total_timeout: self.total_timeout,
463 }
464 }
465}
466
467pub trait AsGrpcStatus {
469 fn grpc_code(&self) -> tonic::Code;
471}
472
473impl AsGrpcStatus for tonic::Status {
474 fn grpc_code(&self) -> tonic::Code {
475 self.code()
476 }
477}
478
479impl<T> AsGrpcStatus for Result<T, tonic::Status> {
480 fn grpc_code(&self) -> tonic::Code {
481 match self {
482 Ok(_) => tonic::Code::Ok,
483 Err(e) => e.code(),
484 }
485 }
486}
487
488impl AsGrpcStatus for crate::error::TalosError {
490 fn grpc_code(&self) -> tonic::Code {
491 match self {
492 crate::error::TalosError::Api(status) => status.code(),
493 crate::error::TalosError::Transport(_) => tonic::Code::Unavailable,
494 crate::error::TalosError::Config(_) => tonic::Code::InvalidArgument,
495 crate::error::TalosError::Validation(_) => tonic::Code::InvalidArgument,
496 crate::error::TalosError::Connection(_) => tonic::Code::Unavailable,
497 crate::error::TalosError::CircuitOpen(_) => tonic::Code::Unavailable,
498 crate::error::TalosError::Unknown(_) => tonic::Code::Internal,
499 }
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_no_backoff() {
509 let backoff = NoBackoff::new();
510 assert_eq!(backoff.delay(0), Duration::ZERO);
511 assert_eq!(backoff.delay(5), Duration::ZERO);
512 assert_eq!(backoff.delay(100), Duration::ZERO);
513 }
514
515 #[test]
516 fn test_fixed_backoff() {
517 let backoff = FixedBackoff::from_millis(100);
518 assert_eq!(backoff.delay(0), Duration::from_millis(100));
519 assert_eq!(backoff.delay(5), Duration::from_millis(100));
520 assert_eq!(backoff.delay(100), Duration::from_millis(100));
521 }
522
523 #[test]
524 fn test_linear_backoff() {
525 let backoff = LinearBackoff::new(Duration::from_millis(100))
526 .with_increment(Duration::from_millis(50))
527 .with_max_delay(Duration::from_millis(500));
528
529 assert_eq!(backoff.delay(0), Duration::from_millis(100));
530 assert_eq!(backoff.delay(1), Duration::from_millis(150));
531 assert_eq!(backoff.delay(2), Duration::from_millis(200));
532 assert_eq!(backoff.delay(10), Duration::from_millis(500)); }
534
535 #[test]
536 fn test_exponential_backoff() {
537 let backoff = ExponentialBackoff::new(Duration::from_millis(100))
538 .with_max_delay(Duration::from_secs(10))
539 .with_jitter(false);
540
541 assert_eq!(backoff.delay(0), Duration::from_millis(100));
542 assert_eq!(backoff.delay(1), Duration::from_millis(200));
543 assert_eq!(backoff.delay(2), Duration::from_millis(400));
544 assert_eq!(backoff.delay(3), Duration::from_millis(800));
545 }
546
547 #[test]
548 fn test_exponential_backoff_cap() {
549 let backoff = ExponentialBackoff::new(Duration::from_millis(100))
550 .with_max_delay(Duration::from_millis(500))
551 .with_jitter(false);
552
553 assert_eq!(backoff.delay(5), Duration::from_millis(500)); }
555
556 #[test]
557 fn test_default_retry_policy() {
558 let policy = DefaultRetryPolicy;
559
560 assert!(policy.should_retry(tonic::Code::Unavailable));
561 assert!(policy.should_retry(tonic::Code::DeadlineExceeded));
562 assert!(policy.should_retry(tonic::Code::ResourceExhausted));
563 assert!(policy.should_retry(tonic::Code::Aborted));
564
565 assert!(!policy.should_retry(tonic::Code::InvalidArgument));
566 assert!(!policy.should_retry(tonic::Code::NotFound));
567 assert!(!policy.should_retry(tonic::Code::PermissionDenied));
568 assert!(!policy.should_retry(tonic::Code::AlreadyExists));
569 }
570
571 #[test]
572 fn test_no_retry_policy() {
573 let policy = NoRetryPolicy;
574
575 assert!(!policy.should_retry(tonic::Code::Unavailable));
576 assert!(!policy.should_retry(tonic::Code::Unknown));
577 }
578
579 #[test]
580 fn test_custom_retry_policy() {
581 let policy = CustomRetryPolicy::network_errors();
582
583 assert!(policy.should_retry(tonic::Code::Unavailable));
584 assert!(policy.should_retry(tonic::Code::Unknown));
585 assert!(!policy.should_retry(tonic::Code::DeadlineExceeded));
586 }
587
588 #[test]
589 fn test_retry_config_builder() {
590 let config = RetryConfig::builder()
591 .max_retries(5)
592 .backoff(FixedBackoff::from_millis(200))
593 .total_timeout(Duration::from_secs(60))
594 .build();
595
596 assert_eq!(config.max_retries, 5);
597 assert_eq!(config.total_timeout, Some(Duration::from_secs(60)));
598 }
599
600 #[test]
601 fn test_retry_config_disabled() {
602 let config = RetryConfig::disabled();
603
604 assert_eq!(config.max_retries, 0);
605 assert_eq!(config.total_timeout, None);
606 }
607
608 #[tokio::test]
609 async fn test_retry_execute_success() {
610 let config = RetryConfig::default();
611
612 let result: Result<i32, tonic::Status> = config.execute(|| async { Ok(42) }).await;
613
614 assert_eq!(result.unwrap(), 42);
615 }
616
617 #[tokio::test]
618 async fn test_retry_execute_transient_failure() {
619 let config = RetryConfig::builder()
620 .max_retries(3)
621 .backoff(NoBackoff::new())
622 .build();
623
624 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
625 let call_count_clone = call_count.clone();
626
627 let result: Result<i32, tonic::Status> = config
628 .execute(|| {
629 let count = call_count_clone.clone();
630 async move {
631 let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
632 if n < 2 {
633 Err(tonic::Status::unavailable("transient"))
634 } else {
635 Ok(42)
636 }
637 }
638 })
639 .await;
640
641 assert_eq!(result.unwrap(), 42);
642 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
643 }
644
645 #[tokio::test]
646 async fn test_retry_execute_permanent_failure() {
647 let config = RetryConfig::builder()
648 .max_retries(3)
649 .backoff(NoBackoff::new())
650 .build();
651
652 let result: Result<i32, tonic::Status> = config
653 .execute(|| async { Err(tonic::Status::invalid_argument("bad input")) })
654 .await;
655
656 assert!(result.is_err());
657 assert_eq!(result.unwrap_err().code(), tonic::Code::InvalidArgument);
658 }
659}