Skip to main content

sh_layer2/checkpoint_system/
retry.rs

1//! # 错误恢复系统
2//!
3//! 实现三层重试机制:自动 → 降级 → 用户介入
4
5use crate::types::{Layer2Result, SessionId};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use tokio::time::sleep;
10
11/// 错误类型分类
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum ErrorCategory {
14    /// 临时错误(网络波动、API 限流)
15    Transient,
16    /// 资源错误(内存不足、磁盘满)
17    Resource,
18    /// 配置错误(API Key 无效、配置缺失)
19    Configuration,
20    /// 逻辑错误(参数错误、工具失败)
21    Logic,
22    /// 系统错误(未知错误)
23    System,
24    /// 用户中断
25    UserInterrupt,
26}
27
28impl ErrorCategory {
29    /// 从错误消息分析类型
30    pub fn from_error_message(msg: &str) -> Self {
31        let msg_lower = msg.to_lowercase();
32
33        if msg_lower.contains("timeout")
34            || msg_lower.contains("network")
35            || msg_lower.contains("rate limit")
36        {
37            ErrorCategory::Transient
38        } else if msg_lower.contains("memory")
39            || msg_lower.contains("disk")
40            || msg_lower.contains("resource")
41        {
42            ErrorCategory::Resource
43        } else if msg_lower.contains("api key")
44            || msg_lower.contains("config")
45            || msg_lower.contains("auth")
46        {
47            ErrorCategory::Configuration
48        } else if msg_lower.contains("invalid")
49            || msg_lower.contains("parameter")
50            || msg_lower.contains("argument")
51        {
52            ErrorCategory::Logic
53        } else if msg_lower.contains("interrupt")
54            || msg_lower.contains("cancel")
55            || msg_lower.contains("abort")
56        {
57            ErrorCategory::UserInterrupt
58        } else {
59            ErrorCategory::System
60        }
61    }
62
63    /// 是否可重试
64    pub fn is_retryable(&self) -> bool {
65        matches!(self, ErrorCategory::Transient | ErrorCategory::Resource)
66    }
67}
68
69/// 重试策略
70#[derive(Debug, Clone)]
71pub struct RetryPolicy {
72    /// 最大重试次数
73    pub max_retries: usize,
74    /// 初始延迟(毫秒)
75    pub initial_delay_ms: u64,
76    /// 最大延迟(毫秒)
77    pub max_delay_ms: u64,
78    /// 延迟倍数(指数退避)
79    pub multiplier: f64,
80    /// 抖动因子(0.0-1.0)
81    pub jitter: f64,
82}
83
84impl Default for RetryPolicy {
85    fn default() -> Self {
86        Self {
87            max_retries: 3,
88            initial_delay_ms: 1000,
89            max_delay_ms: 30000,
90            multiplier: 2.0,
91            jitter: 0.1,
92        }
93    }
94}
95
96impl RetryPolicy {
97    /// 计算第 n 次重试的延迟
98    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
99        let base_delay = self.initial_delay_ms as f64 * self.multiplier.powi(attempt as i32);
100        let capped_delay = base_delay.min(self.max_delay_ms as f64);
101
102        // 添加确定性抖动(基于 attempt 避免需要 rand 依赖)
103        let jitter_range = capped_delay * self.jitter;
104        let jitter_offset = ((attempt as f64 * 0.3).fract() - 0.5) * 2.0 * jitter_range;
105        let final_delay = (capped_delay + jitter_offset).max(0.0) as u64;
106
107        Duration::from_millis(final_delay)
108    }
109}
110
111/// 降级策略
112#[derive(Debug, Clone)]
113pub enum FallbackStrategy {
114    /// 不降级
115    None,
116    /// 使用备用服务
117    BackupService { endpoint: String },
118    /// 使用缓存数据
119    UseCache { max_age_seconds: u64 },
120    /// 简化功能
121    Simplified { mode: String },
122    /// 跳过操作
123    Skip,
124}
125
126/// 恢复层执行结果
127#[derive(Debug, Clone)]
128pub struct RecoveryResult {
129    /// 是否成功
130    pub success: bool,
131    /// 使用的恢复层
132    pub layer_used: RecoveryLayer,
133    /// 重试次数
134    pub attempts: usize,
135    /// 最终错误消息
136    pub error_message: Option<String>,
137    /// 用户操作建议
138    pub user_action: Option<String>,
139}
140
141/// 恢复层
142#[derive(Debug, Clone, PartialEq, Eq)]
143pub enum RecoveryLayer {
144    /// 第一层:自动重试
145    Automatic,
146    /// 第二层:降级执行
147    Fallback,
148    /// 第三层:用户介入
149    UserIntervention,
150}
151
152/// 恢复动作(用户介入时)
153#[derive(Debug, Clone)]
154pub enum RecoveryAction {
155    /// 重试操作
156    Retry,
157    /// 跳过操作
158    Skip,
159    /// 终止会话
160    Abort,
161    /// 修改配置后重试
162    ModifyConfig { key: String, value: String },
163    /// 切换到备用服务
164    SwitchBackup { service: String },
165}
166
167/// 用户确认回调
168pub type UserConfirmationCallback =
169    Arc<dyn Fn(&str, Vec<RecoveryAction>) -> RecoveryAction + Send + Sync>;
170
171/// 错误恢复管理器
172pub struct ErrorRecovery {
173    /// 重试策略
174    retry_policy: RetryPolicy,
175    /// 降级策略
176    fallback_strategy: FallbackStrategy,
177    /// 用户确认回调
178    user_callback: RwLock<Option<UserConfirmationCallback>>,
179    /// 恢复统计
180    stats: RwLock<RecoveryStats>,
181}
182
183/// 恢复统计
184#[derive(Debug, Clone, Default)]
185pub struct RecoveryStats {
186    pub total_errors: usize,
187    pub auto_recovered: usize,
188    pub fallback_recovered: usize,
189    pub user_interventions: usize,
190    pub unrecovered: usize,
191}
192
193impl Default for ErrorRecovery {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199impl ErrorRecovery {
200    /// 创建新的恢复管理器
201    pub fn new() -> Self {
202        Self {
203            retry_policy: RetryPolicy::default(),
204            fallback_strategy: FallbackStrategy::None,
205            user_callback: RwLock::new(None),
206            stats: RwLock::new(RecoveryStats::default()),
207        }
208    }
209
210    /// 设置重试策略
211    pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
212        self.retry_policy = policy;
213        self
214    }
215
216    /// 设置降级策略
217    pub fn with_fallback(mut self, strategy: FallbackStrategy) -> Self {
218        self.fallback_strategy = strategy;
219        self
220    }
221
222    /// 设置用户确认回调
223    pub async fn set_user_callback(&self, callback: UserConfirmationCallback) {
224        *self.user_callback.write().await = Some(callback);
225    }
226
227    /// 执行带恢复的操作
228    pub async fn execute_with_recovery<F, Fut, T>(&self, operation: F) -> RecoveryResult
229    where
230        F: Fn() -> Fut + Send + Sync,
231        Fut: std::future::Future<Output = Layer2Result<T>> + Send,
232        T: Send,
233    {
234        let mut stats = self.stats.write().await;
235        stats.total_errors += 1;
236        drop(stats);
237
238        // 第一层:自动重试
239        let retry_result = self.try_with_retry(&operation).await;
240
241        if retry_result.success {
242            let mut stats = self.stats.write().await;
243            stats.auto_recovered += 1;
244            return retry_result;
245        }
246
247        // 第二层:降级执行
248        let fallback_result = self.try_with_fallback(&operation).await;
249
250        if fallback_result.success {
251            let mut stats = self.stats.write().await;
252            stats.fallback_recovered += 1;
253            return fallback_result;
254        }
255
256        // 第三层:用户介入
257        let user_result = self.try_with_user_intervention(&operation).await;
258
259        if user_result.success {
260            let mut stats = self.stats.write().await;
261            stats.user_interventions += 1;
262        } else {
263            let mut stats = self.stats.write().await;
264            stats.unrecovered += 1;
265        }
266
267        user_result
268    }
269
270    /// 第一层:自动重试
271    async fn try_with_retry<F, Fut, T>(&self, operation: &F) -> RecoveryResult
272    where
273        F: Fn() -> Fut + Send + Sync,
274        Fut: std::future::Future<Output = Layer2Result<T>> + Send,
275        T: Send,
276    {
277        let mut last_error: Option<String> = None;
278
279        for attempt in 0..=self.retry_policy.max_retries {
280            match operation().await {
281                Ok(_) => {
282                    return RecoveryResult {
283                        success: true,
284                        layer_used: RecoveryLayer::Automatic,
285                        attempts: attempt,
286                        error_message: None,
287                        user_action: None,
288                    };
289                }
290                Err(e) => {
291                    let error_msg = e.to_string();
292                    let category = ErrorCategory::from_error_message(&error_msg);
293
294                    if !category.is_retryable() {
295                        return RecoveryResult {
296                            success: false,
297                            layer_used: RecoveryLayer::Automatic,
298                            attempts: attempt,
299                            error_message: Some(error_msg.clone()),
300                            user_action: Some(self.get_user_hint(&category)),
301                        };
302                    }
303
304                    last_error = Some(error_msg);
305
306                    if attempt < self.retry_policy.max_retries {
307                        let delay = self.retry_policy.delay_for_attempt(attempt);
308                        sleep(delay).await;
309                    }
310                }
311            }
312        }
313
314        RecoveryResult {
315            success: false,
316            layer_used: RecoveryLayer::Automatic,
317            attempts: self.retry_policy.max_retries + 1,
318            error_message: last_error,
319            user_action: None,
320        }
321    }
322
323    /// 第二层:降级执行
324    async fn try_with_fallback<F, Fut, T>(&self, _operation: &F) -> RecoveryResult
325    where
326        F: Fn() -> Fut + Send + Sync,
327        Fut: std::future::Future<Output = Layer2Result<T>> + Send,
328        T: Send,
329    {
330        match &self.fallback_strategy {
331            FallbackStrategy::None => RecoveryResult {
332                success: false,
333                layer_used: RecoveryLayer::Fallback,
334                attempts: 0,
335                error_message: Some("No fallback strategy configured".to_string()),
336                user_action: None,
337            },
338            FallbackStrategy::Skip => RecoveryResult {
339                success: true,
340                layer_used: RecoveryLayer::Fallback,
341                attempts: 1,
342                error_message: None,
343                user_action: Some("Operation skipped due to fallback policy".to_string()),
344            },
345            FallbackStrategy::BackupService { endpoint } => {
346                // 简化实现:返回成功(实际实现需要切换服务端点)
347                RecoveryResult {
348                    success: true,
349                    layer_used: RecoveryLayer::Fallback,
350                    attempts: 1,
351                    error_message: None,
352                    user_action: Some(format!("Switched to backup: {}", endpoint)),
353                }
354            }
355            FallbackStrategy::UseCache { max_age_seconds } => RecoveryResult {
356                success: true,
357                layer_used: RecoveryLayer::Fallback,
358                attempts: 1,
359                error_message: None,
360                user_action: Some(format!("Using cached data (max {}s old)", max_age_seconds)),
361            },
362            FallbackStrategy::Simplified { mode } => RecoveryResult {
363                success: true,
364                layer_used: RecoveryLayer::Fallback,
365                attempts: 1,
366                error_message: None,
367                user_action: Some(format!("Using simplified mode: {}", mode)),
368            },
369        }
370    }
371
372    /// 第三层:用户介入
373    async fn try_with_user_intervention<F, Fut, T>(&self, _operation: &F) -> RecoveryResult
374    where
375        F: Fn() -> Fut + Send + Sync,
376        Fut: std::future::Future<Output = Layer2Result<T>> + Send,
377        T: Send,
378    {
379        let callback = self.user_callback.read().await;
380
381        if let Some(cb) = callback.as_ref() {
382            let actions = vec![
383                RecoveryAction::Retry,
384                RecoveryAction::Skip,
385                RecoveryAction::Abort,
386            ];
387
388            let action = cb("Operation failed. Choose action:", actions);
389
390            match action {
391                RecoveryAction::Retry => RecoveryResult {
392                    success: false, // 实际实现需要重新尝试
393                    layer_used: RecoveryLayer::UserIntervention,
394                    attempts: 1,
395                    error_message: None,
396                    user_action: Some("User requested retry".to_string()),
397                },
398                RecoveryAction::Skip => RecoveryResult {
399                    success: true,
400                    layer_used: RecoveryLayer::UserIntervention,
401                    attempts: 1,
402                    error_message: None,
403                    user_action: Some("User chose to skip".to_string()),
404                },
405                RecoveryAction::Abort => RecoveryResult {
406                    success: false,
407                    layer_used: RecoveryLayer::UserIntervention,
408                    attempts: 1,
409                    error_message: Some("User aborted operation".to_string()),
410                    user_action: Some("User aborted".to_string()),
411                },
412                _ => RecoveryResult {
413                    success: false,
414                    layer_used: RecoveryLayer::UserIntervention,
415                    attempts: 1,
416                    error_message: Some("Unknown action".to_string()),
417                    user_action: None,
418                },
419            }
420        } else {
421            RecoveryResult {
422                success: false,
423                layer_used: RecoveryLayer::UserIntervention,
424                attempts: 0,
425                error_message: Some("No user callback set".to_string()),
426                user_action: Some("Please configure user callback for intervention".to_string()),
427            }
428        }
429    }
430
431    /// 获取用户提示
432    fn get_user_hint(&self, category: &ErrorCategory) -> String {
433        match category {
434            ErrorCategory::Configuration => "Check your API key and configuration".to_string(),
435            ErrorCategory::Logic => "Verify your input parameters".to_string(),
436            ErrorCategory::UserInterrupt => "Operation was cancelled".to_string(),
437            ErrorCategory::Transient => "Temporary issue, will retry automatically".to_string(),
438            ErrorCategory::Resource => {
439                "System resource issue, consider freeing up memory/disk".to_string()
440            }
441            ErrorCategory::System => "Unknown error occurred".to_string(),
442        }
443    }
444
445    /// 获取统计信息
446    pub async fn get_stats(&self) -> RecoveryStats {
447        self.stats.read().await.clone()
448    }
449}
450
451/// 会话恢复检测器
452pub struct SessionRecovery {
453    /// 存储路径
454    storage_path: std::path::PathBuf,
455}
456
457impl SessionRecovery {
458    /// 创建新的会话恢复器
459    pub fn new(storage_path: impl AsRef<std::path::Path>) -> Self {
460        Self {
461            storage_path: storage_path.as_ref().to_path_buf(),
462        }
463    }
464
465    /// 检测是否有中断的会话
466    pub fn detect_interrupted_sessions(&self) -> Layer2Result<Vec<InterruptedSession>> {
467        let mut interrupted = Vec::new();
468
469        if !self.storage_path.exists() {
470            return Ok(interrupted);
471        }
472
473        for entry in std::fs::read_dir(&self.storage_path)? {
474            let entry = entry?;
475            let session_dir = entry.path();
476
477            if !session_dir.is_dir() {
478                continue;
479            }
480
481            let state_file = session_dir.join("state.json");
482            if state_file.exists() {
483                if let Ok(content) = std::fs::read_to_string(&state_file) {
484                    if let Ok(state) = serde_json::from_str::<SessionState>(&content) {
485                        if state.status == SessionStatus::Running && !state.completed {
486                            interrupted.push(InterruptedSession {
487                                session_id: state.session_id,
488                                last_iteration: state.iteration,
489                                last_activity: state.last_updated,
490                                task_description: state.task_description,
491                            });
492                        }
493                    }
494                }
495            }
496        }
497
498        // 按时间排序(最近的中断在前)
499        interrupted.sort_by_key(|b| std::cmp::Reverse(b.last_activity));
500
501        Ok(interrupted)
502    }
503
504    /// 渲染中断会话列表
505    pub fn render_interrupted(&self) -> String {
506        match self.detect_interrupted_sessions() {
507            Ok(sessions) => {
508                if sessions.is_empty() {
509                    "No interrupted sessions found.".to_string()
510                } else {
511                    let mut output =
512                        format!("Found {} interrupted session(s):\n\n", sessions.len());
513                    for (i, session) in sessions.iter().enumerate() {
514                        output.push_str(&format!(
515                            "{}. Session: {}\n   Task: {}\n   Iteration: {}\n   Last activity: {}\n\n",
516                            i + 1,
517                            session.session_id,
518                            session.task_description.as_deref().unwrap_or("Unknown"),
519                            session.last_iteration,
520                            session.last_activity.format("%Y-%m-%d %H:%M:%S")
521                        ));
522                    }
523                    output.push_str("Use 'continuum session resume <id>' to continue.");
524                    output
525                }
526            }
527            Err(e) => format!("Error detecting sessions: {}", e),
528        }
529    }
530}
531
532/// 中断的会话信息
533#[derive(Debug, Clone)]
534pub struct InterruptedSession {
535    pub session_id: SessionId,
536    pub last_iteration: i32,
537    pub last_activity: chrono::DateTime<chrono::Utc>,
538    pub task_description: Option<String>,
539}
540
541/// 会话状态
542#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
543struct SessionState {
544    session_id: SessionId,
545    status: SessionStatus,
546    completed: bool,
547    iteration: i32,
548    last_updated: chrono::DateTime<chrono::Utc>,
549    task_description: Option<String>,
550}
551
552/// 会话状态枚举
553#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
554enum SessionStatus {
555    Running,
556    Paused,
557    Completed,
558    Error,
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564
565    #[test]
566    fn test_error_category_analysis() {
567        let cat = ErrorCategory::from_error_message("network timeout");
568        assert_eq!(cat, ErrorCategory::Transient);
569
570        let cat = ErrorCategory::from_error_message("invalid parameter");
571        assert_eq!(cat, ErrorCategory::Logic);
572    }
573
574    #[test]
575    fn test_retry_policy_delay() {
576        let policy = RetryPolicy::default();
577        let delay = policy.delay_for_attempt(0);
578        assert!(delay.as_millis() >= 900); // 考虑抖动
579        assert!(delay.as_millis() <= 1100);
580    }
581
582    #[test]
583    fn test_retry_policy_max_delay() {
584        let policy = RetryPolicy {
585            max_delay_ms: 5000,
586            ..Default::default()
587        };
588        let delay = policy.delay_for_attempt(10);
589        assert!(delay.as_millis() <= 5500); // 考虑抖动
590    }
591
592    #[tokio::test]
593    async fn test_error_recovery_creation() {
594        let recovery = ErrorRecovery::new();
595        let stats = recovery.get_stats().await;
596        assert_eq!(stats.total_errors, 0);
597    }
598
599    #[test]
600    fn test_fallback_strategy() {
601        let strategy = FallbackStrategy::Skip;
602        matches!(strategy, FallbackStrategy::Skip);
603    }
604}