simple_agents_router/
retry.rs1use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
6use std::future::Future;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Copy)]
11pub struct RetryPolicy {
12 pub max_attempts: u32,
14 pub initial_backoff: Duration,
16 pub max_backoff: Duration,
18 pub backoff_multiplier: f32,
20 pub jitter: bool,
22}
23
24impl Default for RetryPolicy {
25 fn default() -> Self {
26 Self {
27 max_attempts: 3,
28 initial_backoff: Duration::from_millis(100),
29 max_backoff: Duration::from_secs(10),
30 backoff_multiplier: 2.0,
31 jitter: true,
32 }
33 }
34}
35
36impl RetryPolicy {
37 fn backoff(&self, attempt: u32) -> Duration {
38 let base =
39 self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
40 let capped = base.min(self.max_backoff.as_millis() as f32);
41
42 let duration_ms = if self.jitter {
43 let jitter_factor = 0.5 + (random_f32() * 0.5);
44 capped * jitter_factor
45 } else {
46 capped
47 };
48
49 Duration::from_millis(duration_ms as u64).min(self.max_backoff)
50 }
51}
52
53pub async fn execute_with_retry<F, Fut, T>(
57 policy: RetryPolicy,
58 operation: F,
59) -> Result<T, SimpleAgentsError>
60where
61 F: Fn() -> Fut,
62 Fut: Future<Output = Result<T, SimpleAgentsError>>,
63{
64 let mut last_error: Option<SimpleAgentsError> = None;
65
66 for attempt in 0..policy.max_attempts {
67 match operation().await {
68 Ok(result) => return Ok(result),
69 Err(error) => {
70 if !is_retryable(&error) {
71 return Err(error);
72 }
73
74 if attempt >= policy.max_attempts - 1 {
75 last_error = Some(error);
76 break;
77 }
78
79 tokio::time::sleep(policy.backoff(attempt)).await;
80 last_error = Some(error);
81 }
82 }
83 }
84
85 Err(last_error.unwrap())
86}
87
88fn is_retryable(error: &SimpleAgentsError) -> bool {
89 matches!(
90 error,
91 SimpleAgentsError::Provider(
92 ProviderError::RateLimit { .. }
93 | ProviderError::Timeout(_)
94 | ProviderError::ServerError(_)
95 ) | SimpleAgentsError::Network(_)
96 )
97}
98
99fn random_f32() -> f32 {
100 use rand::Rng;
101 rand::thread_rng().gen()
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[tokio::test]
109 async fn succeeds_without_retry() {
110 let policy = RetryPolicy {
111 max_attempts: 3,
112 initial_backoff: Duration::from_millis(1),
113 max_backoff: Duration::from_millis(5),
114 backoff_multiplier: 2.0,
115 jitter: false,
116 };
117
118 let result =
119 execute_with_retry(policy, || async { Ok::<_, SimpleAgentsError>("ok") }).await;
120 assert_eq!(result.unwrap(), "ok");
121 }
122
123 #[tokio::test]
124 async fn retries_on_retryable_error() {
125 let policy = RetryPolicy {
126 max_attempts: 2,
127 initial_backoff: Duration::from_millis(1),
128 max_backoff: Duration::from_millis(5),
129 backoff_multiplier: 2.0,
130 jitter: false,
131 };
132
133 use std::sync::atomic::{AtomicUsize, Ordering};
134 use std::sync::Arc;
135
136 let attempts = Arc::new(AtomicUsize::new(0));
137 let attempts_clone = attempts.clone();
138
139 let result = execute_with_retry(policy, move || {
140 let attempts = attempts_clone.clone();
141 async move {
142 let current = attempts.fetch_add(1, Ordering::Relaxed);
143 if current == 0 {
144 Err(SimpleAgentsError::Provider(ProviderError::Timeout(
145 Duration::from_secs(1),
146 )))
147 } else {
148 Ok("ok")
149 }
150 }
151 })
152 .await;
153
154 assert_eq!(result.unwrap(), "ok");
155 assert_eq!(attempts.load(Ordering::Relaxed), 2);
156 }
157
158 #[tokio::test]
159 async fn fails_on_non_retryable_error() {
160 let policy = RetryPolicy {
161 max_attempts: 3,
162 initial_backoff: Duration::from_millis(1),
163 max_backoff: Duration::from_millis(5),
164 backoff_multiplier: 2.0,
165 jitter: false,
166 };
167
168 use std::sync::atomic::{AtomicUsize, Ordering};
169 use std::sync::Arc;
170
171 let attempts = Arc::new(AtomicUsize::new(0));
172 let attempts_clone = attempts.clone();
173
174 let result = execute_with_retry(policy, move || {
175 let attempts = attempts_clone.clone();
176 async move {
177 attempts.fetch_add(1, Ordering::Relaxed);
178 Err::<&str, _>(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
179 }
180 })
181 .await;
182
183 assert!(matches!(
184 result,
185 Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
186 ));
187 assert_eq!(attempts.load(Ordering::Relaxed), 1);
188 }
189}