1use std::time::Duration;
2use syn::LitStr;
3
4#[derive(Debug, Clone, PartialEq)]
6pub struct RetryPolicy {
7 pub max_attempts: u32,
9 pub base_delay_ms: u64,
11 pub max_delay_ms: u64,
13 pub exponential_base: f64,
15 pub jitter_ratio: f64,
17 pub idempotent_only: bool,
19}
20
21impl Default for RetryPolicy {
22 fn default() -> Self {
23 Self {
24 max_attempts: 3,
25 base_delay_ms: 100,
26 max_delay_ms: 30000, exponential_base: 2.0,
28 jitter_ratio: 0.1,
29 idempotent_only: true,
30 }
31 }
32}
33
34impl RetryPolicy {
35 pub fn exponential(max_attempts: u32, base_delay_ms: u64) -> Self {
37 Self {
38 max_attempts,
39 base_delay_ms,
40 ..Default::default()
41 }
42 }
43
44 pub fn fixed(max_attempts: u32, delay_ms: u64) -> Self {
46 Self {
47 max_attempts,
48 base_delay_ms: delay_ms,
49 max_delay_ms: delay_ms,
50 exponential_base: 1.0,
51 jitter_ratio: 0.0,
52 idempotent_only: true,
53 }
54 }
55
56 pub fn calculate_delay(&self, attempt: u32) -> Duration {
58 if attempt == 0 {
59 return Duration::from_millis(0);
60 }
61
62 let exponential_delay = self.base_delay_ms as f64
64 * self.exponential_base.powi((attempt - 1) as i32);
65
66 let capped_delay = exponential_delay.min(self.max_delay_ms as f64);
68
69 let jitter = if self.jitter_ratio > 0.0 {
71 let max_jitter = capped_delay * self.jitter_ratio;
72 fastrand::f64() * max_jitter
73 } else {
74 0.0
75 };
76
77 Duration::from_millis((capped_delay + jitter) as u64)
78 }
79
80 pub fn should_retry_status(&self, status: u16) -> bool {
82 match status {
83 500..=599 => true,
85 429 => true,
87 408 => true,
89 _ => false,
91 }
92 }
93
94 pub fn is_idempotent_method(method: &crate::types::http::HttpMethod) -> bool {
96 use crate::types::http::HttpMethod;
97 match method {
98 HttpMethod::Get | HttpMethod::Put | HttpMethod::Delete => true,
99 HttpMethod::Post => false,
100 }
101 }
102}
103
104#[derive(Clone)]
106pub struct RetryConfig {
107 pub policy: RetryPolicy,
108 pub raw_config: LitStr,
109}
110
111impl std::fmt::Debug for RetryConfig {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("RetryConfig")
114 .field("policy", &self.policy)
115 .field("raw_config", &self.raw_config.value())
116 .finish()
117 }
118}
119
120impl RetryConfig {
121 pub fn parse(config_str: &LitStr) -> Result<Self, syn::Error> {
128 let config_value = config_str.value();
129 let policy = Self::parse_policy_string(&config_value)
130 .map_err(|msg| syn::Error::new(config_str.span(), msg))?;
131
132 Ok(RetryConfig {
133 policy,
134 raw_config: config_str.clone(),
135 })
136 }
137
138 fn parse_policy_string(config: &str) -> Result<RetryPolicy, String> {
139 let config = config.trim();
140
141 if config.starts_with("exponential(") && config.ends_with(')') {
142 Self::parse_exponential_config(&config[12..config.len()-1])
143 } else if config.starts_with("fixed(") && config.ends_with(')') {
144 Self::parse_fixed_config(&config[6..config.len()-1])
145 } else {
146 Err(format!("Unsupported retry config format: {}", config))
147 }
148 }
149
150 fn parse_exponential_config(params: &str) -> Result<RetryPolicy, String> {
151 let mut policy = RetryPolicy::default();
152
153 for param in params.split(',') {
155 let param = param.trim();
156 if param.is_empty() { continue; }
157
158 if let Some((key, value)) = param.split_once('=') {
159 let key = key.trim();
160 let value = value.trim();
161
162 match key {
163 "max_attempts" => {
164 policy.max_attempts = value.parse()
165 .map_err(|_| format!("Invalid max_attempts: {}", value))?;
166 }
167 "base_delay" => {
168 policy.base_delay_ms = Self::parse_duration(value)?;
169 }
170 "max_delay" => {
171 policy.max_delay_ms = Self::parse_duration(value)?;
172 }
173 "exponential_base" => {
174 policy.exponential_base = value.parse()
175 .map_err(|_| format!("Invalid exponential_base: {}", value))?;
176 }
177 "jitter_ratio" => {
178 policy.jitter_ratio = value.parse()
179 .map_err(|_| format!("Invalid jitter_ratio: {}", value))?;
180 }
181 "idempotent_only" => {
182 policy.idempotent_only = value.parse()
183 .map_err(|_| format!("Invalid idempotent_only: {}", value))?;
184 }
185 _ => return Err(format!("Unknown parameter: {}", key)),
186 }
187 } else {
188 let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
190 if parts.len() >= 1 {
191 policy.max_attempts = parts[0].parse()
192 .map_err(|_| format!("Invalid max_attempts: {}", parts[0]))?;
193 }
194 if parts.len() >= 2 {
195 policy.base_delay_ms = Self::parse_duration(parts[1])?;
196 }
197 break;
198 }
199 }
200
201 Ok(policy)
202 }
203
204 fn parse_fixed_config(params: &str) -> Result<RetryPolicy, String> {
205 let mut policy = RetryPolicy::default();
206 policy.exponential_base = 1.0; for param in params.split(',') {
209 let param = param.trim();
210 if param.is_empty() { continue; }
211
212 if let Some((key, value)) = param.split_once('=') {
213 let key = key.trim();
214 let value = value.trim();
215
216 match key {
217 "max_attempts" => {
218 policy.max_attempts = value.parse()
219 .map_err(|_| format!("Invalid max_attempts: {}", value))?;
220 }
221 "delay" => {
222 let delay = Self::parse_duration(value)?;
223 policy.base_delay_ms = delay;
224 policy.max_delay_ms = delay;
225 }
226 _ => return Err(format!("Unknown parameter: {}", key)),
227 }
228 }
229 }
230
231 Ok(policy)
232 }
233
234 fn parse_duration(duration_str: &str) -> Result<u64, String> {
235 let duration_str = duration_str.trim();
236
237 if duration_str.ends_with("ms") {
238 duration_str[..duration_str.len()-2].parse()
239 .map_err(|_| format!("Invalid milliseconds: {}", duration_str))
240 } else if duration_str.ends_with('s') {
241 let seconds: u64 = duration_str[..duration_str.len()-1].parse()
242 .map_err(|_| format!("Invalid seconds: {}", duration_str))?;
243 Ok(seconds * 1000)
244 } else {
245 duration_str.parse()
247 .map_err(|_| format!("Invalid duration (expected ms or s suffix): {}", duration_str))
248 }
249 }
250}
251
252#[cfg(test)]
253mod basic_tests {
254 use super::*;
255 use syn::{LitStr, parse_quote};
256
257 #[test]
258 fn test_retry_policy_default() {
259 let policy = RetryPolicy::default();
260 assert_eq!(policy.max_attempts, 3);
261 assert_eq!(policy.base_delay_ms, 100);
262 assert_eq!(policy.exponential_base, 2.0);
263 }
264
265 #[test]
266 fn test_calculate_delay() {
267 let policy = RetryPolicy::exponential(3, 100);
268
269 assert_eq!(policy.calculate_delay(0), Duration::from_millis(0));
270 assert!(policy.calculate_delay(1).as_millis() >= 100);
271 assert!(policy.calculate_delay(2).as_millis() >= 200);
272 assert!(policy.calculate_delay(3).as_millis() >= 400);
273 }
274
275 #[test]
276 fn test_should_retry_status() {
277 let policy = RetryPolicy::default();
278
279 assert!(policy.should_retry_status(500));
280 assert!(policy.should_retry_status(502));
281 assert!(policy.should_retry_status(429));
282 assert!(policy.should_retry_status(408));
283
284 assert!(!policy.should_retry_status(200));
285 assert!(!policy.should_retry_status(400));
286 assert!(!policy.should_retry_status(404));
287 }
288
289 #[test]
290 fn test_parse_exponential_simple() {
291 let config: LitStr = parse_quote! { "exponential(3, 100ms)" };
292 let result = RetryConfig::parse(&config).unwrap();
293
294 assert_eq!(result.policy.max_attempts, 3);
295 assert_eq!(result.policy.base_delay_ms, 100);
296 assert_eq!(result.policy.exponential_base, 2.0);
297 }
298
299 #[test]
300 fn test_parse_exponential_detailed() {
301 let config: LitStr = parse_quote! { "exponential(max_attempts=5, base_delay=200ms, max_delay=10s)" };
302 let result = RetryConfig::parse(&config).unwrap();
303
304 assert_eq!(result.policy.max_attempts, 5);
305 assert_eq!(result.policy.base_delay_ms, 200);
306 assert_eq!(result.policy.max_delay_ms, 10000);
307 }
308
309 #[test]
310 fn test_parse_fixed() {
311 let config: LitStr = parse_quote! { "fixed(max_attempts=3, delay=500ms)" };
312 let result = RetryConfig::parse(&config).unwrap();
313
314 assert_eq!(result.policy.max_attempts, 3);
315 assert_eq!(result.policy.base_delay_ms, 500);
316 assert_eq!(result.policy.exponential_base, 1.0);
317 }
318
319 #[test]
320 fn test_parse_duration() {
321 assert_eq!(RetryConfig::parse_duration("100ms").unwrap(), 100);
322 assert_eq!(RetryConfig::parse_duration("2s").unwrap(), 2000);
323 assert_eq!(RetryConfig::parse_duration("500").unwrap(), 500);
324 }
325}
326