1use anyhow::{Context as AnyhowContext, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt;
5use std::sync::{Arc, RwLock};
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use crate::error_codes::{get_error_info, is_critical_error};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TrainingError {
13 pub error_type: ErrorType,
14 pub message: String,
15 pub error_code: String,
16 pub severity: ErrorSeverity,
17 pub context: ErrorContext,
18 pub timestamp: u64,
19 pub recovery_suggestions: Vec<RecoverySuggestion>,
20 pub related_errors: Vec<String>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, Eq, Hash, PartialEq)]
24pub enum ErrorType {
25 Configuration,
26 DataLoading,
27 ModelInitialization,
28 Training,
29 Validation,
30 Checkpoint,
31 Resource,
32 Network,
33 Hardware,
34 UserInput,
35 Internal,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, Eq, Hash, PartialEq)]
39pub enum ErrorSeverity {
40 Critical, High, Medium, Low, }
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ErrorContext {
48 pub component: String,
49 pub operation: String,
50 pub epoch: Option<u32>,
51 pub step: Option<u32>,
52 pub batch_size: Option<usize>,
53 pub learning_rate: Option<f64>,
54 pub model_state: Option<String>,
55 pub system_info: SystemInfo,
56 pub additional_data: HashMap<String, String>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct SystemInfo {
61 pub memory_usage: Option<u64>,
62 pub gpu_memory_usage: Option<u64>,
63 pub cpu_usage: Option<f32>,
64 pub disk_space: Option<u64>,
65 pub network_status: Option<String>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct RecoverySuggestion {
70 pub action: String,
71 pub description: String,
72 pub priority: u8, pub automatic: bool, }
75
76impl fmt::Display for TrainingError {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 write!(
79 f,
80 "[{}] {}: {} ({})",
81 self.severity, self.error_type, self.message, self.error_code
82 )
83 }
84}
85
86impl fmt::Display for ErrorType {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 match self {
89 ErrorType::Configuration => write!(f, "CONFIGURATION"),
90 ErrorType::DataLoading => write!(f, "DATA_LOADING"),
91 ErrorType::ModelInitialization => write!(f, "MODEL_INIT"),
92 ErrorType::Training => write!(f, "TRAINING"),
93 ErrorType::Validation => write!(f, "VALIDATION"),
94 ErrorType::Checkpoint => write!(f, "CHECKPOINT"),
95 ErrorType::Resource => write!(f, "RESOURCE"),
96 ErrorType::Network => write!(f, "NETWORK"),
97 ErrorType::Hardware => write!(f, "HARDWARE"),
98 ErrorType::UserInput => write!(f, "USER_INPUT"),
99 ErrorType::Internal => write!(f, "INTERNAL"),
100 }
101 }
102}
103
104impl fmt::Display for ErrorSeverity {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 ErrorSeverity::Critical => write!(f, "CRITICAL"),
108 ErrorSeverity::High => write!(f, "HIGH"),
109 ErrorSeverity::Medium => write!(f, "MEDIUM"),
110 ErrorSeverity::Low => write!(f, "LOW"),
111 }
112 }
113}
114
115impl std::error::Error for TrainingError {}
116
117pub struct ErrorManager {
119 errors: Arc<RwLock<Vec<TrainingError>>>,
120 error_patterns: Arc<RwLock<HashMap<String, ErrorPattern>>>,
121 recovery_strategies: Arc<RwLock<HashMap<ErrorType, Vec<RecoveryStrategy>>>>,
122 statistics: Arc<RwLock<ErrorStatistics>>,
123}
124
125#[derive(Debug, Clone)]
126pub struct ErrorPattern {
127 pub pattern_id: String,
128 pub error_codes: Vec<String>,
129 pub frequency_threshold: u32,
130 pub time_window_seconds: u64,
131 pub suggested_actions: Vec<RecoverySuggestion>,
132}
133
134#[derive(Debug, Clone)]
135pub struct RecoveryStrategy {
136 pub strategy_id: String,
137 pub name: String,
138 pub applicable_errors: Vec<String>,
139 pub handler: fn(&TrainingError) -> Result<RecoveryAction>,
140 pub auto_apply: bool,
141}
142
143#[derive(Debug, Clone)]
144pub enum RecoveryAction {
145 Continue,
146 Retry {
147 max_attempts: u32,
148 },
149 Restart {
150 checkpoint: Option<String>,
151 },
152 Abort,
153 ReduceResources {
154 factor: f32,
155 },
156 ChangeConfiguration {
157 config_changes: HashMap<String, String>,
158 },
159 SwitchFallback {
160 fallback_config: String,
161 },
162}
163
164#[derive(Debug, Default, Clone)]
165pub struct ErrorStatistics {
166 pub total_errors: u64,
167 pub errors_by_type: HashMap<ErrorType, u64>,
168 pub errors_by_severity: HashMap<ErrorSeverity, u64>,
169 pub errors_by_component: HashMap<String, u64>,
170 pub recovery_success_rate: f64,
171 pub most_common_errors: Vec<(String, u64)>,
172 pub error_trends: Vec<ErrorTrend>,
173}
174
175#[derive(Debug, Clone)]
176pub struct ErrorTrend {
177 pub timestamp: u64,
178 pub error_count: u64,
179 pub error_rate: f64, }
181
182impl Default for ErrorManager {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl ErrorManager {
189 pub fn new() -> Self {
190 Self {
191 errors: Arc::new(RwLock::new(Vec::new())),
192 error_patterns: Arc::new(RwLock::new(HashMap::new())),
193 recovery_strategies: Arc::new(RwLock::new(HashMap::new())),
194 statistics: Arc::new(RwLock::new(ErrorStatistics::default())),
195 }
196 }
197
198 pub fn record_error(&self, error: TrainingError) -> Result<()> {
200 {
202 let mut errors = self
203 .errors
204 .write()
205 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on errors"))?;
206 errors.push(error.clone());
207 }
208
209 self.update_statistics(&error)?;
211
212 self.check_error_patterns(&error)?;
214
215 self.attempt_recovery(&error)?;
217
218 Ok(())
219 }
220
221 pub fn create_error(
223 &self,
224 error_type: ErrorType,
225 message: String,
226 error_code: String,
227 severity: ErrorSeverity,
228 context: ErrorContext,
229 ) -> TrainingError {
230 let error = TrainingError {
231 error_type: error_type.clone(),
232 message,
233 error_code: error_code.clone(),
234 severity: severity.clone(),
235 context,
236 timestamp: SystemTime::now()
237 .duration_since(UNIX_EPOCH)
238 .expect("SystemTime should be after UNIX_EPOCH")
239 .as_secs(),
240 recovery_suggestions: self.get_recovery_suggestions(&error_type, &error_code),
241 related_errors: Vec::new(),
242 };
243
244 if let Err(e) = self.record_error(error.clone()) {
245 eprintln!("Failed to record error: {}", e);
246 }
247
248 error
249 }
250
251 pub fn add_error_pattern(&self, pattern: ErrorPattern) -> Result<()> {
253 let mut patterns = self
254 .error_patterns
255 .write()
256 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on error patterns"))?;
257 patterns.insert(pattern.pattern_id.clone(), pattern);
258 Ok(())
259 }
260
261 pub fn add_recovery_strategy(
263 &self,
264 error_type: ErrorType,
265 strategy: RecoveryStrategy,
266 ) -> Result<()> {
267 let mut strategies = self
268 .recovery_strategies
269 .write()
270 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on recovery strategies"))?;
271 strategies.entry(error_type).or_insert_with(Vec::new).push(strategy);
272 Ok(())
273 }
274
275 fn update_statistics(&self, error: &TrainingError) -> Result<()> {
276 let mut stats = self
277 .statistics
278 .write()
279 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on statistics"))?;
280
281 stats.total_errors += 1;
282 *stats.errors_by_type.entry(error.error_type.clone()).or_insert(0) += 1;
283 *stats.errors_by_severity.entry(error.severity.clone()).or_insert(0) += 1;
284 *stats.errors_by_component.entry(error.context.component.clone()).or_insert(0) += 1;
285
286 let current_time = SystemTime::now()
288 .duration_since(UNIX_EPOCH)
289 .expect("SystemTime should be after UNIX_EPOCH")
290 .as_secs();
291
292 let errors = self
294 .errors
295 .read()
296 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on errors"))?;
297 let recent_errors =
298 errors.iter().filter(|e| current_time - e.timestamp <= 60).count() as u64;
299
300 stats.error_trends.push(ErrorTrend {
301 timestamp: current_time,
302 error_count: recent_errors,
303 error_rate: recent_errors as f64 / 60.0,
304 });
305
306 if stats.error_trends.len() > 100 {
308 stats.error_trends.remove(0);
309 }
310
311 Ok(())
312 }
313
314 fn check_error_patterns(&self, error: &TrainingError) -> Result<()> {
315 let patterns = self
316 .error_patterns
317 .read()
318 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on error patterns"))?;
319
320 for pattern in patterns.values() {
321 if pattern.error_codes.contains(&error.error_code) {
322 let recent_matching_errors = self.count_recent_matching_errors(pattern)?;
324
325 if recent_matching_errors >= pattern.frequency_threshold {
326 println!(
327 "๐จ Error pattern detected: {} (occurred {} times)",
328 pattern.pattern_id, recent_matching_errors
329 );
330
331 for suggestion in &pattern.suggested_actions {
333 println!(
334 "๐ก Suggestion: {} - {}",
335 suggestion.action, suggestion.description
336 );
337 }
338 }
339 }
340 }
341
342 Ok(())
343 }
344
345 fn count_recent_matching_errors(&self, pattern: &ErrorPattern) -> Result<u32> {
346 let errors = self
347 .errors
348 .read()
349 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on errors"))?;
350
351 let current_time = SystemTime::now()
352 .duration_since(UNIX_EPOCH)
353 .expect("SystemTime should be after UNIX_EPOCH")
354 .as_secs();
355
356 let count = errors
357 .iter()
358 .filter(|error| {
359 current_time - error.timestamp <= pattern.time_window_seconds
360 && pattern.error_codes.contains(&error.error_code)
361 })
362 .count() as u32;
363
364 Ok(count)
365 }
366
367 fn attempt_recovery(&self, error: &TrainingError) -> Result<()> {
368 if is_critical_error(&error.error_code) {
370 println!(
371 "๐จ Critical error detected: {} - Manual intervention required",
372 error.error_code
373 );
374 return Ok(()); }
376
377 let strategies = self
378 .recovery_strategies
379 .read()
380 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on recovery strategies"))?;
381
382 if let Some(success) = self.try_builtin_recovery(error)? {
384 if success {
385 println!(
386 "โ
Built-in recovery successful for error: {}",
387 error.error_code
388 );
389 return Ok(());
390 }
391 }
392
393 if let Some(type_strategies) = strategies.get(&error.error_type) {
395 for strategy in type_strategies {
396 if strategy.auto_apply && strategy.applicable_errors.contains(&error.error_code) {
397 println!("๐ง Attempting automatic recovery: {}", strategy.name);
398
399 match (strategy.handler)(error) {
400 Ok(action) => {
401 println!("โ
Recovery action determined: {:?}", action);
402
403 if let Err(e) = self.execute_recovery_action(&action, error) {
405 println!("โ Failed to execute recovery action: {}", e);
406 continue;
407 }
408
409 println!("โ
Recovery action executed successfully");
410 return Ok(());
411 },
412 Err(e) => {
413 println!("โ Recovery strategy failed: {}", e);
414 },
415 }
416 }
417 }
418 }
419
420 self.suggest_manual_recovery(error);
422 Ok(())
423 }
424
425 fn try_builtin_recovery(&self, error: &TrainingError) -> Result<Option<bool>> {
427 match error.error_code.as_str() {
428 "RESOURCE_OOM" | "RESOURCE_GPU_OOM" => {
429 println!("๐ง Attempting memory recovery for OOM error");
430 self.simulate_memory_cleanup()?;
432 Ok(Some(true))
433 },
434 "TRAIN_NAN_LOSS" | "TRAIN_INF_LOSS" => {
435 println!("๐ง Attempting numerical stability recovery");
436 self.suggest_numerical_fixes(error)?;
438 Ok(Some(false)) },
440 "DATA_FILE_NOT_FOUND" => {
441 println!("๐ง Attempting data path recovery");
442 self.suggest_data_path_fixes(error)?;
444 Ok(Some(false))
445 },
446 "NETWORK_CONNECTION_TIMEOUT" => {
447 println!("๐ง Attempting network recovery");
448 self.attempt_network_retry(error)?;
450 Ok(Some(true))
451 },
452 _ => Ok(None), }
454 }
455
456 fn execute_recovery_action(
458 &self,
459 action: &RecoveryAction,
460 _error: &TrainingError,
461 ) -> Result<()> {
462 match action {
463 RecoveryAction::Continue => {
464 println!("๐ Recovery action: Continue training");
465 Ok(())
466 },
467 RecoveryAction::Retry { max_attempts } => {
468 println!(
469 "๐ Recovery action: Retry operation (max {} attempts)",
470 max_attempts
471 );
472 Ok(())
474 },
475 RecoveryAction::Restart { checkpoint } => {
476 println!(
477 "๐ Recovery action: Restart from checkpoint: {:?}",
478 checkpoint
479 );
480 Ok(())
482 },
483 RecoveryAction::Abort => {
484 println!("๐ Recovery action: Abort training");
485 Err(anyhow::anyhow!(
486 "Training aborted due to unrecoverable error"
487 ))
488 },
489 RecoveryAction::ReduceResources { factor } => {
490 println!("๐ Recovery action: Reduce resources by factor {}", factor);
491 Ok(())
493 },
494 RecoveryAction::ChangeConfiguration { config_changes } => {
495 println!(
496 "๐ Recovery action: Change configuration: {:?}",
497 config_changes
498 );
499 Ok(())
501 },
502 RecoveryAction::SwitchFallback { fallback_config } => {
503 println!(
504 "๐ Recovery action: Switch to fallback configuration: {}",
505 fallback_config
506 );
507 Ok(())
509 },
510 }
511 }
512
513 fn simulate_memory_cleanup(&self) -> Result<()> {
515 println!("๐งน Simulating memory cleanup...");
516 println!(" - Clearing unused tensors");
517 println!(" - Running garbage collection");
518 println!(" - Reducing batch size temporarily");
519 Ok(())
520 }
521
522 fn suggest_numerical_fixes(&self, error: &TrainingError) -> Result<()> {
524 println!("๐ก Numerical stability suggestions:");
525 println!(" - Reduce learning rate by factor of 10");
526 println!(" - Enable gradient clipping (max_norm=1.0)");
527 println!(" - Check input data normalization");
528 println!(" - Consider using mixed precision training");
529
530 if let Some(lr) = error.context.learning_rate {
531 println!(" - Current learning rate: {}, suggested: {}", lr, lr * 0.1);
532 }
533
534 Ok(())
535 }
536
537 fn suggest_data_path_fixes(&self, _error: &TrainingError) -> Result<()> {
539 println!("๐ก Data path suggestions:");
540 println!(" - Check if file path is correct");
541 println!(" - Verify file permissions");
542 println!(" - Try relative vs absolute paths");
543 println!(" - Check if data is in expected location");
544 Ok(())
545 }
546
547 fn attempt_network_retry(&self, _error: &TrainingError) -> Result<()> {
549 println!("๐ Attempting network retry with exponential backoff...");
550
551 for attempt in 1..=3 {
552 println!(" Attempt {}/3", attempt);
553
554 std::thread::sleep(std::time::Duration::from_millis(100 * (1 << attempt)));
556
557 if fastrand::bool() {
559 println!(" โ
Network operation succeeded");
560 return Ok(());
561 }
562
563 println!(" โ Network operation failed, retrying...");
564 }
565
566 Err(anyhow::anyhow!("Network operation failed after 3 attempts"))
567 }
568
569 fn suggest_manual_recovery(&self, error: &TrainingError) {
571 println!("๐ง Manual recovery suggestions for {}:", error.error_code);
572
573 for suggestion in &error.recovery_suggestions {
574 println!(
575 " {} (Priority: {}) - {}",
576 if suggestion.automatic { "๐ค AUTO" } else { "๐ค MANUAL" },
577 suggestion.priority,
578 suggestion.action
579 );
580 println!(" ๐ {}", suggestion.description);
581 }
582
583 if let Some(epoch) = error.context.epoch {
585 println!(" ๐ Error occurred at epoch {}", epoch);
586 if epoch < 5 {
587 println!(
588 " ๐ก Early training failure - check data loading and model initialization"
589 );
590 }
591 }
592
593 if let Some(step) = error.context.step {
594 println!(" ๐ Error occurred at step {}", step);
595 }
596 }
597
598 fn get_recovery_suggestions(
599 &self,
600 error_type: &ErrorType,
601 error_code: &str,
602 ) -> Vec<RecoverySuggestion> {
603 if let Some(error_info) = get_error_info(error_code) {
605 return error_info
606 .solutions
607 .iter()
608 .enumerate()
609 .map(|(i, solution)| {
610 RecoverySuggestion {
611 action: solution.to_string(),
612 description: format!(
613 "See documentation: {}",
614 error_info
615 .documentation_url
616 .unwrap_or("https://docs.trustformers.rs/errors")
617 ),
618 priority: 10 - i as u8, automatic: error_info.severity != "CRITICAL", }
621 })
622 .collect();
623 }
624
625 match error_type {
627 ErrorType::Configuration => vec![
628 RecoverySuggestion {
629 action: "Check configuration file".to_string(),
630 description: "Verify that all required parameters are set correctly"
631 .to_string(),
632 priority: 9,
633 automatic: false,
634 },
635 RecoverySuggestion {
636 action: "Use default configuration".to_string(),
637 description: "Fall back to known good default settings".to_string(),
638 priority: 7,
639 automatic: true,
640 },
641 ],
642 ErrorType::DataLoading => vec![
643 RecoverySuggestion {
644 action: "Check data path".to_string(),
645 description: "Verify that the dataset path exists and is accessible"
646 .to_string(),
647 priority: 9,
648 automatic: false,
649 },
650 RecoverySuggestion {
651 action: "Reduce batch size".to_string(),
652 description: "Try reducing batch size to avoid memory issues".to_string(),
653 priority: 8,
654 automatic: true,
655 },
656 ],
657 ErrorType::Training => vec![
658 RecoverySuggestion {
659 action: "Reduce learning rate".to_string(),
660 description: "Lower the learning rate to stabilize training".to_string(),
661 priority: 8,
662 automatic: true,
663 },
664 RecoverySuggestion {
665 action: "Check for NaN/Inf values".to_string(),
666 description: "Inspect model weights and gradients for numerical issues"
667 .to_string(),
668 priority: 9,
669 automatic: false,
670 },
671 ],
672 ErrorType::Resource => vec![
673 RecoverySuggestion {
674 action: "Free up memory".to_string(),
675 description: "Clear unused variables and reduce model size".to_string(),
676 priority: 9,
677 automatic: true,
678 },
679 RecoverySuggestion {
680 action: "Use gradient checkpointing".to_string(),
681 description: "Enable gradient checkpointing to reduce memory usage".to_string(),
682 priority: 7,
683 automatic: true,
684 },
685 ],
686 _ => vec![RecoverySuggestion {
687 action: "Restart training".to_string(),
688 description: "Restart training from the last checkpoint".to_string(),
689 priority: 5,
690 automatic: false,
691 }],
692 }
693 }
694
695 pub fn get_statistics(&self) -> Result<ErrorStatistics> {
697 let stats = self
698 .statistics
699 .read()
700 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on statistics"))?;
701 Ok((*stats).clone())
702 }
703
704 pub fn get_recent_errors(&self, limit: usize) -> Result<Vec<TrainingError>> {
706 let errors = self
707 .errors
708 .read()
709 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on errors"))?;
710
711 let recent: Vec<_> = errors.iter().rev().take(limit).cloned().collect();
712
713 Ok(recent)
714 }
715
716 pub fn clear_errors(&self) -> Result<()> {
718 let mut errors = self
719 .errors
720 .write()
721 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on errors"))?;
722 errors.clear();
723 Ok(())
724 }
725
726 pub fn export_errors(&self) -> Result<String> {
728 let errors = self
729 .errors
730 .read()
731 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on errors"))?;
732
733 serde_json::to_string_pretty(&*errors).context("Failed to serialize errors to JSON")
734 }
735}
736
737#[macro_export]
739macro_rules! training_error {
740 ($error_manager:expr, $error_type:expr, $message:expr, $error_code:expr, $severity:expr, $context:expr) => {
741 $error_manager.create_error(
742 $error_type,
743 $message.to_string(),
744 $error_code.to_string(),
745 $severity,
746 $context,
747 )
748 };
749}
750
751#[macro_export]
752macro_rules! create_context {
753 ($component:expr, $operation:expr) => {
754 ErrorContext {
755 component: $component.to_string(),
756 operation: $operation.to_string(),
757 epoch: None,
758 step: None,
759 batch_size: None,
760 learning_rate: None,
761 model_state: None,
762 system_info: SystemInfo {
763 memory_usage: None,
764 gpu_memory_usage: None,
765 cpu_usage: None,
766 disk_space: None,
767 network_status: None,
768 },
769 additional_data: HashMap::new(),
770 }
771 };
772 ($component:expr, $operation:expr, epoch: $epoch:expr, step: $step:expr) => {
773 ErrorContext {
774 component: $component.to_string(),
775 operation: $operation.to_string(),
776 epoch: Some($epoch),
777 step: Some($step),
778 batch_size: None,
779 learning_rate: None,
780 model_state: None,
781 system_info: SystemInfo {
782 memory_usage: None,
783 gpu_memory_usage: None,
784 cpu_usage: None,
785 disk_space: None,
786 network_status: None,
787 },
788 additional_data: HashMap::new(),
789 }
790 };
791}
792
793pub type TrainingResult<T> = Result<T, TrainingError>;
795
796pub trait TrainingErrorExt<T> {
798 fn training_error(
799 self,
800 error_manager: &ErrorManager,
801 error_type: ErrorType,
802 error_code: &str,
803 severity: ErrorSeverity,
804 context: ErrorContext,
805 ) -> TrainingResult<T>;
806}
807
808impl<T> TrainingErrorExt<T> for Result<T> {
809 fn training_error(
810 self,
811 error_manager: &ErrorManager,
812 error_type: ErrorType,
813 error_code: &str,
814 severity: ErrorSeverity,
815 context: ErrorContext,
816 ) -> TrainingResult<T> {
817 match self {
818 Ok(value) => Ok(value),
819 Err(e) => {
820 let training_error = error_manager.create_error(
821 error_type,
822 e.to_string(),
823 error_code.to_string(),
824 severity,
825 context,
826 );
827 Err(training_error)
828 },
829 }
830 }
831}
832
833#[cfg(test)]
834mod tests {
835 use super::*;
836
837 #[test]
838 fn test_error_manager_creation() {
839 let manager = ErrorManager::new();
840 let stats = manager.get_statistics().expect("operation failed in test");
841 assert_eq!(stats.total_errors, 0);
842 }
843
844 #[test]
845 fn test_error_recording() {
846 let manager = ErrorManager::new();
847
848 let context = ErrorContext {
849 component: "trainer".to_string(),
850 operation: "forward_pass".to_string(),
851 epoch: Some(1),
852 step: Some(100),
853 batch_size: Some(32),
854 learning_rate: Some(0.001),
855 model_state: None,
856 system_info: SystemInfo {
857 memory_usage: Some(1024),
858 gpu_memory_usage: Some(512),
859 cpu_usage: Some(50.0),
860 disk_space: None,
861 network_status: None,
862 },
863 additional_data: HashMap::new(),
864 };
865
866 let error = manager.create_error(
867 ErrorType::Training,
868 "NaN detected in loss".to_string(),
869 "TRAINING_NAN_LOSS".to_string(),
870 ErrorSeverity::Critical,
871 context,
872 );
873
874 assert_eq!(error.error_type, ErrorType::Training);
875 assert_eq!(error.error_code, "TRAINING_NAN_LOSS");
876 assert_eq!(error.severity, ErrorSeverity::Critical);
877 assert!(!error.recovery_suggestions.is_empty());
878
879 let stats = manager.get_statistics().expect("operation failed in test");
880 assert_eq!(stats.total_errors, 1);
881 assert_eq!(
882 *stats
883 .errors_by_type
884 .get(&ErrorType::Training)
885 .expect("expected value not found"),
886 1
887 );
888 }
889
890 #[test]
891 fn test_error_pattern_detection() {
892 let manager = ErrorManager::new();
893
894 let pattern = ErrorPattern {
895 pattern_id: "frequent_oom".to_string(),
896 error_codes: vec!["RESOURCE_OOM".to_string()],
897 frequency_threshold: 3,
898 time_window_seconds: 300,
899 suggested_actions: vec![RecoverySuggestion {
900 action: "Reduce batch size".to_string(),
901 description: "Lower batch size to reduce memory usage".to_string(),
902 priority: 9,
903 automatic: true,
904 }],
905 };
906
907 manager.add_error_pattern(pattern).expect("add operation failed");
908
909 let context = create_context!("trainer", "forward_pass");
911 for _ in 0..3 {
912 manager.create_error(
913 ErrorType::Resource,
914 "Out of memory".to_string(),
915 "RESOURCE_OOM".to_string(),
916 ErrorSeverity::Critical,
917 context.clone(),
918 );
919 }
920
921 let stats = manager.get_statistics().expect("operation failed in test");
923 assert_eq!(stats.total_errors, 3);
924 }
925
926 #[test]
927 fn test_recovery_suggestions() {
928 let manager = ErrorManager::new();
929
930 let suggestions =
931 manager.get_recovery_suggestions(&ErrorType::Training, "TRAINING_NAN_LOSS");
932 assert!(!suggestions.is_empty());
933
934 let has_lr_suggestion = suggestions.iter().any(|s| s.action.contains("learning rate"));
935 assert!(has_lr_suggestion);
936 }
937
938 #[test]
939 fn test_error_statistics() {
940 let manager = ErrorManager::new();
941
942 let context = create_context!("trainer", "test");
944
945 manager.create_error(
946 ErrorType::Training,
947 "Test error 1".to_string(),
948 "TEST_001".to_string(),
949 ErrorSeverity::Critical,
950 context.clone(),
951 );
952
953 manager.create_error(
954 ErrorType::DataLoading,
955 "Test error 2".to_string(),
956 "TEST_002".to_string(),
957 ErrorSeverity::Medium,
958 context.clone(),
959 );
960
961 manager.create_error(
962 ErrorType::Training,
963 "Test error 3".to_string(),
964 "TEST_003".to_string(),
965 ErrorSeverity::High,
966 context,
967 );
968
969 let stats = manager.get_statistics().expect("operation failed in test");
970 assert_eq!(stats.total_errors, 3);
971 assert_eq!(
972 *stats
973 .errors_by_type
974 .get(&ErrorType::Training)
975 .expect("expected value not found"),
976 2
977 );
978 assert_eq!(
979 *stats
980 .errors_by_type
981 .get(&ErrorType::DataLoading)
982 .expect("expected value not found"),
983 1
984 );
985 assert_eq!(
986 *stats
987 .errors_by_severity
988 .get(&ErrorSeverity::Critical)
989 .expect("expected value not found"),
990 1
991 );
992 assert_eq!(
993 *stats
994 .errors_by_severity
995 .get(&ErrorSeverity::Medium)
996 .expect("expected value not found"),
997 1
998 );
999 assert_eq!(
1000 *stats
1001 .errors_by_severity
1002 .get(&ErrorSeverity::High)
1003 .expect("expected value not found"),
1004 1
1005 );
1006 }
1007
1008 #[test]
1009 fn test_recent_errors() {
1010 let manager = ErrorManager::new();
1011
1012 let context = create_context!("trainer", "test");
1013
1014 for i in 0..5 {
1016 manager.create_error(
1017 ErrorType::Training,
1018 format!("Test error {}", i),
1019 format!("TEST_{:03}", i),
1020 ErrorSeverity::Medium,
1021 context.clone(),
1022 );
1023 }
1024
1025 let recent_errors = manager.get_recent_errors(3).expect("operation failed in test");
1026 assert_eq!(recent_errors.len(), 3);
1027
1028 assert_eq!(recent_errors[0].error_code, "TEST_004");
1030 assert_eq!(recent_errors[1].error_code, "TEST_003");
1031 assert_eq!(recent_errors[2].error_code, "TEST_002");
1032 }
1033
1034 #[test]
1035 fn test_error_export() {
1036 let manager = ErrorManager::new();
1037
1038 let context = create_context!("trainer", "test");
1039 manager.create_error(
1040 ErrorType::Training,
1041 "Test error".to_string(),
1042 "TEST_001".to_string(),
1043 ErrorSeverity::Medium,
1044 context,
1045 );
1046
1047 let json_export = manager.export_errors().expect("operation failed in test");
1048 assert!(!json_export.is_empty());
1049 assert!(json_export.contains("TEST_001"));
1050 assert!(json_export.contains("Test error"));
1051 }
1052}