1use rand::Rng;
4use std::time::Duration;
5
6#[derive(Debug, Clone)]
8pub struct RetryPolicy {
9 pub max_retries: u32,
11 pub base_delay: Duration,
13 pub max_delay: Duration,
15 pub multiplier: f64,
17 pub jitter: bool,
19}
20
21impl RetryPolicy {
22 pub fn none() -> Self {
24 Self {
25 max_retries: 0,
26 base_delay: Duration::from_secs(1),
27 max_delay: Duration::from_secs(1),
28 multiplier: 2.0,
29 jitter: false,
30 }
31 }
32
33 pub fn standard() -> Self {
35 Self {
36 max_retries: 3,
37 base_delay: Duration::from_secs(2),
38 max_delay: Duration::from_secs(30),
39 multiplier: 2.0,
40 jitter: true,
41 }
42 }
43
44 pub fn aggressive() -> Self {
46 Self {
47 max_retries: 5,
48 base_delay: Duration::from_secs(1),
49 max_delay: Duration::from_secs(15),
50 multiplier: 1.5,
51 jitter: true,
52 }
53 }
54
55 pub fn linear(max_retries: u32, delay: Duration) -> Self {
57 Self {
58 max_retries,
59 base_delay: delay,
60 max_delay: delay,
61 multiplier: 1.0,
62 jitter: false,
63 }
64 }
65
66 pub fn patient() -> Self {
68 Self {
69 max_retries: 10,
70 base_delay: Duration::from_secs(5),
71 max_delay: Duration::from_secs(120),
72 multiplier: 2.0,
73 jitter: true,
74 }
75 }
76
77 pub fn from_max_retries(max_retries: u32) -> Self {
79 if max_retries == 0 {
80 Self::none()
81 } else {
82 Self {
83 max_retries,
84 ..Self::standard()
85 }
86 }
87 }
88
89 pub fn from_preset(name: &str) -> Option<Self> {
91 match name.to_lowercase().as_str() {
92 "none" => Some(Self::none()),
93 "standard" => Some(Self::standard()),
94 "aggressive" => Some(Self::aggressive()),
95 "patient" => Some(Self::patient()),
96 _ => None,
97 }
98 }
99
100 pub fn should_retry(&self, attempt: u32) -> bool {
102 attempt <= self.max_retries
103 }
104
105 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
107 if attempt == 0 {
108 return Duration::ZERO;
109 }
110
111 let base = self.base_delay.as_secs_f64();
112 let delay = base * self.multiplier.powi(attempt.saturating_sub(1) as i32);
113 let capped = delay.min(self.max_delay.as_secs_f64());
114
115 if self.jitter {
116 let mut rng = rand::thread_rng();
117 let jittered = capped * rng.gen_range(0.5..1.0);
118 Duration::from_secs_f64(jittered)
119 } else {
120 Duration::from_secs_f64(capped)
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_none_policy() {
131 let policy = RetryPolicy::none();
132 assert!(!policy.should_retry(1));
133 assert!(policy.should_retry(0));
134 }
135
136 #[test]
137 fn test_standard_policy() {
138 let policy = RetryPolicy::standard();
139 assert!(policy.should_retry(1));
140 assert!(policy.should_retry(3));
141 assert!(!policy.should_retry(4));
142 }
143
144 #[test]
145 fn test_delay_exponential() {
146 let policy = RetryPolicy {
147 jitter: false,
148 ..RetryPolicy::standard()
149 };
150 let d1 = policy.delay_for_attempt(1);
151 let d2 = policy.delay_for_attempt(2);
152 assert!(d2 > d1);
153 }
154
155 #[test]
156 fn test_delay_capped() {
157 let policy = RetryPolicy {
158 jitter: false,
159 max_delay: Duration::from_secs(10),
160 ..RetryPolicy::standard()
161 };
162 let d = policy.delay_for_attempt(100);
163 assert!(d <= Duration::from_secs(10));
164 }
165
166 #[test]
167 fn test_linear_policy() {
168 let policy = RetryPolicy::linear(5, Duration::from_secs(3));
169 let d1 = policy.delay_for_attempt(1);
170 let d2 = policy.delay_for_attempt(2);
171 assert_eq!(d1, d2); assert_eq!(d1, Duration::from_secs(3));
173 }
174
175 #[test]
176 fn test_from_preset() {
177 assert!(RetryPolicy::from_preset("standard").is_some());
178 assert!(RetryPolicy::from_preset("none").is_some());
179 assert!(RetryPolicy::from_preset("aggressive").is_some());
180 assert!(RetryPolicy::from_preset("patient").is_some());
181 assert!(RetryPolicy::from_preset("unknown").is_none());
182 }
183
184 #[test]
185 fn test_from_max_retries() {
186 let p = RetryPolicy::from_max_retries(0);
187 assert!(!p.should_retry(1));
188
189 let p = RetryPolicy::from_max_retries(5);
190 assert!(p.should_retry(5));
191 assert!(!p.should_retry(6));
192 }
193}