1use crate::error::{Result, YfCommonError};
4use rand::Rng;
5use std::future::Future;
6use std::time::Duration;
7use tokio::time::sleep;
8use tracing::warn;
9
10#[derive(Debug, Clone)]
12pub struct RetryConfig {
13 pub max_retries: u32,
14 pub base_delay_ms: u64,
15 pub max_jitter_ms: u64,
16 pub retry_on_rate_limit: bool,
17 pub retry_on_server_error: bool,
18}
19
20impl Default for RetryConfig {
21 fn default() -> Self {
22 Self {
23 max_retries: 3,
24 base_delay_ms: 1000,
25 max_jitter_ms: 500,
26 retry_on_rate_limit: true,
27 retry_on_server_error: true,
28 }
29 }
30}
31
32impl RetryConfig {
33 pub fn new(max_retries: u32) -> Self {
34 Self { max_retries, ..Default::default() }
35 }
36
37 pub fn no_retry() -> Self {
38 Self { max_retries: 0, ..Default::default() }
39 }
40
41 pub fn with_base_delay(mut self, ms: u64) -> Self {
42 self.base_delay_ms = ms;
43 self
44 }
45
46 pub fn with_max_jitter(mut self, ms: u64) -> Self {
47 self.max_jitter_ms = ms;
48 self
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
54pub enum BackoffStrategy {
55 #[default]
56 Exponential,
57 Linear,
58 Constant,
59}
60
61pub fn calculate_backoff(attempt: u32, config: &RetryConfig, strategy: BackoffStrategy) -> Duration {
63 let base = config.base_delay_ms;
64 let delay_ms = match strategy {
65 BackoffStrategy::Exponential => base * 2u64.pow(attempt),
66 BackoffStrategy::Linear => base * (attempt as u64 + 1),
67 BackoffStrategy::Constant => base,
68 };
69 let jitter = if config.max_jitter_ms > 0 {
70 rand::thread_rng().gen_range(0..=config.max_jitter_ms)
71 } else {
72 0
73 };
74 Duration::from_millis(delay_ms + jitter)
75}
76
77pub fn is_retryable(error: &YfCommonError, config: &RetryConfig) -> bool {
79 match error {
80 YfCommonError::RateLimitExceeded(_) => config.retry_on_rate_limit,
81 YfCommonError::ServerError(_, _) => config.retry_on_server_error,
82 YfCommonError::TimeoutError(_) => true,
83 YfCommonError::RequestError(e) => e.is_timeout() || e.is_connect(),
84 _ => false,
85 }
86}
87
88pub async fn retry_with_backoff<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
90where
91 F: Fn() -> Fut,
92 Fut: Future<Output = Result<T>>,
93{
94 let mut last_error = None;
95
96 for attempt in 0..=config.max_retries {
97 match operation().await {
98 Ok(result) => return Ok(result),
99 Err(e) => {
100 if attempt == config.max_retries || !is_retryable(&e, config) {
101 return Err(e);
102 }
103 let delay = calculate_backoff(attempt, config, BackoffStrategy::Exponential);
104 warn!("Attempt {} failed: {}. Retrying in {:?}", attempt + 1, e, delay);
105 sleep(delay).await;
106 last_error = Some(e);
107 }
108 }
109 }
110
111 Err(last_error.unwrap_or(YfCommonError::MaxRetriesExceeded(config.max_retries)))
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_default_config() {
120 let config = RetryConfig::default();
121 assert_eq!(config.max_retries, 3);
122 assert_eq!(config.base_delay_ms, 1000);
123 }
124
125 #[test]
126 fn test_exponential_backoff() {
127 let config = RetryConfig::default().with_max_jitter(0);
128 assert_eq!(calculate_backoff(0, &config, BackoffStrategy::Exponential), Duration::from_millis(1000));
129 assert_eq!(calculate_backoff(1, &config, BackoffStrategy::Exponential), Duration::from_millis(2000));
130 assert_eq!(calculate_backoff(2, &config, BackoffStrategy::Exponential), Duration::from_millis(4000));
131 }
132
133 #[test]
134 fn test_linear_backoff() {
135 let config = RetryConfig::default().with_max_jitter(0);
136 assert_eq!(calculate_backoff(0, &config, BackoffStrategy::Linear), Duration::from_millis(1000));
137 assert_eq!(calculate_backoff(1, &config, BackoffStrategy::Linear), Duration::from_millis(2000));
138 assert_eq!(calculate_backoff(2, &config, BackoffStrategy::Linear), Duration::from_millis(3000));
139 }
140
141 #[tokio::test]
142 async fn test_retry_immediate_success() {
143 let config = RetryConfig::default();
144 let result: Result<i32> = retry_with_backoff(|| async { Ok(42) }, &config).await;
145 assert_eq!(result.unwrap(), 42);
146 }
147
148 #[test]
149 fn test_constant_backoff() {
150 let config = RetryConfig::default().with_max_jitter(0);
151 assert_eq!(calculate_backoff(0, &config, BackoffStrategy::Constant), Duration::from_millis(1000));
152 assert_eq!(calculate_backoff(1, &config, BackoffStrategy::Constant), Duration::from_millis(1000));
153 assert_eq!(calculate_backoff(2, &config, BackoffStrategy::Constant), Duration::from_millis(1000));
154 }
155
156 #[test]
157 fn test_is_retryable_timeout() {
158 let config = RetryConfig::default();
159 let err = YfCommonError::TimeoutError("timeout".to_string());
160 assert!(is_retryable(&err, &config));
161 }
162
163 #[test]
164 fn test_is_retryable_server_error() {
165 let config = RetryConfig::default();
166 let err = YfCommonError::ServerError(500, "error".to_string());
167 assert!(is_retryable(&err, &config));
168 }
169}