Skip to main content

talos_api_rs/runtime/
retry.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Retry policies and backoff strategies for resilient API calls.
4//!
5//! This module provides configurable retry behavior for handling transient failures
6//! in gRPC communication with Talos nodes.
7//!
8//! # Example
9//!
10//! ```
11//! use talos_api_rs::runtime::{RetryConfig, ExponentialBackoff};
12//! use std::time::Duration;
13//!
14//! let retry = RetryConfig::builder()
15//!     .max_retries(3)
16//!     .backoff(ExponentialBackoff::new(Duration::from_millis(100)))
17//!     .build();
18//! ```
19
20use std::time::Duration;
21
22/// Defines a backoff strategy for retry delays.
23pub trait BackoffStrategy: Clone + Send + Sync + 'static {
24    /// Calculate the delay before the next retry attempt.
25    ///
26    /// # Arguments
27    /// * `attempt` - The current attempt number (0-indexed)
28    fn delay(&self, attempt: u32) -> Duration;
29}
30
31// =============================================================================
32// No Backoff
33// =============================================================================
34
35/// No delay between retries.
36#[derive(Debug, Clone, Copy, Default)]
37pub struct NoBackoff;
38
39impl NoBackoff {
40    /// Create a new no-backoff strategy.
41    #[must_use]
42    pub fn new() -> Self {
43        Self
44    }
45}
46
47impl BackoffStrategy for NoBackoff {
48    fn delay(&self, _attempt: u32) -> Duration {
49        Duration::ZERO
50    }
51}
52
53// =============================================================================
54// Fixed Backoff
55// =============================================================================
56
57/// Fixed delay between retries.
58#[derive(Debug, Clone, Copy)]
59pub struct FixedBackoff {
60    delay: Duration,
61}
62
63impl FixedBackoff {
64    /// Create a new fixed backoff strategy.
65    #[must_use]
66    pub fn new(delay: Duration) -> Self {
67        Self { delay }
68    }
69
70    /// Create a fixed backoff with delay in milliseconds.
71    #[must_use]
72    pub fn from_millis(millis: u64) -> Self {
73        Self::new(Duration::from_millis(millis))
74    }
75
76    /// Create a fixed backoff with delay in seconds.
77    #[must_use]
78    pub fn from_secs(secs: u64) -> Self {
79        Self::new(Duration::from_secs(secs))
80    }
81}
82
83impl Default for FixedBackoff {
84    fn default() -> Self {
85        Self::new(Duration::from_millis(100))
86    }
87}
88
89impl BackoffStrategy for FixedBackoff {
90    fn delay(&self, _attempt: u32) -> Duration {
91        self.delay
92    }
93}
94
95// =============================================================================
96// Linear Backoff
97// =============================================================================
98
99/// Linear backoff - delay increases linearly with each attempt.
100#[derive(Debug, Clone, Copy)]
101pub struct LinearBackoff {
102    initial_delay: Duration,
103    increment: Duration,
104    max_delay: Duration,
105}
106
107impl LinearBackoff {
108    /// Create a new linear backoff strategy.
109    #[must_use]
110    pub fn new(initial_delay: Duration) -> Self {
111        Self {
112            initial_delay,
113            increment: initial_delay,
114            max_delay: Duration::from_secs(30),
115        }
116    }
117
118    /// Set the increment for each retry.
119    #[must_use]
120    pub fn with_increment(mut self, increment: Duration) -> Self {
121        self.increment = increment;
122        self
123    }
124
125    /// Set the maximum delay cap.
126    #[must_use]
127    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
128        self.max_delay = max_delay;
129        self
130    }
131}
132
133impl Default for LinearBackoff {
134    fn default() -> Self {
135        Self::new(Duration::from_millis(100))
136    }
137}
138
139impl BackoffStrategy for LinearBackoff {
140    fn delay(&self, attempt: u32) -> Duration {
141        let delay = self.initial_delay + self.increment * attempt;
142        delay.min(self.max_delay)
143    }
144}
145
146// =============================================================================
147// Exponential Backoff
148// =============================================================================
149
150/// Exponential backoff - delay doubles with each attempt.
151///
152/// Optionally includes jitter to prevent thundering herd.
153#[derive(Debug, Clone, Copy)]
154pub struct ExponentialBackoff {
155    initial_delay: Duration,
156    max_delay: Duration,
157    multiplier: f64,
158    jitter: bool,
159}
160
161impl ExponentialBackoff {
162    /// Create a new exponential backoff strategy.
163    #[must_use]
164    pub fn new(initial_delay: Duration) -> Self {
165        Self {
166            initial_delay,
167            max_delay: Duration::from_secs(30),
168            multiplier: 2.0,
169            jitter: true,
170        }
171    }
172
173    /// Set the maximum delay cap.
174    #[must_use]
175    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
176        self.max_delay = max_delay;
177        self
178    }
179
180    /// Set the multiplier for exponential growth.
181    #[must_use]
182    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
183        self.multiplier = multiplier;
184        self
185    }
186
187    /// Enable or disable jitter.
188    #[must_use]
189    pub fn with_jitter(mut self, jitter: bool) -> Self {
190        self.jitter = jitter;
191        self
192    }
193}
194
195impl Default for ExponentialBackoff {
196    fn default() -> Self {
197        Self::new(Duration::from_millis(100))
198    }
199}
200
201impl BackoffStrategy for ExponentialBackoff {
202    fn delay(&self, attempt: u32) -> Duration {
203        let base_delay =
204            self.initial_delay.as_millis() as f64 * self.multiplier.powi(attempt as i32);
205        let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
206
207        let final_delay = if self.jitter {
208            // Add up to 25% jitter
209            let jitter_range = capped_delay * 0.25;
210            // Simple deterministic jitter based on attempt number
211            let jitter = (attempt as f64 * 0.1).sin().abs() * jitter_range;
212            capped_delay + jitter
213        } else {
214            capped_delay
215        };
216
217        Duration::from_millis(final_delay as u64)
218    }
219}
220
221// =============================================================================
222// Retry Policy
223// =============================================================================
224
225/// Determines whether a gRPC error should be retried.
226pub trait RetryPolicy: Clone + Send + Sync + 'static {
227    /// Returns `true` if the operation should be retried for this error.
228    fn should_retry(&self, code: tonic::Code) -> bool;
229}
230
231/// Default retry policy - retries on transient errors.
232#[derive(Debug, Clone, Copy, Default)]
233pub struct DefaultRetryPolicy;
234
235impl RetryPolicy for DefaultRetryPolicy {
236    fn should_retry(&self, code: tonic::Code) -> bool {
237        matches!(
238            code,
239            tonic::Code::Unavailable
240                | tonic::Code::Unknown
241                | tonic::Code::DeadlineExceeded
242                | tonic::Code::ResourceExhausted
243                | tonic::Code::Aborted
244        )
245    }
246}
247
248/// Never retry - fail immediately.
249#[derive(Debug, Clone, Copy, Default)]
250pub struct NoRetryPolicy;
251
252impl RetryPolicy for NoRetryPolicy {
253    fn should_retry(&self, _code: tonic::Code) -> bool {
254        false
255    }
256}
257
258/// Custom retry policy based on a list of codes.
259#[derive(Debug, Clone)]
260pub struct CustomRetryPolicy {
261    retry_codes: Vec<tonic::Code>,
262}
263
264impl CustomRetryPolicy {
265    /// Create a policy that retries on specific codes.
266    #[must_use]
267    pub fn new(retry_codes: Vec<tonic::Code>) -> Self {
268        Self { retry_codes }
269    }
270
271    /// Create a policy for network-level errors only.
272    #[must_use]
273    pub fn network_errors() -> Self {
274        Self::new(vec![tonic::Code::Unavailable, tonic::Code::Unknown])
275    }
276}
277
278impl RetryPolicy for CustomRetryPolicy {
279    fn should_retry(&self, code: tonic::Code) -> bool {
280        self.retry_codes.contains(&code)
281    }
282}
283
284// =============================================================================
285// Retry Configuration
286// =============================================================================
287
288/// Complete retry configuration combining policy and backoff.
289#[derive(Debug, Clone)]
290pub struct RetryConfig<P: RetryPolicy = DefaultRetryPolicy, B: BackoffStrategy = ExponentialBackoff>
291{
292    /// Maximum number of retry attempts.
293    pub max_retries: u32,
294    /// Policy determining which errors to retry.
295    pub policy: P,
296    /// Backoff strategy for calculating delays.
297    pub backoff: B,
298    /// Maximum total time for all retries.
299    pub total_timeout: Option<Duration>,
300}
301
302impl Default for RetryConfig {
303    fn default() -> Self {
304        Self {
305            max_retries: 3,
306            policy: DefaultRetryPolicy,
307            backoff: ExponentialBackoff::default(),
308            total_timeout: Some(Duration::from_secs(30)),
309        }
310    }
311}
312
313impl RetryConfig {
314    /// Create a new retry configuration with defaults.
315    #[must_use]
316    pub fn new() -> Self {
317        Self::default()
318    }
319
320    /// Create a configuration builder.
321    #[must_use]
322    pub fn builder() -> RetryConfigBuilder<DefaultRetryPolicy, ExponentialBackoff> {
323        RetryConfigBuilder::new()
324    }
325
326    /// Disable retries.
327    #[must_use]
328    pub fn disabled() -> RetryConfig<NoRetryPolicy, NoBackoff> {
329        RetryConfig {
330            max_retries: 0,
331            policy: NoRetryPolicy,
332            backoff: NoBackoff,
333            total_timeout: None,
334        }
335    }
336}
337
338impl<P: RetryPolicy, B: BackoffStrategy> RetryConfig<P, B> {
339    /// Execute an async operation with retry logic.
340    pub async fn execute<T, E, F, Fut>(&self, mut operation: F) -> Result<T, E>
341    where
342        F: FnMut() -> Fut,
343        Fut: std::future::Future<Output = Result<T, E>>,
344        E: AsGrpcStatus,
345    {
346        let start = std::time::Instant::now();
347        let mut attempt = 0;
348
349        loop {
350            match operation().await {
351                Ok(result) => return Ok(result),
352                Err(e) => {
353                    let code = e.grpc_code();
354
355                    // Check if we should retry
356                    if !self.policy.should_retry(code) {
357                        return Err(e);
358                    }
359
360                    // Check if we've exceeded max retries
361                    if attempt >= self.max_retries {
362                        return Err(e);
363                    }
364
365                    // Check total timeout
366                    if let Some(timeout) = self.total_timeout {
367                        if start.elapsed() >= timeout {
368                            return Err(e);
369                        }
370                    }
371
372                    // Calculate delay and sleep
373                    let delay = self.backoff.delay(attempt);
374                    tokio::time::sleep(delay).await;
375
376                    attempt += 1;
377                }
378            }
379        }
380    }
381}
382
383/// Builder for `RetryConfig`.
384#[derive(Debug, Clone)]
385pub struct RetryConfigBuilder<P: RetryPolicy, B: BackoffStrategy> {
386    max_retries: u32,
387    policy: P,
388    backoff: B,
389    total_timeout: Option<Duration>,
390}
391
392impl RetryConfigBuilder<DefaultRetryPolicy, ExponentialBackoff> {
393    /// Create a new builder with defaults.
394    #[must_use]
395    pub fn new() -> Self {
396        Self {
397            max_retries: 3,
398            policy: DefaultRetryPolicy,
399            backoff: ExponentialBackoff::default(),
400            total_timeout: Some(Duration::from_secs(30)),
401        }
402    }
403}
404
405impl Default for RetryConfigBuilder<DefaultRetryPolicy, ExponentialBackoff> {
406    fn default() -> Self {
407        Self::new()
408    }
409}
410
411impl<P: RetryPolicy, B: BackoffStrategy> RetryConfigBuilder<P, B> {
412    /// Set maximum retry attempts.
413    #[must_use]
414    pub fn max_retries(mut self, max: u32) -> Self {
415        self.max_retries = max;
416        self
417    }
418
419    /// Set the retry policy.
420    #[must_use]
421    pub fn policy<P2: RetryPolicy>(self, policy: P2) -> RetryConfigBuilder<P2, B> {
422        RetryConfigBuilder {
423            max_retries: self.max_retries,
424            policy,
425            backoff: self.backoff,
426            total_timeout: self.total_timeout,
427        }
428    }
429
430    /// Set the backoff strategy.
431    #[must_use]
432    pub fn backoff<B2: BackoffStrategy>(self, backoff: B2) -> RetryConfigBuilder<P, B2> {
433        RetryConfigBuilder {
434            max_retries: self.max_retries,
435            policy: self.policy,
436            backoff,
437            total_timeout: self.total_timeout,
438        }
439    }
440
441    /// Set the total timeout for all retries.
442    #[must_use]
443    pub fn total_timeout(mut self, timeout: Duration) -> Self {
444        self.total_timeout = Some(timeout);
445        self
446    }
447
448    /// Disable total timeout.
449    #[must_use]
450    pub fn no_total_timeout(mut self) -> Self {
451        self.total_timeout = None;
452        self
453    }
454
455    /// Build the configuration.
456    #[must_use]
457    pub fn build(self) -> RetryConfig<P, B> {
458        RetryConfig {
459            max_retries: self.max_retries,
460            policy: self.policy,
461            backoff: self.backoff,
462            total_timeout: self.total_timeout,
463        }
464    }
465}
466
467/// Trait for extracting gRPC status codes from errors.
468pub trait AsGrpcStatus {
469    /// Extract the gRPC status code.
470    fn grpc_code(&self) -> tonic::Code;
471}
472
473impl AsGrpcStatus for tonic::Status {
474    fn grpc_code(&self) -> tonic::Code {
475        self.code()
476    }
477}
478
479impl<T> AsGrpcStatus for Result<T, tonic::Status> {
480    fn grpc_code(&self) -> tonic::Code {
481        match self {
482            Ok(_) => tonic::Code::Ok,
483            Err(e) => e.code(),
484        }
485    }
486}
487
488// Implement for our error type
489impl AsGrpcStatus for crate::error::TalosError {
490    fn grpc_code(&self) -> tonic::Code {
491        match self {
492            crate::error::TalosError::Api(status) => status.code(),
493            crate::error::TalosError::Transport(_) => tonic::Code::Unavailable,
494            crate::error::TalosError::Config(_) => tonic::Code::InvalidArgument,
495            crate::error::TalosError::Validation(_) => tonic::Code::InvalidArgument,
496            crate::error::TalosError::Connection(_) => tonic::Code::Unavailable,
497            crate::error::TalosError::CircuitOpen(_) => tonic::Code::Unavailable,
498            crate::error::TalosError::Unknown(_) => tonic::Code::Internal,
499        }
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_no_backoff() {
509        let backoff = NoBackoff::new();
510        assert_eq!(backoff.delay(0), Duration::ZERO);
511        assert_eq!(backoff.delay(5), Duration::ZERO);
512        assert_eq!(backoff.delay(100), Duration::ZERO);
513    }
514
515    #[test]
516    fn test_fixed_backoff() {
517        let backoff = FixedBackoff::from_millis(100);
518        assert_eq!(backoff.delay(0), Duration::from_millis(100));
519        assert_eq!(backoff.delay(5), Duration::from_millis(100));
520        assert_eq!(backoff.delay(100), Duration::from_millis(100));
521    }
522
523    #[test]
524    fn test_linear_backoff() {
525        let backoff = LinearBackoff::new(Duration::from_millis(100))
526            .with_increment(Duration::from_millis(50))
527            .with_max_delay(Duration::from_millis(500));
528
529        assert_eq!(backoff.delay(0), Duration::from_millis(100));
530        assert_eq!(backoff.delay(1), Duration::from_millis(150));
531        assert_eq!(backoff.delay(2), Duration::from_millis(200));
532        assert_eq!(backoff.delay(10), Duration::from_millis(500)); // Capped
533    }
534
535    #[test]
536    fn test_exponential_backoff() {
537        let backoff = ExponentialBackoff::new(Duration::from_millis(100))
538            .with_max_delay(Duration::from_secs(10))
539            .with_jitter(false);
540
541        assert_eq!(backoff.delay(0), Duration::from_millis(100));
542        assert_eq!(backoff.delay(1), Duration::from_millis(200));
543        assert_eq!(backoff.delay(2), Duration::from_millis(400));
544        assert_eq!(backoff.delay(3), Duration::from_millis(800));
545    }
546
547    #[test]
548    fn test_exponential_backoff_cap() {
549        let backoff = ExponentialBackoff::new(Duration::from_millis(100))
550            .with_max_delay(Duration::from_millis(500))
551            .with_jitter(false);
552
553        assert_eq!(backoff.delay(5), Duration::from_millis(500)); // Capped at 500ms
554    }
555
556    #[test]
557    fn test_default_retry_policy() {
558        let policy = DefaultRetryPolicy;
559
560        assert!(policy.should_retry(tonic::Code::Unavailable));
561        assert!(policy.should_retry(tonic::Code::DeadlineExceeded));
562        assert!(policy.should_retry(tonic::Code::ResourceExhausted));
563        assert!(policy.should_retry(tonic::Code::Aborted));
564
565        assert!(!policy.should_retry(tonic::Code::InvalidArgument));
566        assert!(!policy.should_retry(tonic::Code::NotFound));
567        assert!(!policy.should_retry(tonic::Code::PermissionDenied));
568        assert!(!policy.should_retry(tonic::Code::AlreadyExists));
569    }
570
571    #[test]
572    fn test_no_retry_policy() {
573        let policy = NoRetryPolicy;
574
575        assert!(!policy.should_retry(tonic::Code::Unavailable));
576        assert!(!policy.should_retry(tonic::Code::Unknown));
577    }
578
579    #[test]
580    fn test_custom_retry_policy() {
581        let policy = CustomRetryPolicy::network_errors();
582
583        assert!(policy.should_retry(tonic::Code::Unavailable));
584        assert!(policy.should_retry(tonic::Code::Unknown));
585        assert!(!policy.should_retry(tonic::Code::DeadlineExceeded));
586    }
587
588    #[test]
589    fn test_retry_config_builder() {
590        let config = RetryConfig::builder()
591            .max_retries(5)
592            .backoff(FixedBackoff::from_millis(200))
593            .total_timeout(Duration::from_secs(60))
594            .build();
595
596        assert_eq!(config.max_retries, 5);
597        assert_eq!(config.total_timeout, Some(Duration::from_secs(60)));
598    }
599
600    #[test]
601    fn test_retry_config_disabled() {
602        let config = RetryConfig::disabled();
603
604        assert_eq!(config.max_retries, 0);
605        assert_eq!(config.total_timeout, None);
606    }
607
608    #[tokio::test]
609    async fn test_retry_execute_success() {
610        let config = RetryConfig::default();
611
612        let result: Result<i32, tonic::Status> = config.execute(|| async { Ok(42) }).await;
613
614        assert_eq!(result.unwrap(), 42);
615    }
616
617    #[tokio::test]
618    async fn test_retry_execute_transient_failure() {
619        let config = RetryConfig::builder()
620            .max_retries(3)
621            .backoff(NoBackoff::new())
622            .build();
623
624        let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
625        let call_count_clone = call_count.clone();
626
627        let result: Result<i32, tonic::Status> = config
628            .execute(|| {
629                let count = call_count_clone.clone();
630                async move {
631                    let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
632                    if n < 2 {
633                        Err(tonic::Status::unavailable("transient"))
634                    } else {
635                        Ok(42)
636                    }
637                }
638            })
639            .await;
640
641        assert_eq!(result.unwrap(), 42);
642        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
643    }
644
645    #[tokio::test]
646    async fn test_retry_execute_permanent_failure() {
647        let config = RetryConfig::builder()
648            .max_retries(3)
649            .backoff(NoBackoff::new())
650            .build();
651
652        let result: Result<i32, tonic::Status> = config
653            .execute(|| async { Err(tonic::Status::invalid_argument("bad input")) })
654            .await;
655
656        assert!(result.is_err());
657        assert_eq!(result.unwrap_err().code(), tonic::Code::InvalidArgument);
658    }
659}