Skip to main content

trustformers_models/
error_recovery.rs

1//! # Comprehensive Error Recovery Framework for TrustformeRS Models
2//!
3//! This module provides advanced error recovery mechanisms to ensure robust operation
4//! of transformer models under various failure conditions.
5//!
6//! ## Features
7//!
8//! - **Automatic Retry Strategies**: Configurable retry mechanisms with exponential backoff
9//! - **Fallback Execution**: Graceful degradation to simpler model variants
10//! - **State Persistence**: Save and restore model state during errors
11//! - **Memory Recovery**: Intelligent memory cleanup and reallocation
12//! - **Error Classification**: Smart categorization of errors for appropriate response
13//! - **Circuit Breaker Pattern**: Prevent cascade failures
14//! - **Checkpoint Management**: Automatic model checkpointing for recovery
15//! - **Performance Monitoring**: Track recovery effectiveness
16//!
17//! ## Usage
18//!
19//! ```rust
20//! use trustformers_models::error_recovery::{
21//!     ErrorRecoveryManager, RecoveryConfig, RecoveryStrategy
22//! };
23//!
24//! let config = RecoveryConfig::default()
25//!     .with_max_retries(3)
26//!     .with_fallback_enabled(true);
27//!
28//! let mut manager = ErrorRecoveryManager::new(config);
29//!
30//! // Execute with automatic recovery
31//! let result = manager.execute_with_recovery(|| {
32//!     // Your model operation here
33//!     model.forward(&input)
34//! })?;
35//! ```
36
37use anyhow::{Error, Result};
38use serde::{Deserialize, Serialize};
39use std::collections::{HashMap, VecDeque};
40use std::sync::{Arc, Mutex};
41use std::time::{Duration, Instant, SystemTime};
42use uuid::Uuid;
43
44/// Configuration for error recovery behavior
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct RecoveryConfig {
47    /// Maximum number of retry attempts
48    pub max_retries: usize,
49    /// Base delay for exponential backoff (milliseconds)
50    pub base_delay_ms: u64,
51    /// Maximum delay between retries (milliseconds)
52    pub max_delay_ms: u64,
53    /// Exponential backoff multiplier
54    pub backoff_multiplier: f64,
55    /// Whether to enable fallback strategies
56    pub enable_fallback: bool,
57    /// Whether to enable automatic checkpointing
58    pub enable_checkpointing: bool,
59    /// Memory pressure threshold for cleanup (MB)
60    pub memory_pressure_threshold_mb: f64,
61    /// Circuit breaker failure threshold
62    pub circuit_breaker_threshold: usize,
63    /// Circuit breaker timeout (seconds)
64    pub circuit_breaker_timeout_s: u64,
65    /// Whether to enable performance monitoring
66    pub enable_monitoring: bool,
67    /// Maximum number of error history entries to keep
68    pub max_error_history: usize,
69}
70
71impl Default for RecoveryConfig {
72    fn default() -> Self {
73        Self {
74            max_retries: 3,
75            base_delay_ms: 100,
76            max_delay_ms: 30000,
77            backoff_multiplier: 2.0,
78            enable_fallback: true,
79            enable_checkpointing: true,
80            memory_pressure_threshold_mb: 1024.0,
81            circuit_breaker_threshold: 5,
82            circuit_breaker_timeout_s: 60,
83            enable_monitoring: true,
84            max_error_history: 1000,
85        }
86    }
87}
88
89impl RecoveryConfig {
90    /// Enable maximum retries
91    pub fn with_max_retries(mut self, max_retries: usize) -> Self {
92        self.max_retries = max_retries;
93        self
94    }
95
96    /// Enable fallback strategies
97    pub fn with_fallback_enabled(mut self, enabled: bool) -> Self {
98        self.enable_fallback = enabled;
99        self
100    }
101
102    /// Set memory pressure threshold
103    pub fn with_memory_threshold(mut self, threshold_mb: f64) -> Self {
104        self.memory_pressure_threshold_mb = threshold_mb;
105        self
106    }
107}
108
109/// Types of errors that can be recovered from
110#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
111pub enum ErrorCategory {
112    /// Memory-related errors (OOM, allocation failures)
113    Memory,
114    /// Compute-related errors (CUDA, device failures)
115    Compute,
116    /// Network-related errors (distributed training)
117    Network,
118    /// Model-related errors (dimension mismatches, invalid states)
119    Model,
120    /// Data-related errors (corrupted inputs, invalid tensors)
121    Data,
122    /// Temporary resource unavailability
123    Resource,
124    /// Configuration or setup errors
125    Configuration,
126    /// Unknown or unclassified errors
127    Unknown,
128}
129
130/// Recovery strategies for different error types
131#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
132pub enum RecoveryStrategy {
133    /// Retry with exponential backoff
134    Retry {
135        max_attempts: usize,
136        base_delay_ms: u64,
137    },
138    /// Fallback to alternative implementation
139    Fallback { fallback_implementation: String },
140    /// Reduce resource usage and retry
141    ResourceReduction { reduction_factor: f64 },
142    /// Restart subsystem
143    Restart { component: String },
144    /// Clean memory and retry
145    MemoryCleanup,
146    /// Load from checkpoint
147    CheckpointRestore { checkpoint_id: String },
148    /// Graceful degradation
149    Degrade { degraded_mode: String },
150    /// No recovery possible
151    NoRecovery,
152}
153
154/// Error recovery attempt information
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct RecoveryAttempt {
157    pub attempt_id: Uuid,
158    pub timestamp: SystemTime,
159    pub error_category: ErrorCategory,
160    pub strategy: RecoveryStrategy,
161    pub success: bool,
162    pub duration_ms: u64,
163    pub error_message: String,
164    pub context: HashMap<String, String>,
165}
166
167/// Circuit breaker state
168#[derive(Debug, Clone, PartialEq)]
169enum CircuitBreakerState {
170    Closed,
171    Open,
172    HalfOpen,
173}
174
175/// Circuit breaker for preventing cascade failures
176#[derive(Debug)]
177struct CircuitBreaker {
178    state: CircuitBreakerState,
179    failure_count: usize,
180    last_failure_time: Option<Instant>,
181    failure_threshold: usize,
182    timeout: Duration,
183}
184
185impl CircuitBreaker {
186    fn new(failure_threshold: usize, timeout: Duration) -> Self {
187        Self {
188            state: CircuitBreakerState::Closed,
189            failure_count: 0,
190            last_failure_time: None,
191            failure_threshold,
192            timeout,
193        }
194    }
195
196    fn can_execute(&mut self) -> bool {
197        match self.state {
198            CircuitBreakerState::Closed => true,
199            CircuitBreakerState::Open => {
200                if let Some(last_failure) = self.last_failure_time {
201                    if last_failure.elapsed() >= self.timeout {
202                        self.state = CircuitBreakerState::HalfOpen;
203                        true
204                    } else {
205                        false
206                    }
207                } else {
208                    true
209                }
210            },
211            CircuitBreakerState::HalfOpen => true,
212        }
213    }
214
215    fn on_success(&mut self) {
216        self.failure_count = 0;
217        self.state = CircuitBreakerState::Closed;
218    }
219
220    fn on_failure(&mut self) {
221        self.failure_count += 1;
222        self.last_failure_time = Some(Instant::now());
223
224        if self.failure_count >= self.failure_threshold {
225            self.state = CircuitBreakerState::Open;
226        }
227    }
228}
229
230/// Model checkpoint for recovery
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct ModelCheckpoint {
233    pub checkpoint_id: String,
234    pub timestamp: SystemTime,
235    pub model_state: HashMap<String, Vec<u8>>, // Serialized tensors
236    pub metadata: HashMap<String, String>,
237    pub size_bytes: usize,
238}
239
240impl ModelCheckpoint {
241    /// Create a new checkpoint
242    pub fn new(model_state: HashMap<String, Vec<u8>>, metadata: HashMap<String, String>) -> Self {
243        let size_bytes = model_state.values().map(|v| v.len()).sum();
244
245        Self {
246            checkpoint_id: Uuid::new_v4().to_string(),
247            timestamp: SystemTime::now(),
248            model_state,
249            metadata,
250            size_bytes,
251        }
252    }
253}
254
255/// Recovery performance metrics
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct RecoveryMetrics {
258    pub total_errors: usize,
259    pub successful_recoveries: usize,
260    pub failed_recoveries: usize,
261    pub average_recovery_time_ms: f64,
262    pub recovery_rate: f64,
263    pub error_frequency: f64,
264    pub most_common_errors: HashMap<ErrorCategory, usize>,
265    pub most_effective_strategies: HashMap<String, f64>,
266}
267
268/// Main error recovery manager
269pub struct ErrorRecoveryManager {
270    config: RecoveryConfig,
271    error_history: VecDeque<RecoveryAttempt>,
272    circuit_breakers: HashMap<String, CircuitBreaker>,
273    checkpoints: HashMap<String, ModelCheckpoint>,
274    recovery_strategies: HashMap<ErrorCategory, Vec<RecoveryStrategy>>,
275    metrics: Arc<Mutex<RecoveryMetrics>>,
276    start_time: Instant,
277}
278
279impl ErrorRecoveryManager {
280    /// Create a new error recovery manager
281    pub fn new(config: RecoveryConfig) -> Self {
282        let mut recovery_strategies = HashMap::new();
283
284        // Define default recovery strategies for each error category
285        recovery_strategies.insert(
286            ErrorCategory::Memory,
287            vec![
288                RecoveryStrategy::MemoryCleanup,
289                RecoveryStrategy::ResourceReduction {
290                    reduction_factor: 0.5,
291                },
292                RecoveryStrategy::CheckpointRestore {
293                    checkpoint_id: "latest".to_string(),
294                },
295            ],
296        );
297
298        recovery_strategies.insert(
299            ErrorCategory::Compute,
300            vec![
301                RecoveryStrategy::Retry {
302                    max_attempts: 3,
303                    base_delay_ms: 1000,
304                },
305                RecoveryStrategy::Fallback {
306                    fallback_implementation: "cpu".to_string(),
307                },
308                RecoveryStrategy::Restart {
309                    component: "compute_engine".to_string(),
310                },
311            ],
312        );
313
314        recovery_strategies.insert(
315            ErrorCategory::Network,
316            vec![
317                RecoveryStrategy::Retry {
318                    max_attempts: 5,
319                    base_delay_ms: 2000,
320                },
321                RecoveryStrategy::Fallback {
322                    fallback_implementation: "local".to_string(),
323                },
324            ],
325        );
326
327        recovery_strategies.insert(
328            ErrorCategory::Model,
329            vec![
330                RecoveryStrategy::CheckpointRestore {
331                    checkpoint_id: "latest".to_string(),
332                },
333                RecoveryStrategy::Degrade {
334                    degraded_mode: "simple".to_string(),
335                },
336                RecoveryStrategy::Restart {
337                    component: "model".to_string(),
338                },
339            ],
340        );
341
342        recovery_strategies.insert(
343            ErrorCategory::Data,
344            vec![
345                RecoveryStrategy::Retry {
346                    max_attempts: 2,
347                    base_delay_ms: 100,
348                },
349                RecoveryStrategy::Fallback {
350                    fallback_implementation: "default_data".to_string(),
351                },
352            ],
353        );
354
355        recovery_strategies.insert(
356            ErrorCategory::Resource,
357            vec![
358                RecoveryStrategy::Retry {
359                    max_attempts: 3,
360                    base_delay_ms: 5000,
361                },
362                RecoveryStrategy::ResourceReduction {
363                    reduction_factor: 0.7,
364                },
365            ],
366        );
367
368        Self {
369            config,
370            error_history: VecDeque::new(),
371            circuit_breakers: HashMap::new(),
372            checkpoints: HashMap::new(),
373            recovery_strategies,
374            metrics: Arc::new(Mutex::new(RecoveryMetrics {
375                total_errors: 0,
376                successful_recoveries: 0,
377                failed_recoveries: 0,
378                average_recovery_time_ms: 0.0,
379                recovery_rate: 0.0,
380                error_frequency: 0.0,
381                most_common_errors: HashMap::new(),
382                most_effective_strategies: HashMap::new(),
383            })),
384            start_time: Instant::now(),
385        }
386    }
387
388    /// Execute a function with automatic error recovery
389    pub fn execute_with_recovery<T, F>(&mut self, operation: F) -> Result<T>
390    where
391        F: Fn() -> Result<T>,
392    {
393        let operation_name = "default_operation";
394
395        // Check circuit breaker
396        if !self.get_or_create_circuit_breaker(operation_name).can_execute() {
397            return Err(anyhow::anyhow!(
398                "Circuit breaker is open for operation: {}",
399                operation_name
400            ));
401        }
402
403        let mut last_error = None;
404
405        for attempt in 0..=self.config.max_retries {
406            let start_time = Instant::now();
407
408            match operation() {
409                Ok(result) => {
410                    // Success - update circuit breaker and metrics
411                    self.get_or_create_circuit_breaker(operation_name).on_success();
412
413                    if attempt > 0 {
414                        // Record successful recovery
415                        self.record_successful_recovery(attempt, start_time);
416                    }
417
418                    return Ok(result);
419                },
420                Err(error) => {
421                    last_error = Some(anyhow::anyhow!(error.to_string()));
422
423                    // Classify error and attempt recovery
424                    let error_category = self.classify_error(&error);
425                    let recovery_success = self
426                        .attempt_recovery(&error, error_category.clone(), attempt)
427                        .unwrap_or(false);
428
429                    if !recovery_success && attempt == self.config.max_retries {
430                        // All recovery attempts failed
431                        self.get_or_create_circuit_breaker(operation_name).on_failure();
432                        self.record_failed_recovery(error_category, start_time, &error);
433                        break;
434                    }
435
436                    // Wait before retrying (exponential backoff)
437                    if attempt < self.config.max_retries {
438                        let delay = self.calculate_backoff_delay(attempt);
439                        std::thread::sleep(delay);
440                    }
441                },
442            }
443        }
444
445        // Return the last error if all attempts failed
446        Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Unknown error occurred")))
447    }
448
449    /// Classify an error into a category
450    fn classify_error(&self, error: &Error) -> ErrorCategory {
451        let error_string = error.to_string().to_lowercase();
452
453        if error_string.contains("memory")
454            || error_string.contains("oom")
455            || error_string.contains("allocation")
456        {
457            ErrorCategory::Memory
458        } else if error_string.contains("cuda")
459            || error_string.contains("gpu")
460            || error_string.contains("device")
461        {
462            ErrorCategory::Compute
463        } else if error_string.contains("network")
464            || error_string.contains("connection")
465            || error_string.contains("timeout")
466        {
467            ErrorCategory::Network
468        } else if error_string.contains("dimension")
469            || error_string.contains("shape")
470            || error_string.contains("tensor")
471        {
472            ErrorCategory::Model
473        } else if error_string.contains("data")
474            || error_string.contains("input")
475            || error_string.contains("corrupted")
476        {
477            ErrorCategory::Data
478        } else if error_string.contains("resource")
479            || error_string.contains("unavailable")
480            || error_string.contains("busy")
481        {
482            ErrorCategory::Resource
483        } else if error_string.contains("config")
484            || error_string.contains("setup")
485            || error_string.contains("initialization")
486        {
487            ErrorCategory::Configuration
488        } else {
489            ErrorCategory::Unknown
490        }
491    }
492
493    /// Attempt to recover from an error
494    fn attempt_recovery(
495        &mut self,
496        error: &Error,
497        category: ErrorCategory,
498        _attempt: usize,
499    ) -> Result<bool> {
500        let strategies = self.recovery_strategies.get(&category).cloned().unwrap_or_else(|| {
501            vec![RecoveryStrategy::Retry {
502                max_attempts: 1,
503                base_delay_ms: 1000,
504            }]
505        });
506
507        for strategy in strategies {
508            if self.execute_recovery_strategy(&strategy, error, &category)? {
509                self.record_recovery_attempt(category.clone(), strategy, true, error);
510                return Ok(true);
511            }
512        }
513
514        self.record_recovery_attempt(category, RecoveryStrategy::NoRecovery, false, error);
515        Ok(false)
516    }
517
518    /// Execute a specific recovery strategy
519    fn execute_recovery_strategy(
520        &mut self,
521        strategy: &RecoveryStrategy,
522        _error: &Error,
523        _category: &ErrorCategory,
524    ) -> Result<bool> {
525        match strategy {
526            RecoveryStrategy::Retry {
527                max_attempts: _,
528                base_delay_ms,
529            } => {
530                // Basic retry is handled by the main loop, just wait
531                std::thread::sleep(Duration::from_millis(*base_delay_ms));
532                Ok(true)
533            },
534
535            RecoveryStrategy::MemoryCleanup => {
536                self.perform_memory_cleanup()?;
537                Ok(true)
538            },
539
540            RecoveryStrategy::ResourceReduction { reduction_factor } => {
541                self.reduce_resource_usage(*reduction_factor)?;
542                Ok(true)
543            },
544
545            RecoveryStrategy::CheckpointRestore { checkpoint_id } => {
546                self.restore_from_checkpoint(checkpoint_id)
547            },
548
549            RecoveryStrategy::Fallback {
550                fallback_implementation,
551            } => {
552                self.switch_to_fallback(fallback_implementation)?;
553                Ok(true)
554            },
555
556            RecoveryStrategy::Restart { component } => {
557                self.restart_component(component)?;
558                Ok(true)
559            },
560
561            RecoveryStrategy::Degrade { degraded_mode } => {
562                self.enable_degraded_mode(degraded_mode)?;
563                Ok(true)
564            },
565
566            RecoveryStrategy::NoRecovery => Ok(false),
567        }
568    }
569
570    /// Perform memory cleanup
571    fn perform_memory_cleanup(&self) -> Result<()> {
572        // Force garbage collection if available
573        // Clear caches
574        // Compact memory
575        println!("[INFO] Performing memory cleanup");
576
577        // In a real implementation, this would:
578        // - Clear tensor caches
579        // - Force garbage collection
580        // - Compact memory allocations
581        // - Clear intermediate computations
582
583        Ok(())
584    }
585
586    /// Reduce resource usage
587    fn reduce_resource_usage(&self, reduction_factor: f64) -> Result<()> {
588        println!(
589            "[INFO] Reducing resource usage by factor: {}",
590            reduction_factor
591        );
592
593        // In a real implementation, this would:
594        // - Reduce batch sizes
595        // - Decrease model precision
596        // - Limit concurrent operations
597        // - Reduce cache sizes
598
599        Ok(())
600    }
601
602    /// Switch to fallback implementation
603    fn switch_to_fallback(&self, fallback: &str) -> Result<()> {
604        println!("[INFO] Switching to fallback implementation: {}", fallback);
605
606        // In a real implementation, this would:
607        // - Switch to CPU from GPU
608        // - Use simpler model architecture
609        // - Use alternative algorithms
610
611        Ok(())
612    }
613
614    /// Restart a component
615    fn restart_component(&self, component: &str) -> Result<()> {
616        println!("[INFO] Restarting component: {}", component);
617
618        // In a real implementation, this would:
619        // - Reinitialize the specified component
620        // - Clear component state
621        // - Reload configurations
622
623        Ok(())
624    }
625
626    /// Enable degraded mode
627    fn enable_degraded_mode(&self, mode: &str) -> Result<()> {
628        println!("[INFO] Enabling degraded mode: {}", mode);
629
630        // In a real implementation, this would:
631        // - Reduce functionality
632        // - Use simpler algorithms
633        // - Lower quality outputs
634
635        Ok(())
636    }
637
638    /// Restore from checkpoint
639    fn restore_from_checkpoint(&self, checkpoint_id: &str) -> Result<bool> {
640        if let Some(_checkpoint) = self.checkpoints.get(checkpoint_id) {
641            println!("[INFO] Restoring from checkpoint: {}", checkpoint_id);
642
643            // In a real implementation, this would:
644            // - Restore model weights from checkpoint
645            // - Restore optimizer state
646            // - Restore training state
647
648            Ok(true)
649        } else {
650            println!("[WARN] Checkpoint not found: {}", checkpoint_id);
651            Ok(false)
652        }
653    }
654
655    /// Create a model checkpoint
656    pub fn create_checkpoint(
657        &mut self,
658        model_state: HashMap<String, Vec<u8>>,
659        metadata: HashMap<String, String>,
660    ) -> String {
661        let checkpoint = ModelCheckpoint::new(model_state, metadata);
662        let checkpoint_id = checkpoint.checkpoint_id.clone();
663
664        self.checkpoints.insert(checkpoint_id.clone(), checkpoint);
665        self.checkpoints.insert(
666            "latest".to_string(),
667            self.checkpoints[&checkpoint_id].clone(),
668        );
669
670        // Limit number of checkpoints
671        if self.checkpoints.len() > 10 {
672            // Remove oldest checkpoints (simplified)
673            let keys_to_remove: Vec<String> = self.checkpoints.keys()
674                .filter(|k| *k != "latest")
675                .skip(9) // Keep 9 + "latest"
676                .cloned()
677                .collect();
678
679            for key in keys_to_remove {
680                self.checkpoints.remove(&key);
681            }
682        }
683
684        println!("[INFO] Created checkpoint: {}", checkpoint_id);
685        checkpoint_id
686    }
687
688    /// Calculate exponential backoff delay
689    fn calculate_backoff_delay(&self, attempt: usize) -> Duration {
690        let delay_ms =
691            self.config.base_delay_ms as f64 * self.config.backoff_multiplier.powi(attempt as i32);
692        let delay_ms = delay_ms.min(self.config.max_delay_ms as f64) as u64;
693        Duration::from_millis(delay_ms)
694    }
695
696    /// Get or create circuit breaker for an operation
697    fn get_or_create_circuit_breaker(&mut self, operation: &str) -> &mut CircuitBreaker {
698        self.circuit_breakers.entry(operation.to_string()).or_insert_with(|| {
699            CircuitBreaker::new(
700                self.config.circuit_breaker_threshold,
701                Duration::from_secs(self.config.circuit_breaker_timeout_s),
702            )
703        })
704    }
705
706    /// Record a recovery attempt
707    fn record_recovery_attempt(
708        &mut self,
709        category: ErrorCategory,
710        strategy: RecoveryStrategy,
711        success: bool,
712        error: &Error,
713    ) {
714        let attempt = RecoveryAttempt {
715            attempt_id: Uuid::new_v4(),
716            timestamp: SystemTime::now(),
717            error_category: category.clone(),
718            strategy: strategy.clone(),
719            success,
720            duration_ms: 0, // Would be calculated in real implementation
721            error_message: error.to_string(),
722            context: HashMap::new(),
723        };
724
725        self.error_history.push_back(attempt);
726
727        // Limit history size
728        while self.error_history.len() > self.config.max_error_history {
729            self.error_history.pop_front();
730        }
731
732        // Update metrics
733        if let Ok(mut metrics) = self.metrics.lock() {
734            metrics.total_errors += 1;
735            if success {
736                metrics.successful_recoveries += 1;
737            } else {
738                metrics.failed_recoveries += 1;
739            }
740
741            metrics.recovery_rate =
742                metrics.successful_recoveries as f64 / metrics.total_errors as f64;
743
744            let count = metrics.most_common_errors.entry(category).or_insert(0);
745            *count += 1;
746        }
747    }
748
749    /// Record successful recovery
750    fn record_successful_recovery(&mut self, _attempts: usize, start_time: Instant) {
751        if let Ok(mut metrics) = self.metrics.lock() {
752            let duration = start_time.elapsed().as_millis() as f64;
753            let total_recoveries = metrics.successful_recoveries + metrics.failed_recoveries;
754
755            if total_recoveries > 0 {
756                metrics.average_recovery_time_ms =
757                    (metrics.average_recovery_time_ms * total_recoveries as f64 + duration)
758                        / (total_recoveries + 1) as f64;
759            } else {
760                metrics.average_recovery_time_ms = duration;
761            }
762        }
763    }
764
765    /// Record failed recovery
766    fn record_failed_recovery(
767        &mut self,
768        category: ErrorCategory,
769        _start_time: Instant,
770        error: &Error,
771    ) {
772        self.record_recovery_attempt(category, RecoveryStrategy::NoRecovery, false, error);
773    }
774
775    /// Get current recovery metrics
776    pub fn get_metrics(&self) -> RecoveryMetrics {
777        self.metrics.lock().expect("operation failed").clone()
778    }
779
780    /// Generate recovery report
781    pub fn generate_recovery_report(&self) -> RecoveryReport {
782        let metrics = self.get_metrics();
783        let uptime = self.start_time.elapsed();
784
785        let recent_errors: Vec<_> = self.error_history.iter().rev().take(10).cloned().collect();
786
787        let error_trends = self.analyze_error_trends();
788        let recommendations = self.generate_recommendations(&metrics, &error_trends);
789
790        RecoveryReport {
791            timestamp: SystemTime::now(),
792            uptime,
793            metrics,
794            recent_errors,
795            error_trends,
796            recommendations,
797            circuit_breaker_states: self.get_circuit_breaker_states(),
798            checkpoint_count: self.checkpoints.len(),
799        }
800    }
801
802    /// Analyze error trends
803    fn analyze_error_trends(&self) -> ErrorTrends {
804        let now = SystemTime::now();
805        let one_hour_ago = now.checked_sub(Duration::from_secs(3600)).unwrap_or(now);
806
807        let recent_errors: Vec<_> = self
808            .error_history
809            .iter()
810            .filter(|attempt| attempt.timestamp >= one_hour_ago)
811            .collect();
812
813        let error_rate = recent_errors.len() as f64 / 3600.0; // errors per second
814        let recovery_success_rate = if !recent_errors.is_empty() {
815            recent_errors.iter().filter(|a| a.success).count() as f64 / recent_errors.len() as f64
816        } else {
817            1.0
818        };
819
820        let trending_up = recent_errors.len() > self.error_history.len() / 2;
821
822        ErrorTrends {
823            error_rate,
824            recovery_success_rate,
825            trending_up,
826            most_frequent_category: self.get_most_frequent_error_category(&recent_errors),
827        }
828    }
829
830    /// Get most frequent error category
831    fn get_most_frequent_error_category(
832        &self,
833        errors: &[&RecoveryAttempt],
834    ) -> Option<ErrorCategory> {
835        let mut category_counts = HashMap::new();
836
837        for error in errors {
838            let count = category_counts.entry(error.error_category.clone()).or_insert(0);
839            *count += 1;
840        }
841
842        category_counts
843            .into_iter()
844            .max_by_key(|(_, count)| *count)
845            .map(|(category, _)| category)
846    }
847
848    /// Generate recommendations based on metrics and trends
849    fn generate_recommendations(
850        &self,
851        metrics: &RecoveryMetrics,
852        trends: &ErrorTrends,
853    ) -> Vec<String> {
854        let mut recommendations = Vec::new();
855
856        if metrics.recovery_rate < 0.8 {
857            recommendations
858                .push("Recovery rate is low. Consider reviewing recovery strategies.".to_string());
859        }
860
861        if trends.error_rate > 0.1 {
862            recommendations.push("High error rate detected. Investigate root causes.".to_string());
863        }
864
865        if trends.trending_up {
866            recommendations
867                .push("Error frequency is increasing. Monitor system closely.".to_string());
868        }
869
870        if metrics.average_recovery_time_ms > 5000.0 {
871            recommendations
872                .push("Recovery time is high. Optimize recovery strategies.".to_string());
873        }
874
875        if let Some(category) = &trends.most_frequent_category {
876            recommendations.push(format!(
877                "Most frequent error category: {:?}. Focus optimization efforts here.",
878                category
879            ));
880        }
881
882        if recommendations.is_empty() {
883            recommendations.push("Error recovery system is operating normally.".to_string());
884        }
885
886        recommendations
887    }
888
889    /// Get circuit breaker states
890    fn get_circuit_breaker_states(&self) -> HashMap<String, String> {
891        self.circuit_breakers
892            .iter()
893            .map(|(name, breaker)| {
894                let state = match breaker.state {
895                    CircuitBreakerState::Closed => "CLOSED",
896                    CircuitBreakerState::Open => "OPEN",
897                    CircuitBreakerState::HalfOpen => "HALF_OPEN",
898                };
899                (name.clone(), state.to_string())
900            })
901            .collect()
902    }
903}
904
905/// Error trend analysis
906#[derive(Debug, Clone, Serialize, Deserialize)]
907pub struct ErrorTrends {
908    pub error_rate: f64,
909    pub recovery_success_rate: f64,
910    pub trending_up: bool,
911    pub most_frequent_category: Option<ErrorCategory>,
912}
913
914/// Comprehensive recovery report
915#[derive(Debug, Serialize, Deserialize)]
916pub struct RecoveryReport {
917    pub timestamp: SystemTime,
918    pub uptime: Duration,
919    pub metrics: RecoveryMetrics,
920    pub recent_errors: Vec<RecoveryAttempt>,
921    pub error_trends: ErrorTrends,
922    pub recommendations: Vec<String>,
923    pub circuit_breaker_states: HashMap<String, String>,
924    pub checkpoint_count: usize,
925}
926
927/// Convenience trait for adding recovery capabilities to any operation
928pub trait RecoverableOperation<T> {
929    fn with_recovery(self, manager: &mut ErrorRecoveryManager) -> Result<T>;
930}
931
932impl<T, F> RecoverableOperation<T> for F
933where
934    F: Fn() -> Result<T>,
935{
936    fn with_recovery(self, manager: &mut ErrorRecoveryManager) -> Result<T> {
937        manager.execute_with_recovery(self)
938    }
939}
940
941#[cfg(test)]
942mod tests {
943    use super::*;
944
945    #[test]
946    fn test_error_classification() {
947        let manager = ErrorRecoveryManager::new(RecoveryConfig::default());
948
949        let memory_error = anyhow::anyhow!("Out of memory error occurred");
950        assert_eq!(manager.classify_error(&memory_error), ErrorCategory::Memory);
951
952        let cuda_error = anyhow::anyhow!("CUDA device error");
953        assert_eq!(manager.classify_error(&cuda_error), ErrorCategory::Compute);
954    }
955
956    #[test]
957    fn test_circuit_breaker() {
958        let mut breaker = CircuitBreaker::new(2, Duration::from_secs(1));
959
960        assert!(breaker.can_execute());
961
962        breaker.on_failure();
963        assert!(breaker.can_execute());
964
965        breaker.on_failure();
966        assert!(!breaker.can_execute()); // Should be open now
967
968        breaker.on_success();
969        assert!(breaker.can_execute()); // Should be closed again
970    }
971
972    #[test]
973    fn test_backoff_calculation() {
974        let config = RecoveryConfig::default();
975        let manager = ErrorRecoveryManager::new(config);
976
977        let delay0 = manager.calculate_backoff_delay(0);
978        let delay1 = manager.calculate_backoff_delay(1);
979        let delay2 = manager.calculate_backoff_delay(2);
980
981        assert!(delay1 > delay0);
982        assert!(delay2 > delay1);
983    }
984
985    #[test]
986    fn test_recovery_config_builder() {
987        let config = RecoveryConfig::default()
988            .with_max_retries(5)
989            .with_fallback_enabled(false)
990            .with_memory_threshold(2048.0);
991
992        assert_eq!(config.max_retries, 5);
993        assert!(!config.enable_fallback);
994        assert_eq!(config.memory_pressure_threshold_mb, 2048.0);
995    }
996}