swan_common/types/
retry.rs

1use std::time::Duration;
2use syn::LitStr;
3
4/// 重试策略配置
5#[derive(Debug, Clone, PartialEq)]
6pub struct RetryPolicy {
7    /// 最大重试次数
8    pub max_attempts: u32,
9    /// 基础延迟时间(毫秒)
10    pub base_delay_ms: u64,
11    /// 最大延迟时间(毫秒)
12    pub max_delay_ms: u64,
13    /// 指数底数
14    pub exponential_base: f64,
15    /// 随机抖动比例 (0.0-1.0)
16    pub jitter_ratio: f64,
17    /// 仅对幂等方法重试
18    pub idempotent_only: bool,
19}
20
21impl Default for RetryPolicy {
22    fn default() -> Self {
23        Self {
24            max_attempts: 3,
25            base_delay_ms: 100,
26            max_delay_ms: 30000, // 30秒
27            exponential_base: 2.0,
28            jitter_ratio: 0.1,
29            idempotent_only: true,
30        }
31    }
32}
33
34impl RetryPolicy {
35    /// 创建指数重试策略
36    pub fn exponential(max_attempts: u32, base_delay_ms: u64) -> Self {
37        Self {
38            max_attempts,
39            base_delay_ms,
40            ..Default::default()
41        }
42    }
43
44    /// 创建固定延迟重试策略
45    pub fn fixed(max_attempts: u32, delay_ms: u64) -> Self {
46        Self {
47            max_attempts,
48            base_delay_ms: delay_ms,
49            max_delay_ms: delay_ms,
50            exponential_base: 1.0,
51            jitter_ratio: 0.0,
52            idempotent_only: true,
53        }
54    }
55
56    /// 计算重试延迟时间
57    pub fn calculate_delay(&self, attempt: u32) -> Duration {
58        if attempt == 0 {
59            return Duration::from_millis(0);
60        }
61
62        // 指数退避: base_delay * exponential_base^(attempt-1)
63        let exponential_delay = self.base_delay_ms as f64 
64            * self.exponential_base.powi((attempt - 1) as i32);
65        
66        // 应用最大延迟限制
67        let capped_delay = exponential_delay.min(self.max_delay_ms as f64);
68        
69        // 添加随机抖动
70        let jitter = if self.jitter_ratio > 0.0 {
71            let max_jitter = capped_delay * self.jitter_ratio;
72            fastrand::f64() * max_jitter
73        } else {
74            0.0
75        };
76
77        Duration::from_millis((capped_delay + jitter) as u64)
78    }
79
80    /// 判断HTTP状态码是否应该重试
81    pub fn should_retry_status(&self, status: u16) -> bool {
82        match status {
83            // 5xx 服务器错误 - 应该重试
84            500..=599 => true,
85            // 429 限流 - 应该重试
86            429 => true,
87            // 408 请求超时 - 应该重试
88            408 => true,
89            // 其他状态码不重试
90            _ => false,
91        }
92    }
93
94    /// 判断HTTP方法是否幂等
95    pub fn is_idempotent_method(method: &crate::types::http::HttpMethod) -> bool {
96        use crate::types::http::HttpMethod;
97        match method {
98            HttpMethod::Get | HttpMethod::Put | HttpMethod::Delete => true,
99            HttpMethod::Post => false,
100        }
101    }
102}
103
104/// 重试配置解析结果
105#[derive(Clone)]
106pub struct RetryConfig {
107    pub policy: RetryPolicy,
108    pub raw_config: LitStr,
109}
110
111impl std::fmt::Debug for RetryConfig {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("RetryConfig")
114            .field("policy", &self.policy)
115            .field("raw_config", &self.raw_config.value())
116            .finish()
117    }
118}
119
120impl RetryConfig {
121    /// 从字符串解析重试配置
122    /// 
123    /// 支持格式:
124    /// - "exponential(max_attempts=3, base_delay=100ms)"
125    /// - "fixed(max_attempts=5, delay=200ms)"
126    /// - "exponential(3, 100ms)" // 简化格式
127    pub fn parse(config_str: &LitStr) -> Result<Self, syn::Error> {
128        let config_value = config_str.value();
129        let policy = Self::parse_policy_string(&config_value)
130            .map_err(|msg| syn::Error::new(config_str.span(), msg))?;
131
132        Ok(RetryConfig {
133            policy,
134            raw_config: config_str.clone(),
135        })
136    }
137
138    fn parse_policy_string(config: &str) -> Result<RetryPolicy, String> {
139        let config = config.trim();
140        
141        if config.starts_with("exponential(") && config.ends_with(')') {
142            Self::parse_exponential_config(&config[12..config.len()-1])
143        } else if config.starts_with("fixed(") && config.ends_with(')') {
144            Self::parse_fixed_config(&config[6..config.len()-1])
145        } else {
146            Err(format!("Unsupported retry config format: {}", config))
147        }
148    }
149
150    fn parse_exponential_config(params: &str) -> Result<RetryPolicy, String> {
151        let mut policy = RetryPolicy::default();
152        
153        // 解析参数
154        for param in params.split(',') {
155            let param = param.trim();
156            if param.is_empty() { continue; }
157            
158            if let Some((key, value)) = param.split_once('=') {
159                let key = key.trim();
160                let value = value.trim();
161                
162                match key {
163                    "max_attempts" => {
164                        policy.max_attempts = value.parse()
165                            .map_err(|_| format!("Invalid max_attempts: {}", value))?;
166                    }
167                    "base_delay" => {
168                        policy.base_delay_ms = Self::parse_duration(value)?;
169                    }
170                    "max_delay" => {
171                        policy.max_delay_ms = Self::parse_duration(value)?;
172                    }
173                    "exponential_base" => {
174                        policy.exponential_base = value.parse()
175                            .map_err(|_| format!("Invalid exponential_base: {}", value))?;
176                    }
177                    "jitter_ratio" => {
178                        policy.jitter_ratio = value.parse()
179                            .map_err(|_| format!("Invalid jitter_ratio: {}", value))?;
180                    }
181                    "idempotent_only" => {
182                        policy.idempotent_only = value.parse()
183                            .map_err(|_| format!("Invalid idempotent_only: {}", value))?;
184                    }
185                    _ => return Err(format!("Unknown parameter: {}", key)),
186                }
187            } else {
188                // 简化格式:exponential(3, 100ms)
189                let parts: Vec<&str> = params.split(',').map(|s| s.trim()).collect();
190                if parts.len() >= 1 {
191                    policy.max_attempts = parts[0].parse()
192                        .map_err(|_| format!("Invalid max_attempts: {}", parts[0]))?;
193                }
194                if parts.len() >= 2 {
195                    policy.base_delay_ms = Self::parse_duration(parts[1])?;
196                }
197                break;
198            }
199        }
200        
201        Ok(policy)
202    }
203
204    fn parse_fixed_config(params: &str) -> Result<RetryPolicy, String> {
205        let mut policy = RetryPolicy::default();
206        policy.exponential_base = 1.0; // 固定延迟
207
208        for param in params.split(',') {
209            let param = param.trim();
210            if param.is_empty() { continue; }
211            
212            if let Some((key, value)) = param.split_once('=') {
213                let key = key.trim();
214                let value = value.trim();
215                
216                match key {
217                    "max_attempts" => {
218                        policy.max_attempts = value.parse()
219                            .map_err(|_| format!("Invalid max_attempts: {}", value))?;
220                    }
221                    "delay" => {
222                        let delay = Self::parse_duration(value)?;
223                        policy.base_delay_ms = delay;
224                        policy.max_delay_ms = delay;
225                    }
226                    _ => return Err(format!("Unknown parameter: {}", key)),
227                }
228            }
229        }
230        
231        Ok(policy)
232    }
233
234    fn parse_duration(duration_str: &str) -> Result<u64, String> {
235        let duration_str = duration_str.trim();
236        
237        if duration_str.ends_with("ms") {
238            duration_str[..duration_str.len()-2].parse()
239                .map_err(|_| format!("Invalid milliseconds: {}", duration_str))
240        } else if duration_str.ends_with('s') {
241            let seconds: u64 = duration_str[..duration_str.len()-1].parse()
242                .map_err(|_| format!("Invalid seconds: {}", duration_str))?;
243            Ok(seconds * 1000)
244        } else {
245            // 默认按毫秒处理
246            duration_str.parse()
247                .map_err(|_| format!("Invalid duration (expected ms or s suffix): {}", duration_str))
248        }
249    }
250}
251
252#[cfg(test)]
253mod basic_tests {
254    use super::*;
255    use syn::{LitStr, parse_quote};
256
257    #[test]
258    fn test_retry_policy_default() {
259        let policy = RetryPolicy::default();
260        assert_eq!(policy.max_attempts, 3);
261        assert_eq!(policy.base_delay_ms, 100);
262        assert_eq!(policy.exponential_base, 2.0);
263    }
264
265    #[test]
266    fn test_calculate_delay() {
267        let policy = RetryPolicy::exponential(3, 100);
268        
269        assert_eq!(policy.calculate_delay(0), Duration::from_millis(0));
270        assert!(policy.calculate_delay(1).as_millis() >= 100);
271        assert!(policy.calculate_delay(2).as_millis() >= 200);
272        assert!(policy.calculate_delay(3).as_millis() >= 400);
273    }
274
275    #[test]
276    fn test_should_retry_status() {
277        let policy = RetryPolicy::default();
278        
279        assert!(policy.should_retry_status(500));
280        assert!(policy.should_retry_status(502));
281        assert!(policy.should_retry_status(429));
282        assert!(policy.should_retry_status(408));
283        
284        assert!(!policy.should_retry_status(200));
285        assert!(!policy.should_retry_status(400));
286        assert!(!policy.should_retry_status(404));
287    }
288
289    #[test]
290    fn test_parse_exponential_simple() {
291        let config: LitStr = parse_quote! { "exponential(3, 100ms)" };
292        let result = RetryConfig::parse(&config).unwrap();
293        
294        assert_eq!(result.policy.max_attempts, 3);
295        assert_eq!(result.policy.base_delay_ms, 100);
296        assert_eq!(result.policy.exponential_base, 2.0);
297    }
298
299    #[test]
300    fn test_parse_exponential_detailed() {
301        let config: LitStr = parse_quote! { "exponential(max_attempts=5, base_delay=200ms, max_delay=10s)" };
302        let result = RetryConfig::parse(&config).unwrap();
303        
304        assert_eq!(result.policy.max_attempts, 5);
305        assert_eq!(result.policy.base_delay_ms, 200);
306        assert_eq!(result.policy.max_delay_ms, 10000);
307    }
308
309    #[test]
310    fn test_parse_fixed() {
311        let config: LitStr = parse_quote! { "fixed(max_attempts=3, delay=500ms)" };
312        let result = RetryConfig::parse(&config).unwrap();
313        
314        assert_eq!(result.policy.max_attempts, 3);
315        assert_eq!(result.policy.base_delay_ms, 500);
316        assert_eq!(result.policy.exponential_base, 1.0);
317    }
318
319    #[test]
320    fn test_parse_duration() {
321        assert_eq!(RetryConfig::parse_duration("100ms").unwrap(), 100);
322        assert_eq!(RetryConfig::parse_duration("2s").unwrap(), 2000);
323        assert_eq!(RetryConfig::parse_duration("500").unwrap(), 500);
324    }
325}
326