1use crate::types::{Layer2Result, SessionId};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use tokio::time::sleep;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum ErrorCategory {
14 Transient,
16 Resource,
18 Configuration,
20 Logic,
22 System,
24 UserInterrupt,
26}
27
28impl ErrorCategory {
29 pub fn from_error_message(msg: &str) -> Self {
31 let msg_lower = msg.to_lowercase();
32
33 if msg_lower.contains("timeout")
34 || msg_lower.contains("network")
35 || msg_lower.contains("rate limit")
36 {
37 ErrorCategory::Transient
38 } else if msg_lower.contains("memory")
39 || msg_lower.contains("disk")
40 || msg_lower.contains("resource")
41 {
42 ErrorCategory::Resource
43 } else if msg_lower.contains("api key")
44 || msg_lower.contains("config")
45 || msg_lower.contains("auth")
46 {
47 ErrorCategory::Configuration
48 } else if msg_lower.contains("invalid")
49 || msg_lower.contains("parameter")
50 || msg_lower.contains("argument")
51 {
52 ErrorCategory::Logic
53 } else if msg_lower.contains("interrupt")
54 || msg_lower.contains("cancel")
55 || msg_lower.contains("abort")
56 {
57 ErrorCategory::UserInterrupt
58 } else {
59 ErrorCategory::System
60 }
61 }
62
63 pub fn is_retryable(&self) -> bool {
65 matches!(self, ErrorCategory::Transient | ErrorCategory::Resource)
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct RetryPolicy {
72 pub max_retries: usize,
74 pub initial_delay_ms: u64,
76 pub max_delay_ms: u64,
78 pub multiplier: f64,
80 pub jitter: f64,
82}
83
84impl Default for RetryPolicy {
85 fn default() -> Self {
86 Self {
87 max_retries: 3,
88 initial_delay_ms: 1000,
89 max_delay_ms: 30000,
90 multiplier: 2.0,
91 jitter: 0.1,
92 }
93 }
94}
95
96impl RetryPolicy {
97 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
99 let base_delay = self.initial_delay_ms as f64 * self.multiplier.powi(attempt as i32);
100 let capped_delay = base_delay.min(self.max_delay_ms as f64);
101
102 let jitter_range = capped_delay * self.jitter;
104 let jitter_offset = ((attempt as f64 * 0.3).fract() - 0.5) * 2.0 * jitter_range;
105 let final_delay = (capped_delay + jitter_offset).max(0.0) as u64;
106
107 Duration::from_millis(final_delay)
108 }
109}
110
111#[derive(Debug, Clone)]
113pub enum FallbackStrategy {
114 None,
116 BackupService { endpoint: String },
118 UseCache { max_age_seconds: u64 },
120 Simplified { mode: String },
122 Skip,
124}
125
126#[derive(Debug, Clone)]
128pub struct RecoveryResult {
129 pub success: bool,
131 pub layer_used: RecoveryLayer,
133 pub attempts: usize,
135 pub error_message: Option<String>,
137 pub user_action: Option<String>,
139}
140
141#[derive(Debug, Clone, PartialEq, Eq)]
143pub enum RecoveryLayer {
144 Automatic,
146 Fallback,
148 UserIntervention,
150}
151
152#[derive(Debug, Clone)]
154pub enum RecoveryAction {
155 Retry,
157 Skip,
159 Abort,
161 ModifyConfig { key: String, value: String },
163 SwitchBackup { service: String },
165}
166
167pub type UserConfirmationCallback =
169 Arc<dyn Fn(&str, Vec<RecoveryAction>) -> RecoveryAction + Send + Sync>;
170
171pub struct ErrorRecovery {
173 retry_policy: RetryPolicy,
175 fallback_strategy: FallbackStrategy,
177 user_callback: RwLock<Option<UserConfirmationCallback>>,
179 stats: RwLock<RecoveryStats>,
181}
182
183#[derive(Debug, Clone, Default)]
185pub struct RecoveryStats {
186 pub total_errors: usize,
187 pub auto_recovered: usize,
188 pub fallback_recovered: usize,
189 pub user_interventions: usize,
190 pub unrecovered: usize,
191}
192
193impl Default for ErrorRecovery {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199impl ErrorRecovery {
200 pub fn new() -> Self {
202 Self {
203 retry_policy: RetryPolicy::default(),
204 fallback_strategy: FallbackStrategy::None,
205 user_callback: RwLock::new(None),
206 stats: RwLock::new(RecoveryStats::default()),
207 }
208 }
209
210 pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
212 self.retry_policy = policy;
213 self
214 }
215
216 pub fn with_fallback(mut self, strategy: FallbackStrategy) -> Self {
218 self.fallback_strategy = strategy;
219 self
220 }
221
222 pub async fn set_user_callback(&self, callback: UserConfirmationCallback) {
224 *self.user_callback.write().await = Some(callback);
225 }
226
227 pub async fn execute_with_recovery<F, Fut, T>(&self, operation: F) -> RecoveryResult
229 where
230 F: Fn() -> Fut + Send + Sync,
231 Fut: std::future::Future<Output = Layer2Result<T>> + Send,
232 T: Send,
233 {
234 let mut stats = self.stats.write().await;
235 stats.total_errors += 1;
236 drop(stats);
237
238 let retry_result = self.try_with_retry(&operation).await;
240
241 if retry_result.success {
242 let mut stats = self.stats.write().await;
243 stats.auto_recovered += 1;
244 return retry_result;
245 }
246
247 let fallback_result = self.try_with_fallback(&operation).await;
249
250 if fallback_result.success {
251 let mut stats = self.stats.write().await;
252 stats.fallback_recovered += 1;
253 return fallback_result;
254 }
255
256 let user_result = self.try_with_user_intervention(&operation).await;
258
259 if user_result.success {
260 let mut stats = self.stats.write().await;
261 stats.user_interventions += 1;
262 } else {
263 let mut stats = self.stats.write().await;
264 stats.unrecovered += 1;
265 }
266
267 user_result
268 }
269
270 async fn try_with_retry<F, Fut, T>(&self, operation: &F) -> RecoveryResult
272 where
273 F: Fn() -> Fut + Send + Sync,
274 Fut: std::future::Future<Output = Layer2Result<T>> + Send,
275 T: Send,
276 {
277 let mut last_error: Option<String> = None;
278
279 for attempt in 0..=self.retry_policy.max_retries {
280 match operation().await {
281 Ok(_) => {
282 return RecoveryResult {
283 success: true,
284 layer_used: RecoveryLayer::Automatic,
285 attempts: attempt,
286 error_message: None,
287 user_action: None,
288 };
289 }
290 Err(e) => {
291 let error_msg = e.to_string();
292 let category = ErrorCategory::from_error_message(&error_msg);
293
294 if !category.is_retryable() {
295 return RecoveryResult {
296 success: false,
297 layer_used: RecoveryLayer::Automatic,
298 attempts: attempt,
299 error_message: Some(error_msg.clone()),
300 user_action: Some(self.get_user_hint(&category)),
301 };
302 }
303
304 last_error = Some(error_msg);
305
306 if attempt < self.retry_policy.max_retries {
307 let delay = self.retry_policy.delay_for_attempt(attempt);
308 sleep(delay).await;
309 }
310 }
311 }
312 }
313
314 RecoveryResult {
315 success: false,
316 layer_used: RecoveryLayer::Automatic,
317 attempts: self.retry_policy.max_retries + 1,
318 error_message: last_error,
319 user_action: None,
320 }
321 }
322
323 async fn try_with_fallback<F, Fut, T>(&self, _operation: &F) -> RecoveryResult
325 where
326 F: Fn() -> Fut + Send + Sync,
327 Fut: std::future::Future<Output = Layer2Result<T>> + Send,
328 T: Send,
329 {
330 match &self.fallback_strategy {
331 FallbackStrategy::None => RecoveryResult {
332 success: false,
333 layer_used: RecoveryLayer::Fallback,
334 attempts: 0,
335 error_message: Some("No fallback strategy configured".to_string()),
336 user_action: None,
337 },
338 FallbackStrategy::Skip => RecoveryResult {
339 success: true,
340 layer_used: RecoveryLayer::Fallback,
341 attempts: 1,
342 error_message: None,
343 user_action: Some("Operation skipped due to fallback policy".to_string()),
344 },
345 FallbackStrategy::BackupService { endpoint } => {
346 RecoveryResult {
348 success: true,
349 layer_used: RecoveryLayer::Fallback,
350 attempts: 1,
351 error_message: None,
352 user_action: Some(format!("Switched to backup: {}", endpoint)),
353 }
354 }
355 FallbackStrategy::UseCache { max_age_seconds } => RecoveryResult {
356 success: true,
357 layer_used: RecoveryLayer::Fallback,
358 attempts: 1,
359 error_message: None,
360 user_action: Some(format!("Using cached data (max {}s old)", max_age_seconds)),
361 },
362 FallbackStrategy::Simplified { mode } => RecoveryResult {
363 success: true,
364 layer_used: RecoveryLayer::Fallback,
365 attempts: 1,
366 error_message: None,
367 user_action: Some(format!("Using simplified mode: {}", mode)),
368 },
369 }
370 }
371
372 async fn try_with_user_intervention<F, Fut, T>(&self, _operation: &F) -> RecoveryResult
374 where
375 F: Fn() -> Fut + Send + Sync,
376 Fut: std::future::Future<Output = Layer2Result<T>> + Send,
377 T: Send,
378 {
379 let callback = self.user_callback.read().await;
380
381 if let Some(cb) = callback.as_ref() {
382 let actions = vec![
383 RecoveryAction::Retry,
384 RecoveryAction::Skip,
385 RecoveryAction::Abort,
386 ];
387
388 let action = cb("Operation failed. Choose action:", actions);
389
390 match action {
391 RecoveryAction::Retry => RecoveryResult {
392 success: false, layer_used: RecoveryLayer::UserIntervention,
394 attempts: 1,
395 error_message: None,
396 user_action: Some("User requested retry".to_string()),
397 },
398 RecoveryAction::Skip => RecoveryResult {
399 success: true,
400 layer_used: RecoveryLayer::UserIntervention,
401 attempts: 1,
402 error_message: None,
403 user_action: Some("User chose to skip".to_string()),
404 },
405 RecoveryAction::Abort => RecoveryResult {
406 success: false,
407 layer_used: RecoveryLayer::UserIntervention,
408 attempts: 1,
409 error_message: Some("User aborted operation".to_string()),
410 user_action: Some("User aborted".to_string()),
411 },
412 _ => RecoveryResult {
413 success: false,
414 layer_used: RecoveryLayer::UserIntervention,
415 attempts: 1,
416 error_message: Some("Unknown action".to_string()),
417 user_action: None,
418 },
419 }
420 } else {
421 RecoveryResult {
422 success: false,
423 layer_used: RecoveryLayer::UserIntervention,
424 attempts: 0,
425 error_message: Some("No user callback set".to_string()),
426 user_action: Some("Please configure user callback for intervention".to_string()),
427 }
428 }
429 }
430
431 fn get_user_hint(&self, category: &ErrorCategory) -> String {
433 match category {
434 ErrorCategory::Configuration => "Check your API key and configuration".to_string(),
435 ErrorCategory::Logic => "Verify your input parameters".to_string(),
436 ErrorCategory::UserInterrupt => "Operation was cancelled".to_string(),
437 ErrorCategory::Transient => "Temporary issue, will retry automatically".to_string(),
438 ErrorCategory::Resource => {
439 "System resource issue, consider freeing up memory/disk".to_string()
440 }
441 ErrorCategory::System => "Unknown error occurred".to_string(),
442 }
443 }
444
445 pub async fn get_stats(&self) -> RecoveryStats {
447 self.stats.read().await.clone()
448 }
449}
450
451pub struct SessionRecovery {
453 storage_path: std::path::PathBuf,
455}
456
457impl SessionRecovery {
458 pub fn new(storage_path: impl AsRef<std::path::Path>) -> Self {
460 Self {
461 storage_path: storage_path.as_ref().to_path_buf(),
462 }
463 }
464
465 pub fn detect_interrupted_sessions(&self) -> Layer2Result<Vec<InterruptedSession>> {
467 let mut interrupted = Vec::new();
468
469 if !self.storage_path.exists() {
470 return Ok(interrupted);
471 }
472
473 for entry in std::fs::read_dir(&self.storage_path)? {
474 let entry = entry?;
475 let session_dir = entry.path();
476
477 if !session_dir.is_dir() {
478 continue;
479 }
480
481 let state_file = session_dir.join("state.json");
482 if state_file.exists() {
483 if let Ok(content) = std::fs::read_to_string(&state_file) {
484 if let Ok(state) = serde_json::from_str::<SessionState>(&content) {
485 if state.status == SessionStatus::Running && !state.completed {
486 interrupted.push(InterruptedSession {
487 session_id: state.session_id,
488 last_iteration: state.iteration,
489 last_activity: state.last_updated,
490 task_description: state.task_description,
491 });
492 }
493 }
494 }
495 }
496 }
497
498 interrupted.sort_by_key(|b| std::cmp::Reverse(b.last_activity));
500
501 Ok(interrupted)
502 }
503
504 pub fn render_interrupted(&self) -> String {
506 match self.detect_interrupted_sessions() {
507 Ok(sessions) => {
508 if sessions.is_empty() {
509 "No interrupted sessions found.".to_string()
510 } else {
511 let mut output =
512 format!("Found {} interrupted session(s):\n\n", sessions.len());
513 for (i, session) in sessions.iter().enumerate() {
514 output.push_str(&format!(
515 "{}. Session: {}\n Task: {}\n Iteration: {}\n Last activity: {}\n\n",
516 i + 1,
517 session.session_id,
518 session.task_description.as_deref().unwrap_or("Unknown"),
519 session.last_iteration,
520 session.last_activity.format("%Y-%m-%d %H:%M:%S")
521 ));
522 }
523 output.push_str("Use 'continuum session resume <id>' to continue.");
524 output
525 }
526 }
527 Err(e) => format!("Error detecting sessions: {}", e),
528 }
529 }
530}
531
532#[derive(Debug, Clone)]
534pub struct InterruptedSession {
535 pub session_id: SessionId,
536 pub last_iteration: i32,
537 pub last_activity: chrono::DateTime<chrono::Utc>,
538 pub task_description: Option<String>,
539}
540
541#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
543struct SessionState {
544 session_id: SessionId,
545 status: SessionStatus,
546 completed: bool,
547 iteration: i32,
548 last_updated: chrono::DateTime<chrono::Utc>,
549 task_description: Option<String>,
550}
551
552#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
554enum SessionStatus {
555 Running,
556 Paused,
557 Completed,
558 Error,
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564
565 #[test]
566 fn test_error_category_analysis() {
567 let cat = ErrorCategory::from_error_message("network timeout");
568 assert_eq!(cat, ErrorCategory::Transient);
569
570 let cat = ErrorCategory::from_error_message("invalid parameter");
571 assert_eq!(cat, ErrorCategory::Logic);
572 }
573
574 #[test]
575 fn test_retry_policy_delay() {
576 let policy = RetryPolicy::default();
577 let delay = policy.delay_for_attempt(0);
578 assert!(delay.as_millis() >= 900); assert!(delay.as_millis() <= 1100);
580 }
581
582 #[test]
583 fn test_retry_policy_max_delay() {
584 let policy = RetryPolicy {
585 max_delay_ms: 5000,
586 ..Default::default()
587 };
588 let delay = policy.delay_for_attempt(10);
589 assert!(delay.as_millis() <= 5500); }
591
592 #[tokio::test]
593 async fn test_error_recovery_creation() {
594 let recovery = ErrorRecovery::new();
595 let stats = recovery.get_stats().await;
596 assert_eq!(stats.total_errors, 0);
597 }
598
599 #[test]
600 fn test_fallback_strategy() {
601 let strategy = FallbackStrategy::Skip;
602 matches!(strategy, FallbackStrategy::Skip);
603 }
604}