Skip to main content

trustformers_training/
error_handling.rs

1use anyhow::{Context as AnyhowContext, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt;
5use std::sync::{Arc, RwLock};
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use crate::error_codes::{get_error_info, is_critical_error};
9
10/// Enhanced error handling system for training
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TrainingError {
13    pub error_type: ErrorType,
14    pub message: String,
15    pub error_code: String,
16    pub severity: ErrorSeverity,
17    pub context: ErrorContext,
18    pub timestamp: u64,
19    pub recovery_suggestions: Vec<RecoverySuggestion>,
20    pub related_errors: Vec<String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, Eq, Hash, PartialEq)]
24pub enum ErrorType {
25    Configuration,
26    DataLoading,
27    ModelInitialization,
28    Training,
29    Validation,
30    Checkpoint,
31    Resource,
32    Network,
33    Hardware,
34    UserInput,
35    Internal,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, Eq, Hash, PartialEq)]
39pub enum ErrorSeverity {
40    Critical, // Training cannot continue
41    High,     // Training can continue with reduced functionality
42    Medium,   // Warning that may affect performance
43    Low,      // Informational, no impact on training
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ErrorContext {
48    pub component: String,
49    pub operation: String,
50    pub epoch: Option<u32>,
51    pub step: Option<u32>,
52    pub batch_size: Option<usize>,
53    pub learning_rate: Option<f64>,
54    pub model_state: Option<String>,
55    pub system_info: SystemInfo,
56    pub additional_data: HashMap<String, String>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct SystemInfo {
61    pub memory_usage: Option<u64>,
62    pub gpu_memory_usage: Option<u64>,
63    pub cpu_usage: Option<f32>,
64    pub disk_space: Option<u64>,
65    pub network_status: Option<String>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct RecoverySuggestion {
70    pub action: String,
71    pub description: String,
72    pub priority: u8,    // 1-10, where 10 is highest priority
73    pub automatic: bool, // Whether this can be applied automatically
74}
75
76impl fmt::Display for TrainingError {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        write!(
79            f,
80            "[{}] {}: {} ({})",
81            self.severity, self.error_type, self.message, self.error_code
82        )
83    }
84}
85
86impl fmt::Display for ErrorType {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        match self {
89            ErrorType::Configuration => write!(f, "CONFIGURATION"),
90            ErrorType::DataLoading => write!(f, "DATA_LOADING"),
91            ErrorType::ModelInitialization => write!(f, "MODEL_INIT"),
92            ErrorType::Training => write!(f, "TRAINING"),
93            ErrorType::Validation => write!(f, "VALIDATION"),
94            ErrorType::Checkpoint => write!(f, "CHECKPOINT"),
95            ErrorType::Resource => write!(f, "RESOURCE"),
96            ErrorType::Network => write!(f, "NETWORK"),
97            ErrorType::Hardware => write!(f, "HARDWARE"),
98            ErrorType::UserInput => write!(f, "USER_INPUT"),
99            ErrorType::Internal => write!(f, "INTERNAL"),
100        }
101    }
102}
103
104impl fmt::Display for ErrorSeverity {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        match self {
107            ErrorSeverity::Critical => write!(f, "CRITICAL"),
108            ErrorSeverity::High => write!(f, "HIGH"),
109            ErrorSeverity::Medium => write!(f, "MEDIUM"),
110            ErrorSeverity::Low => write!(f, "LOW"),
111        }
112    }
113}
114
115impl std::error::Error for TrainingError {}
116
117/// Error manager for collecting, analyzing, and handling training errors
118pub struct ErrorManager {
119    errors: Arc<RwLock<Vec<TrainingError>>>,
120    error_patterns: Arc<RwLock<HashMap<String, ErrorPattern>>>,
121    recovery_strategies: Arc<RwLock<HashMap<ErrorType, Vec<RecoveryStrategy>>>>,
122    statistics: Arc<RwLock<ErrorStatistics>>,
123}
124
125#[derive(Debug, Clone)]
126pub struct ErrorPattern {
127    pub pattern_id: String,
128    pub error_codes: Vec<String>,
129    pub frequency_threshold: u32,
130    pub time_window_seconds: u64,
131    pub suggested_actions: Vec<RecoverySuggestion>,
132}
133
134#[derive(Debug, Clone)]
135pub struct RecoveryStrategy {
136    pub strategy_id: String,
137    pub name: String,
138    pub applicable_errors: Vec<String>,
139    pub handler: fn(&TrainingError) -> Result<RecoveryAction>,
140    pub auto_apply: bool,
141}
142
143#[derive(Debug, Clone)]
144pub enum RecoveryAction {
145    Continue,
146    Retry {
147        max_attempts: u32,
148    },
149    Restart {
150        checkpoint: Option<String>,
151    },
152    Abort,
153    ReduceResources {
154        factor: f32,
155    },
156    ChangeConfiguration {
157        config_changes: HashMap<String, String>,
158    },
159    SwitchFallback {
160        fallback_config: String,
161    },
162}
163
164#[derive(Debug, Default, Clone)]
165pub struct ErrorStatistics {
166    pub total_errors: u64,
167    pub errors_by_type: HashMap<ErrorType, u64>,
168    pub errors_by_severity: HashMap<ErrorSeverity, u64>,
169    pub errors_by_component: HashMap<String, u64>,
170    pub recovery_success_rate: f64,
171    pub most_common_errors: Vec<(String, u64)>,
172    pub error_trends: Vec<ErrorTrend>,
173}
174
175#[derive(Debug, Clone)]
176pub struct ErrorTrend {
177    pub timestamp: u64,
178    pub error_count: u64,
179    pub error_rate: f64, // errors per minute
180}
181
182impl Default for ErrorManager {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl ErrorManager {
189    pub fn new() -> Self {
190        Self {
191            errors: Arc::new(RwLock::new(Vec::new())),
192            error_patterns: Arc::new(RwLock::new(HashMap::new())),
193            recovery_strategies: Arc::new(RwLock::new(HashMap::new())),
194            statistics: Arc::new(RwLock::new(ErrorStatistics::default())),
195        }
196    }
197
198    /// Record a new error
199    pub fn record_error(&self, error: TrainingError) -> Result<()> {
200        // Add to error log
201        {
202            let mut errors = self
203                .errors
204                .write()
205                .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on errors"))?;
206            errors.push(error.clone());
207        }
208
209        // Update statistics
210        self.update_statistics(&error)?;
211
212        // Check for error patterns
213        self.check_error_patterns(&error)?;
214
215        // Attempt automatic recovery if applicable
216        self.attempt_recovery(&error)?;
217
218        Ok(())
219    }
220
221    /// Create and record an error with context
222    pub fn create_error(
223        &self,
224        error_type: ErrorType,
225        message: String,
226        error_code: String,
227        severity: ErrorSeverity,
228        context: ErrorContext,
229    ) -> TrainingError {
230        let error = TrainingError {
231            error_type: error_type.clone(),
232            message,
233            error_code: error_code.clone(),
234            severity: severity.clone(),
235            context,
236            timestamp: SystemTime::now()
237                .duration_since(UNIX_EPOCH)
238                .expect("SystemTime should be after UNIX_EPOCH")
239                .as_secs(),
240            recovery_suggestions: self.get_recovery_suggestions(&error_type, &error_code),
241            related_errors: Vec::new(),
242        };
243
244        if let Err(e) = self.record_error(error.clone()) {
245            eprintln!("Failed to record error: {}", e);
246        }
247
248        error
249    }
250
251    /// Add an error pattern for detection
252    pub fn add_error_pattern(&self, pattern: ErrorPattern) -> Result<()> {
253        let mut patterns = self
254            .error_patterns
255            .write()
256            .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on error patterns"))?;
257        patterns.insert(pattern.pattern_id.clone(), pattern);
258        Ok(())
259    }
260
261    /// Add a recovery strategy
262    pub fn add_recovery_strategy(
263        &self,
264        error_type: ErrorType,
265        strategy: RecoveryStrategy,
266    ) -> Result<()> {
267        let mut strategies = self
268            .recovery_strategies
269            .write()
270            .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on recovery strategies"))?;
271        strategies.entry(error_type).or_insert_with(Vec::new).push(strategy);
272        Ok(())
273    }
274
275    fn update_statistics(&self, error: &TrainingError) -> Result<()> {
276        let mut stats = self
277            .statistics
278            .write()
279            .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on statistics"))?;
280
281        stats.total_errors += 1;
282        *stats.errors_by_type.entry(error.error_type.clone()).or_insert(0) += 1;
283        *stats.errors_by_severity.entry(error.severity.clone()).or_insert(0) += 1;
284        *stats.errors_by_component.entry(error.context.component.clone()).or_insert(0) += 1;
285
286        // Update trends
287        let current_time = SystemTime::now()
288            .duration_since(UNIX_EPOCH)
289            .expect("SystemTime should be after UNIX_EPOCH")
290            .as_secs();
291
292        // Calculate error rate for the last minute
293        let errors = self
294            .errors
295            .read()
296            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on errors"))?;
297        let recent_errors =
298            errors.iter().filter(|e| current_time - e.timestamp <= 60).count() as u64;
299
300        stats.error_trends.push(ErrorTrend {
301            timestamp: current_time,
302            error_count: recent_errors,
303            error_rate: recent_errors as f64 / 60.0,
304        });
305
306        // Keep only last 100 trend points
307        if stats.error_trends.len() > 100 {
308            stats.error_trends.remove(0);
309        }
310
311        Ok(())
312    }
313
314    fn check_error_patterns(&self, error: &TrainingError) -> Result<()> {
315        let patterns = self
316            .error_patterns
317            .read()
318            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on error patterns"))?;
319
320        for pattern in patterns.values() {
321            if pattern.error_codes.contains(&error.error_code) {
322                // Check if this pattern has occurred frequently
323                let recent_matching_errors = self.count_recent_matching_errors(pattern)?;
324
325                if recent_matching_errors >= pattern.frequency_threshold {
326                    println!(
327                        "๐Ÿšจ Error pattern detected: {} (occurred {} times)",
328                        pattern.pattern_id, recent_matching_errors
329                    );
330
331                    // Apply suggested actions
332                    for suggestion in &pattern.suggested_actions {
333                        println!(
334                            "๐Ÿ’ก Suggestion: {} - {}",
335                            suggestion.action, suggestion.description
336                        );
337                    }
338                }
339            }
340        }
341
342        Ok(())
343    }
344
345    fn count_recent_matching_errors(&self, pattern: &ErrorPattern) -> Result<u32> {
346        let errors = self
347            .errors
348            .read()
349            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on errors"))?;
350
351        let current_time = SystemTime::now()
352            .duration_since(UNIX_EPOCH)
353            .expect("SystemTime should be after UNIX_EPOCH")
354            .as_secs();
355
356        let count = errors
357            .iter()
358            .filter(|error| {
359                current_time - error.timestamp <= pattern.time_window_seconds
360                    && pattern.error_codes.contains(&error.error_code)
361            })
362            .count() as u32;
363
364        Ok(count)
365    }
366
367    fn attempt_recovery(&self, error: &TrainingError) -> Result<()> {
368        // Check if this is a critical error that requires immediate attention
369        if is_critical_error(&error.error_code) {
370            println!(
371                "๐Ÿšจ Critical error detected: {} - Manual intervention required",
372                error.error_code
373            );
374            return Ok(()); // Don't attempt automatic recovery for critical errors
375        }
376
377        let strategies = self
378            .recovery_strategies
379            .read()
380            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on recovery strategies"))?;
381
382        // Try built-in recovery strategies first
383        if let Some(success) = self.try_builtin_recovery(error)? {
384            if success {
385                println!(
386                    "โœ… Built-in recovery successful for error: {}",
387                    error.error_code
388                );
389                return Ok(());
390            }
391        }
392
393        // Try custom recovery strategies
394        if let Some(type_strategies) = strategies.get(&error.error_type) {
395            for strategy in type_strategies {
396                if strategy.auto_apply && strategy.applicable_errors.contains(&error.error_code) {
397                    println!("๐Ÿ”ง Attempting automatic recovery: {}", strategy.name);
398
399                    match (strategy.handler)(error) {
400                        Ok(action) => {
401                            println!("โœ… Recovery action determined: {:?}", action);
402
403                            // Execute the recovery action
404                            if let Err(e) = self.execute_recovery_action(&action, error) {
405                                println!("โŒ Failed to execute recovery action: {}", e);
406                                continue;
407                            }
408
409                            println!("โœ… Recovery action executed successfully");
410                            return Ok(());
411                        },
412                        Err(e) => {
413                            println!("โŒ Recovery strategy failed: {}", e);
414                        },
415                    }
416                }
417            }
418        }
419
420        // If no automatic recovery worked, suggest manual recovery
421        self.suggest_manual_recovery(error);
422        Ok(())
423    }
424
425    /// Try built-in recovery strategies based on error code
426    fn try_builtin_recovery(&self, error: &TrainingError) -> Result<Option<bool>> {
427        match error.error_code.as_str() {
428            "RESOURCE_OOM" | "RESOURCE_GPU_OOM" => {
429                println!("๐Ÿ”ง Attempting memory recovery for OOM error");
430                // Simulate memory cleanup
431                self.simulate_memory_cleanup()?;
432                Ok(Some(true))
433            },
434            "TRAIN_NAN_LOSS" | "TRAIN_INF_LOSS" => {
435                println!("๐Ÿ”ง Attempting numerical stability recovery");
436                // Suggest lower learning rate and gradient clipping
437                self.suggest_numerical_fixes(error)?;
438                Ok(Some(false)) // Don't automatically apply, just suggest
439            },
440            "DATA_FILE_NOT_FOUND" => {
441                println!("๐Ÿ”ง Attempting data path recovery");
442                // Try to find alternative data paths
443                self.suggest_data_path_fixes(error)?;
444                Ok(Some(false))
445            },
446            "NETWORK_CONNECTION_TIMEOUT" => {
447                println!("๐Ÿ”ง Attempting network recovery");
448                // Try retry with exponential backoff
449                self.attempt_network_retry(error)?;
450                Ok(Some(true))
451            },
452            _ => Ok(None), // No built-in recovery for this error code
453        }
454    }
455
456    /// Execute a recovery action
457    fn execute_recovery_action(
458        &self,
459        action: &RecoveryAction,
460        _error: &TrainingError,
461    ) -> Result<()> {
462        match action {
463            RecoveryAction::Continue => {
464                println!("๐Ÿ“ Recovery action: Continue training");
465                Ok(())
466            },
467            RecoveryAction::Retry { max_attempts } => {
468                println!(
469                    "๐Ÿ“ Recovery action: Retry operation (max {} attempts)",
470                    max_attempts
471                );
472                // In a real implementation, would retry the failed operation
473                Ok(())
474            },
475            RecoveryAction::Restart { checkpoint } => {
476                println!(
477                    "๐Ÿ“ Recovery action: Restart from checkpoint: {:?}",
478                    checkpoint
479                );
480                // In a real implementation, would restart training from checkpoint
481                Ok(())
482            },
483            RecoveryAction::Abort => {
484                println!("๐Ÿ“ Recovery action: Abort training");
485                Err(anyhow::anyhow!(
486                    "Training aborted due to unrecoverable error"
487                ))
488            },
489            RecoveryAction::ReduceResources { factor } => {
490                println!("๐Ÿ“ Recovery action: Reduce resources by factor {}", factor);
491                // In a real implementation, would reduce batch size, model size, etc.
492                Ok(())
493            },
494            RecoveryAction::ChangeConfiguration { config_changes } => {
495                println!(
496                    "๐Ÿ“ Recovery action: Change configuration: {:?}",
497                    config_changes
498                );
499                // In a real implementation, would apply configuration changes
500                Ok(())
501            },
502            RecoveryAction::SwitchFallback { fallback_config } => {
503                println!(
504                    "๐Ÿ“ Recovery action: Switch to fallback configuration: {}",
505                    fallback_config
506                );
507                // In a real implementation, would switch to fallback configuration
508                Ok(())
509            },
510        }
511    }
512
513    /// Simulate memory cleanup for OOM errors
514    fn simulate_memory_cleanup(&self) -> Result<()> {
515        println!("๐Ÿงน Simulating memory cleanup...");
516        println!("  - Clearing unused tensors");
517        println!("  - Running garbage collection");
518        println!("  - Reducing batch size temporarily");
519        Ok(())
520    }
521
522    /// Suggest numerical stability fixes
523    fn suggest_numerical_fixes(&self, error: &TrainingError) -> Result<()> {
524        println!("๐Ÿ’ก Numerical stability suggestions:");
525        println!("  - Reduce learning rate by factor of 10");
526        println!("  - Enable gradient clipping (max_norm=1.0)");
527        println!("  - Check input data normalization");
528        println!("  - Consider using mixed precision training");
529
530        if let Some(lr) = error.context.learning_rate {
531            println!("  - Current learning rate: {}, suggested: {}", lr, lr * 0.1);
532        }
533
534        Ok(())
535    }
536
537    /// Suggest data path fixes
538    fn suggest_data_path_fixes(&self, _error: &TrainingError) -> Result<()> {
539        println!("๐Ÿ’ก Data path suggestions:");
540        println!("  - Check if file path is correct");
541        println!("  - Verify file permissions");
542        println!("  - Try relative vs absolute paths");
543        println!("  - Check if data is in expected location");
544        Ok(())
545    }
546
547    /// Attempt network retry with exponential backoff
548    fn attempt_network_retry(&self, _error: &TrainingError) -> Result<()> {
549        println!("๐Ÿ”„ Attempting network retry with exponential backoff...");
550
551        for attempt in 1..=3 {
552            println!("  Attempt {}/3", attempt);
553
554            // Simulate network operation
555            std::thread::sleep(std::time::Duration::from_millis(100 * (1 << attempt)));
556
557            // Simulate random success/failure
558            if fastrand::bool() {
559                println!("  โœ… Network operation succeeded");
560                return Ok(());
561            }
562
563            println!("  โŒ Network operation failed, retrying...");
564        }
565
566        Err(anyhow::anyhow!("Network operation failed after 3 attempts"))
567    }
568
569    /// Suggest manual recovery steps
570    fn suggest_manual_recovery(&self, error: &TrainingError) {
571        println!("๐Ÿ”ง Manual recovery suggestions for {}:", error.error_code);
572
573        for suggestion in &error.recovery_suggestions {
574            println!(
575                "  {} (Priority: {}) - {}",
576                if suggestion.automatic { "๐Ÿค– AUTO" } else { "๐Ÿ‘ค MANUAL" },
577                suggestion.priority,
578                suggestion.action
579            );
580            println!("    ๐Ÿ“ {}", suggestion.description);
581        }
582
583        // Additional context-specific suggestions
584        if let Some(epoch) = error.context.epoch {
585            println!("  ๐Ÿ“Š Error occurred at epoch {}", epoch);
586            if epoch < 5 {
587                println!(
588                    "    ๐Ÿ’ก Early training failure - check data loading and model initialization"
589                );
590            }
591        }
592
593        if let Some(step) = error.context.step {
594            println!("  ๐Ÿ“Š Error occurred at step {}", step);
595        }
596    }
597
598    fn get_recovery_suggestions(
599        &self,
600        error_type: &ErrorType,
601        error_code: &str,
602    ) -> Vec<RecoverySuggestion> {
603        // First, try to get suggestions from the error code registry
604        if let Some(error_info) = get_error_info(error_code) {
605            return error_info
606                .solutions
607                .iter()
608                .enumerate()
609                .map(|(i, solution)| {
610                    RecoverySuggestion {
611                        action: solution.to_string(),
612                        description: format!(
613                            "See documentation: {}",
614                            error_info
615                                .documentation_url
616                                .unwrap_or("https://docs.trustformers.rs/errors")
617                        ),
618                        priority: 10 - i as u8, // Higher priority for earlier solutions
619                        automatic: error_info.severity != "CRITICAL", // Auto-apply only for non-critical
620                    }
621                })
622                .collect();
623        }
624
625        // Fallback to built-in recovery suggestions based on error type
626        match error_type {
627            ErrorType::Configuration => vec![
628                RecoverySuggestion {
629                    action: "Check configuration file".to_string(),
630                    description: "Verify that all required parameters are set correctly"
631                        .to_string(),
632                    priority: 9,
633                    automatic: false,
634                },
635                RecoverySuggestion {
636                    action: "Use default configuration".to_string(),
637                    description: "Fall back to known good default settings".to_string(),
638                    priority: 7,
639                    automatic: true,
640                },
641            ],
642            ErrorType::DataLoading => vec![
643                RecoverySuggestion {
644                    action: "Check data path".to_string(),
645                    description: "Verify that the dataset path exists and is accessible"
646                        .to_string(),
647                    priority: 9,
648                    automatic: false,
649                },
650                RecoverySuggestion {
651                    action: "Reduce batch size".to_string(),
652                    description: "Try reducing batch size to avoid memory issues".to_string(),
653                    priority: 8,
654                    automatic: true,
655                },
656            ],
657            ErrorType::Training => vec![
658                RecoverySuggestion {
659                    action: "Reduce learning rate".to_string(),
660                    description: "Lower the learning rate to stabilize training".to_string(),
661                    priority: 8,
662                    automatic: true,
663                },
664                RecoverySuggestion {
665                    action: "Check for NaN/Inf values".to_string(),
666                    description: "Inspect model weights and gradients for numerical issues"
667                        .to_string(),
668                    priority: 9,
669                    automatic: false,
670                },
671            ],
672            ErrorType::Resource => vec![
673                RecoverySuggestion {
674                    action: "Free up memory".to_string(),
675                    description: "Clear unused variables and reduce model size".to_string(),
676                    priority: 9,
677                    automatic: true,
678                },
679                RecoverySuggestion {
680                    action: "Use gradient checkpointing".to_string(),
681                    description: "Enable gradient checkpointing to reduce memory usage".to_string(),
682                    priority: 7,
683                    automatic: true,
684                },
685            ],
686            _ => vec![RecoverySuggestion {
687                action: "Restart training".to_string(),
688                description: "Restart training from the last checkpoint".to_string(),
689                priority: 5,
690                automatic: false,
691            }],
692        }
693    }
694
695    /// Get error statistics
696    pub fn get_statistics(&self) -> Result<ErrorStatistics> {
697        let stats = self
698            .statistics
699            .read()
700            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on statistics"))?;
701        Ok((*stats).clone())
702    }
703
704    /// Get recent errors
705    pub fn get_recent_errors(&self, limit: usize) -> Result<Vec<TrainingError>> {
706        let errors = self
707            .errors
708            .read()
709            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on errors"))?;
710
711        let recent: Vec<_> = errors.iter().rev().take(limit).cloned().collect();
712
713        Ok(recent)
714    }
715
716    /// Clear error history
717    pub fn clear_errors(&self) -> Result<()> {
718        let mut errors = self
719            .errors
720            .write()
721            .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on errors"))?;
722        errors.clear();
723        Ok(())
724    }
725
726    /// Export errors to JSON for external analysis
727    pub fn export_errors(&self) -> Result<String> {
728        let errors = self
729            .errors
730            .read()
731            .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on errors"))?;
732
733        serde_json::to_string_pretty(&*errors).context("Failed to serialize errors to JSON")
734    }
735}
736
737/// Helper macros for error creation
738#[macro_export]
739macro_rules! training_error {
740    ($error_manager:expr, $error_type:expr, $message:expr, $error_code:expr, $severity:expr, $context:expr) => {
741        $error_manager.create_error(
742            $error_type,
743            $message.to_string(),
744            $error_code.to_string(),
745            $severity,
746            $context,
747        )
748    };
749}
750
751#[macro_export]
752macro_rules! create_context {
753    ($component:expr, $operation:expr) => {
754        ErrorContext {
755            component: $component.to_string(),
756            operation: $operation.to_string(),
757            epoch: None,
758            step: None,
759            batch_size: None,
760            learning_rate: None,
761            model_state: None,
762            system_info: SystemInfo {
763                memory_usage: None,
764                gpu_memory_usage: None,
765                cpu_usage: None,
766                disk_space: None,
767                network_status: None,
768            },
769            additional_data: HashMap::new(),
770        }
771    };
772    ($component:expr, $operation:expr, epoch: $epoch:expr, step: $step:expr) => {
773        ErrorContext {
774            component: $component.to_string(),
775            operation: $operation.to_string(),
776            epoch: Some($epoch),
777            step: Some($step),
778            batch_size: None,
779            learning_rate: None,
780            model_state: None,
781            system_info: SystemInfo {
782                memory_usage: None,
783                gpu_memory_usage: None,
784                cpu_usage: None,
785                disk_space: None,
786                network_status: None,
787            },
788            additional_data: HashMap::new(),
789        }
790    };
791}
792
793/// Result type with training error
794pub type TrainingResult<T> = Result<T, TrainingError>;
795
796/// Extension trait for converting anyhow errors to training errors
797pub trait TrainingErrorExt<T> {
798    fn training_error(
799        self,
800        error_manager: &ErrorManager,
801        error_type: ErrorType,
802        error_code: &str,
803        severity: ErrorSeverity,
804        context: ErrorContext,
805    ) -> TrainingResult<T>;
806}
807
808impl<T> TrainingErrorExt<T> for Result<T> {
809    fn training_error(
810        self,
811        error_manager: &ErrorManager,
812        error_type: ErrorType,
813        error_code: &str,
814        severity: ErrorSeverity,
815        context: ErrorContext,
816    ) -> TrainingResult<T> {
817        match self {
818            Ok(value) => Ok(value),
819            Err(e) => {
820                let training_error = error_manager.create_error(
821                    error_type,
822                    e.to_string(),
823                    error_code.to_string(),
824                    severity,
825                    context,
826                );
827                Err(training_error)
828            },
829        }
830    }
831}
832
833#[cfg(test)]
834mod tests {
835    use super::*;
836
837    #[test]
838    fn test_error_manager_creation() {
839        let manager = ErrorManager::new();
840        let stats = manager.get_statistics().expect("operation failed in test");
841        assert_eq!(stats.total_errors, 0);
842    }
843
844    #[test]
845    fn test_error_recording() {
846        let manager = ErrorManager::new();
847
848        let context = ErrorContext {
849            component: "trainer".to_string(),
850            operation: "forward_pass".to_string(),
851            epoch: Some(1),
852            step: Some(100),
853            batch_size: Some(32),
854            learning_rate: Some(0.001),
855            model_state: None,
856            system_info: SystemInfo {
857                memory_usage: Some(1024),
858                gpu_memory_usage: Some(512),
859                cpu_usage: Some(50.0),
860                disk_space: None,
861                network_status: None,
862            },
863            additional_data: HashMap::new(),
864        };
865
866        let error = manager.create_error(
867            ErrorType::Training,
868            "NaN detected in loss".to_string(),
869            "TRAINING_NAN_LOSS".to_string(),
870            ErrorSeverity::Critical,
871            context,
872        );
873
874        assert_eq!(error.error_type, ErrorType::Training);
875        assert_eq!(error.error_code, "TRAINING_NAN_LOSS");
876        assert_eq!(error.severity, ErrorSeverity::Critical);
877        assert!(!error.recovery_suggestions.is_empty());
878
879        let stats = manager.get_statistics().expect("operation failed in test");
880        assert_eq!(stats.total_errors, 1);
881        assert_eq!(
882            *stats
883                .errors_by_type
884                .get(&ErrorType::Training)
885                .expect("expected value not found"),
886            1
887        );
888    }
889
890    #[test]
891    fn test_error_pattern_detection() {
892        let manager = ErrorManager::new();
893
894        let pattern = ErrorPattern {
895            pattern_id: "frequent_oom".to_string(),
896            error_codes: vec!["RESOURCE_OOM".to_string()],
897            frequency_threshold: 3,
898            time_window_seconds: 300,
899            suggested_actions: vec![RecoverySuggestion {
900                action: "Reduce batch size".to_string(),
901                description: "Lower batch size to reduce memory usage".to_string(),
902                priority: 9,
903                automatic: true,
904            }],
905        };
906
907        manager.add_error_pattern(pattern).expect("add operation failed");
908
909        // Simulate multiple OOM errors
910        let context = create_context!("trainer", "forward_pass");
911        for _ in 0..3 {
912            manager.create_error(
913                ErrorType::Resource,
914                "Out of memory".to_string(),
915                "RESOURCE_OOM".to_string(),
916                ErrorSeverity::Critical,
917                context.clone(),
918            );
919        }
920
921        // Pattern should be detected after 3 occurrences
922        let stats = manager.get_statistics().expect("operation failed in test");
923        assert_eq!(stats.total_errors, 3);
924    }
925
926    #[test]
927    fn test_recovery_suggestions() {
928        let manager = ErrorManager::new();
929
930        let suggestions =
931            manager.get_recovery_suggestions(&ErrorType::Training, "TRAINING_NAN_LOSS");
932        assert!(!suggestions.is_empty());
933
934        let has_lr_suggestion = suggestions.iter().any(|s| s.action.contains("learning rate"));
935        assert!(has_lr_suggestion);
936    }
937
938    #[test]
939    fn test_error_statistics() {
940        let manager = ErrorManager::new();
941
942        // Create errors of different types and severities
943        let context = create_context!("trainer", "test");
944
945        manager.create_error(
946            ErrorType::Training,
947            "Test error 1".to_string(),
948            "TEST_001".to_string(),
949            ErrorSeverity::Critical,
950            context.clone(),
951        );
952
953        manager.create_error(
954            ErrorType::DataLoading,
955            "Test error 2".to_string(),
956            "TEST_002".to_string(),
957            ErrorSeverity::Medium,
958            context.clone(),
959        );
960
961        manager.create_error(
962            ErrorType::Training,
963            "Test error 3".to_string(),
964            "TEST_003".to_string(),
965            ErrorSeverity::High,
966            context,
967        );
968
969        let stats = manager.get_statistics().expect("operation failed in test");
970        assert_eq!(stats.total_errors, 3);
971        assert_eq!(
972            *stats
973                .errors_by_type
974                .get(&ErrorType::Training)
975                .expect("expected value not found"),
976            2
977        );
978        assert_eq!(
979            *stats
980                .errors_by_type
981                .get(&ErrorType::DataLoading)
982                .expect("expected value not found"),
983            1
984        );
985        assert_eq!(
986            *stats
987                .errors_by_severity
988                .get(&ErrorSeverity::Critical)
989                .expect("expected value not found"),
990            1
991        );
992        assert_eq!(
993            *stats
994                .errors_by_severity
995                .get(&ErrorSeverity::Medium)
996                .expect("expected value not found"),
997            1
998        );
999        assert_eq!(
1000            *stats
1001                .errors_by_severity
1002                .get(&ErrorSeverity::High)
1003                .expect("expected value not found"),
1004            1
1005        );
1006    }
1007
1008    #[test]
1009    fn test_recent_errors() {
1010        let manager = ErrorManager::new();
1011
1012        let context = create_context!("trainer", "test");
1013
1014        // Create 5 errors
1015        for i in 0..5 {
1016            manager.create_error(
1017                ErrorType::Training,
1018                format!("Test error {}", i),
1019                format!("TEST_{:03}", i),
1020                ErrorSeverity::Medium,
1021                context.clone(),
1022            );
1023        }
1024
1025        let recent_errors = manager.get_recent_errors(3).expect("operation failed in test");
1026        assert_eq!(recent_errors.len(), 3);
1027
1028        // Should be in reverse chronological order (most recent first)
1029        assert_eq!(recent_errors[0].error_code, "TEST_004");
1030        assert_eq!(recent_errors[1].error_code, "TEST_003");
1031        assert_eq!(recent_errors[2].error_code, "TEST_002");
1032    }
1033
1034    #[test]
1035    fn test_error_export() {
1036        let manager = ErrorManager::new();
1037
1038        let context = create_context!("trainer", "test");
1039        manager.create_error(
1040            ErrorType::Training,
1041            "Test error".to_string(),
1042            "TEST_001".to_string(),
1043            ErrorSeverity::Medium,
1044            context,
1045        );
1046
1047        let json_export = manager.export_errors().expect("operation failed in test");
1048        assert!(!json_export.is_empty());
1049        assert!(json_export.contains("TEST_001"));
1050        assert!(json_export.contains("Test error"));
1051    }
1052}