1use crate::error::{Result, TruthlinkedError};
2use std::time::Duration;
3use tokio::time::sleep;
4
5#[derive(Debug, Clone)]
7pub struct RetryConfig {
8 pub max_attempts: u32,
10 pub initial_delay: Duration,
12 pub max_delay: Duration,
14 pub backoff_multiplier: f64,
16 pub jitter_factor: f64,
18}
19
20impl Default for RetryConfig {
21 fn default() -> Self {
22 Self {
23 max_attempts: 3,
24 initial_delay: Duration::from_secs(1),
25 max_delay: Duration::from_secs(30),
26 backoff_multiplier: 2.0,
27 jitter_factor: 0.1,
28 }
29 }
30}
31
32impl RetryConfig {
33 pub fn production() -> Self {
35 Self {
36 max_attempts: 3,
37 initial_delay: Duration::from_millis(500),
38 max_delay: Duration::from_secs(10),
39 backoff_multiplier: 2.0,
40 jitter_factor: 0.1,
41 }
42 }
43
44 pub fn aggressive() -> Self {
46 Self {
47 max_attempts: 5,
48 initial_delay: Duration::from_millis(100),
49 max_delay: Duration::from_secs(5),
50 backoff_multiplier: 1.5,
51 jitter_factor: 0.2,
52 }
53 }
54
55 pub fn none() -> Self {
57 Self {
58 max_attempts: 1,
59 initial_delay: Duration::from_secs(0),
60 max_delay: Duration::from_secs(0),
61 backoff_multiplier: 1.0,
62 jitter_factor: 0.0,
63 }
64 }
65}
66
67pub struct RetryExecutor {
69 config: RetryConfig,
70}
71
72impl RetryExecutor {
73 pub fn new(config: RetryConfig) -> Self {
74 Self { config }
75 }
76
77 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
79 where
80 F: FnMut() -> Fut,
81 Fut: std::future::Future<Output = Result<T>>,
82 {
83 let mut last_error = None;
84
85 for attempt in 0..self.config.max_attempts {
86 match operation().await {
87 Ok(result) => return Ok(result),
88 Err(e) => {
89 if !self.should_retry(&e) {
91 return Err(e);
92 }
93
94 last_error = Some(e);
95
96 if attempt + 1 < self.config.max_attempts {
98 let delay = self.calculate_delay(attempt);
99 sleep(delay).await;
100 }
101 }
102 }
103 }
104
105 Err(last_error.unwrap_or(TruthlinkedError::Network("Max retries exceeded".to_string())))
106 }
107
108 fn should_retry(&self, error: &TruthlinkedError) -> bool {
110 match error {
111 TruthlinkedError::Network(_) => true,
113 TruthlinkedError::ServerError => true,
115 TruthlinkedError::Unauthorized => false,
117 TruthlinkedError::Forbidden => false,
118 TruthlinkedError::InvalidRequest(_) => false,
120 TruthlinkedError::RateLimitExceeded(_) => false,
122 _ => false,
124 }
125 }
126
127 fn calculate_delay(&self, attempt: u32) -> Duration {
129 let base_delay = self.config.initial_delay.as_millis() as f64;
130 let exponential_delay = base_delay * self.config.backoff_multiplier.powi(attempt as i32);
131 let capped_delay = exponential_delay.min(self.config.max_delay.as_millis() as f64);
132
133 let jitter = if self.config.jitter_factor > 0.0 {
135 use rand::Rng;
136 let jitter_amount = capped_delay * self.config.jitter_factor;
137
138 rand::thread_rng().gen_range(-jitter_amount..=jitter_amount)
139 } else {
140 0.0
141 };
142
143 let final_delay = (capped_delay + jitter).max(0.0) as u64;
144 Duration::from_millis(final_delay)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use std::sync::atomic::{AtomicU32, Ordering};
152 use std::sync::Arc;
153
154 #[tokio::test]
155 async fn test_retry_success_on_second_attempt() {
156 let config = RetryConfig {
157 max_attempts: 3,
158 initial_delay: Duration::from_millis(1),
159 max_delay: Duration::from_millis(10),
160 backoff_multiplier: 2.0,
161 jitter_factor: 0.0,
162 };
163
164 let executor = RetryExecutor::new(config);
165 let attempt_count = Arc::new(AtomicU32::new(0));
166 let attempt_count_clone = attempt_count.clone();
167
168 let result: Result<&str> = executor.execute(|| {
169 let count = attempt_count_clone.fetch_add(1, Ordering::SeqCst);
170 async move {
171 if count == 0 {
172 Err(TruthlinkedError::Network("Connection failed".to_string()))
173 } else {
174 Ok("success")
175 }
176 }
177 }).await;
178
179 assert!(result.is_ok());
180 assert_eq!(result.unwrap(), "success");
181 assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
182 }
183
184 #[tokio::test]
185 async fn test_no_retry_on_auth_error() {
186 let config = RetryConfig::none();
187 let executor = RetryExecutor::new(config);
188 let attempt_count = Arc::new(AtomicU32::new(0));
189 let attempt_count_clone = attempt_count.clone();
190
191 let result: Result<&str> = executor.execute(|| {
192 attempt_count_clone.fetch_add(1, Ordering::SeqCst);
193 async move {
194 Err(TruthlinkedError::Unauthorized)
195 }
196 }).await;
197
198 assert!(result.is_err());
199 assert_eq!(attempt_count.load(Ordering::SeqCst), 1);
200 }
201}