truthlinked_sdk/
retry.rs

1use crate::error::{Result, TruthlinkedError};
2use std::time::Duration;
3use tokio::time::sleep;
4
5/// Retry configuration with exponential backoff
6#[derive(Debug, Clone)]
7pub struct RetryConfig {
8    /// Maximum number of retry attempts
9    pub max_attempts: u32,
10    /// Initial delay between retries
11    pub initial_delay: Duration,
12    /// Maximum delay between retries
13    pub max_delay: Duration,
14    /// Backoff multiplier (typically 2.0)
15    pub backoff_multiplier: f64,
16    /// Random jitter factor (0.0 to 1.0)
17    pub jitter_factor: f64,
18}
19
20impl Default for RetryConfig {
21    fn default() -> Self {
22        Self {
23            max_attempts: 3,
24            initial_delay: Duration::from_secs(1),
25            max_delay: Duration::from_secs(30),
26            backoff_multiplier: 2.0,
27            jitter_factor: 0.1,
28        }
29    }
30}
31
32impl RetryConfig {
33    /// Create retry config for production use
34    pub fn production() -> Self {
35        Self {
36            max_attempts: 3,
37            initial_delay: Duration::from_millis(500),
38            max_delay: Duration::from_secs(10),
39            backoff_multiplier: 2.0,
40            jitter_factor: 0.1,
41        }
42    }
43    
44    /// Create retry config for aggressive retries
45    pub fn aggressive() -> Self {
46        Self {
47            max_attempts: 5,
48            initial_delay: Duration::from_millis(100),
49            max_delay: Duration::from_secs(5),
50            backoff_multiplier: 1.5,
51            jitter_factor: 0.2,
52        }
53    }
54    
55    /// Create retry config with no retries
56    pub fn none() -> Self {
57        Self {
58            max_attempts: 1,
59            initial_delay: Duration::from_secs(0),
60            max_delay: Duration::from_secs(0),
61            backoff_multiplier: 1.0,
62            jitter_factor: 0.0,
63        }
64    }
65}
66
67/// Retry executor with exponential backoff and jitter
68pub struct RetryExecutor {
69    config: RetryConfig,
70}
71
72impl RetryExecutor {
73    pub fn new(config: RetryConfig) -> Self {
74        Self { config }
75    }
76    
77    /// Execute operation with retries
78    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
79    where
80        F: FnMut() -> Fut,
81        Fut: std::future::Future<Output = Result<T>>,
82    {
83        let mut last_error = None;
84        
85        for attempt in 0..self.config.max_attempts {
86            match operation().await {
87                Ok(result) => return Ok(result),
88                Err(e) => {
89                    // Don't retry certain errors
90                    if !self.should_retry(&e) {
91                        return Err(e);
92                    }
93                    
94                    last_error = Some(e);
95                    
96                    // Don't sleep after the last attempt
97                    if attempt + 1 < self.config.max_attempts {
98                        let delay = self.calculate_delay(attempt);
99                        sleep(delay).await;
100                    }
101                }
102            }
103        }
104        
105        Err(last_error.unwrap_or(TruthlinkedError::Network("Max retries exceeded".to_string())))
106    }
107    
108    /// Determine if error should be retried
109    fn should_retry(&self, error: &TruthlinkedError) -> bool {
110        match error {
111            // Retry network errors
112            TruthlinkedError::Network(_) => true,
113            // Retry server errors
114            TruthlinkedError::ServerError => true,
115            // Don't retry auth errors
116            TruthlinkedError::Unauthorized => false,
117            TruthlinkedError::Forbidden => false,
118            // Don't retry client errors
119            TruthlinkedError::InvalidRequest(_) => false,
120            // Don't retry rate limits (handle separately)
121            TruthlinkedError::RateLimitExceeded(_) => false,
122            // Don't retry other errors
123            _ => false,
124        }
125    }
126    
127    /// Calculate delay with exponential backoff and jitter
128    fn calculate_delay(&self, attempt: u32) -> Duration {
129        let base_delay = self.config.initial_delay.as_millis() as f64;
130        let exponential_delay = base_delay * self.config.backoff_multiplier.powi(attempt as i32);
131        let capped_delay = exponential_delay.min(self.config.max_delay.as_millis() as f64);
132        
133        // Add jitter to prevent thundering herd
134        let jitter = if self.config.jitter_factor > 0.0 {
135            use rand::Rng;
136            let jitter_amount = capped_delay * self.config.jitter_factor;
137            
138            rand::thread_rng().gen_range(-jitter_amount..=jitter_amount)
139        } else {
140            0.0
141        };
142        
143        let final_delay = (capped_delay + jitter).max(0.0) as u64;
144        Duration::from_millis(final_delay)
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use std::sync::atomic::{AtomicU32, Ordering};
152    use std::sync::Arc;
153    
154    #[tokio::test]
155    async fn test_retry_success_on_second_attempt() {
156        let config = RetryConfig {
157            max_attempts: 3,
158            initial_delay: Duration::from_millis(1),
159            max_delay: Duration::from_millis(10),
160            backoff_multiplier: 2.0,
161            jitter_factor: 0.0,
162        };
163        
164        let executor = RetryExecutor::new(config);
165        let attempt_count = Arc::new(AtomicU32::new(0));
166        let attempt_count_clone = attempt_count.clone();
167        
168        let result: Result<&str> = executor.execute(|| {
169            let count = attempt_count_clone.fetch_add(1, Ordering::SeqCst);
170            async move {
171                if count == 0 {
172                    Err(TruthlinkedError::Network("Connection failed".to_string()))
173                } else {
174                    Ok("success")
175                }
176            }
177        }).await;
178        
179        assert!(result.is_ok());
180        assert_eq!(result.unwrap(), "success");
181        assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
182    }
183    
184    #[tokio::test]
185    async fn test_no_retry_on_auth_error() {
186        let config = RetryConfig::none();
187        let executor = RetryExecutor::new(config);
188        let attempt_count = Arc::new(AtomicU32::new(0));
189        let attempt_count_clone = attempt_count.clone();
190        
191        let result: Result<&str> = executor.execute(|| {
192            attempt_count_clone.fetch_add(1, Ordering::SeqCst);
193            async move {
194                Err(TruthlinkedError::Unauthorized)
195            }
196        }).await;
197        
198        assert!(result.is_err());
199        assert_eq!(attempt_count.load(Ordering::SeqCst), 1);
200    }
201}