sklears_utils/
error_handling.rs

1//! Enhanced error handling utilities for machine learning workflows
2//!
3//! This module provides comprehensive error handling with context, stack traces,
4//! error aggregation, and structured error reporting.
5
6use crate::{UtilsError, UtilsResult};
7use std::collections::HashMap;
8use std::fmt;
9use std::sync::{Arc, Mutex};
10
11// ===== ENHANCED ERROR TYPES =====
12
13/// Enhanced error with context and stack trace information
14#[derive(Debug, Clone)]
15pub struct EnhancedError {
16    pub error: UtilsError,
17    pub context: Vec<String>,
18    pub stack_trace: Vec<String>,
19    pub timestamp: std::time::SystemTime,
20    pub error_id: String,
21    pub metadata: HashMap<String, String>,
22}
23
24impl EnhancedError {
25    /// Create a new enhanced error
26    pub fn new(error: UtilsError) -> Self {
27        Self {
28            error,
29            context: Vec::new(),
30            stack_trace: Self::capture_stack_trace(),
31            timestamp: std::time::SystemTime::now(),
32            error_id: Self::generate_error_id(),
33            metadata: HashMap::new(),
34        }
35    }
36
37    /// Add context to the error
38    pub fn with_context<S: Into<String>>(mut self, context: S) -> Self {
39        self.context.push(context.into());
40        self
41    }
42
43    /// Add metadata to the error
44    pub fn with_metadata<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
45        self.metadata.insert(key.into(), value.into());
46        self
47    }
48
49    /// Generate a unique error ID
50    fn generate_error_id() -> String {
51        use std::time::{SystemTime, UNIX_EPOCH};
52        let timestamp = SystemTime::now()
53            .duration_since(UNIX_EPOCH)
54            .unwrap_or_default()
55            .as_nanos();
56        format!("ERR-{timestamp:016x}")
57    }
58
59    /// Capture stack trace (simplified implementation)
60    fn capture_stack_trace() -> Vec<String> {
61        // In a real implementation, you might use backtrace crate
62        // For now, we'll create a simple stack trace
63        vec![
64            "stack_trace: enhanced_error.rs:capture_stack_trace".to_string(),
65            "stack_trace: error_handling.rs:new".to_string(),
66        ]
67    }
68
69    /// Format the error for display
70    pub fn format_detailed(&self) -> String {
71        let mut output = String::new();
72
73        output.push_str(&format!("Error ID: {}\n", self.error_id));
74        output.push_str(&format!("Timestamp: {:?}\n", self.timestamp));
75        output.push_str(&format!("Error: {}\n", self.error));
76
77        if !self.context.is_empty() {
78            output.push_str("Context:\n");
79            for (i, ctx) in self.context.iter().enumerate() {
80                output.push_str(&format!("  {}: {ctx}\n", i + 1));
81            }
82        }
83
84        if !self.metadata.is_empty() {
85            output.push_str("Metadata:\n");
86            for (key, value) in &self.metadata {
87                output.push_str(&format!("  {key}: {value}\n"));
88            }
89        }
90
91        if !self.stack_trace.is_empty() {
92            output.push_str("Stack Trace:\n");
93            for frame in &self.stack_trace {
94                output.push_str(&format!("  {frame}\n"));
95            }
96        }
97
98        output
99    }
100}
101
102impl fmt::Display for EnhancedError {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        write!(f, "{} (ID: {})", self.error, self.error_id)
105    }
106}
107
108impl std::error::Error for EnhancedError {
109    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
110        Some(&self.error)
111    }
112}
113
114// ===== ERROR CONTEXT BUILDER =====
115
116/// Builder for adding context to errors
117pub struct ErrorContext {
118    operation: String,
119    parameters: HashMap<String, String>,
120    location: Option<String>,
121}
122
123impl ErrorContext {
124    /// Create a new error context
125    pub fn new<S: Into<String>>(operation: S) -> Self {
126        Self {
127            operation: operation.into(),
128            parameters: HashMap::new(),
129            location: None,
130        }
131    }
132
133    /// Add a parameter to the context
134    pub fn with_param<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
135        self.parameters.insert(key.into(), value.into());
136        self
137    }
138
139    /// Add location information
140    pub fn at_location<S: Into<String>>(mut self, location: S) -> Self {
141        self.location = Some(location.into());
142        self
143    }
144
145    /// Wrap an error with this context
146    pub fn wrap_error<E: Into<UtilsError>>(self, error: E) -> EnhancedError {
147        let mut enhanced = EnhancedError::new(error.into());
148
149        enhanced = enhanced.with_context(format!("Operation: {}", self.operation));
150
151        if let Some(location) = self.location {
152            enhanced = enhanced.with_context(format!("Location: {location}"));
153        }
154
155        for (key, value) in self.parameters {
156            enhanced = enhanced.with_metadata(key, value);
157        }
158
159        enhanced
160    }
161}
162
163// ===== ERROR AGGREGATION =====
164
165/// Aggregates multiple errors into a single report
166#[derive(Debug, Clone)]
167pub struct ErrorAggregator {
168    errors: Vec<EnhancedError>,
169    max_errors: usize,
170    continue_on_error: bool,
171}
172
173impl ErrorAggregator {
174    /// Create a new error aggregator
175    pub fn new(max_errors: usize, continue_on_error: bool) -> Self {
176        Self {
177            errors: Vec::new(),
178            max_errors,
179            continue_on_error,
180        }
181    }
182
183    /// Add an error to the aggregator
184    pub fn add_error(&mut self, error: EnhancedError) -> UtilsResult<()> {
185        self.errors.push(error);
186
187        if self.errors.len() >= self.max_errors {
188            if self.continue_on_error {
189                // Remove oldest error to make room
190                self.errors.remove(0);
191            } else {
192                return Err(UtilsError::InvalidParameter(
193                    "Maximum error count reached".to_string(),
194                ));
195            }
196        }
197
198        Ok(())
199    }
200
201    /// Check if there are any errors
202    pub fn has_errors(&self) -> bool {
203        !self.errors.is_empty()
204    }
205
206    /// Get the number of errors
207    pub fn error_count(&self) -> usize {
208        self.errors.len()
209    }
210
211    /// Get all errors
212    pub fn get_errors(&self) -> &[EnhancedError] {
213        &self.errors
214    }
215
216    /// Clear all errors
217    pub fn clear(&mut self) {
218        self.errors.clear();
219    }
220
221    /// Generate an error summary
222    pub fn generate_summary(&self) -> ErrorSummary {
223        let mut summary = ErrorSummary::default();
224
225        for error in &self.errors {
226            summary.total_errors += 1;
227
228            match &error.error {
229                UtilsError::ShapeMismatch { .. } => summary.shape_errors += 1,
230                UtilsError::InvalidParameter(_) => summary.parameter_errors += 1,
231                UtilsError::EmptyInput => summary.input_errors += 1,
232                UtilsError::InvalidRandomState(_) => summary.random_state_errors += 1,
233                UtilsError::InsufficientData { .. } => summary.data_errors += 1,
234            }
235        }
236
237        summary
238    }
239
240    /// Export errors to a structured format
241    pub fn export_errors(&self) -> Vec<HashMap<String, String>> {
242        self.errors
243            .iter()
244            .map(|error| {
245                let mut export = HashMap::new();
246                export.insert("id".to_string(), error.error_id.clone());
247                export.insert("error".to_string(), error.error.to_string());
248                export.insert("timestamp".to_string(), format!("{:?}", error.timestamp));
249                export.insert("context".to_string(), error.context.join("; "));
250
251                for (key, value) in &error.metadata {
252                    export.insert(format!("meta_{key}"), value.clone());
253                }
254
255                export
256            })
257            .collect()
258    }
259}
260
261/// Summary of aggregated errors
262#[derive(Debug, Default, Clone)]
263pub struct ErrorSummary {
264    pub total_errors: usize,
265    pub shape_errors: usize,
266    pub parameter_errors: usize,
267    pub input_errors: usize,
268    pub random_state_errors: usize,
269    pub data_errors: usize,
270}
271
272// ===== ERROR RECOVERY STRATEGIES =====
273
274/// Error recovery strategies for common ML scenarios
275pub struct ErrorRecovery;
276
277impl ErrorRecovery {
278    /// Attempt to recover from a shape mismatch error
279    pub fn recover_shape_mismatch(expected: &[usize], actual: &[usize]) -> Option<Vec<usize>> {
280        // Try to suggest a compatible shape
281        if expected.len() == actual.len() {
282            // Same number of dimensions, might be able to reshape
283            let expected_size: usize = expected.iter().product();
284            let actual_size: usize = actual.iter().product();
285
286            if expected_size == actual_size {
287                return Some(expected.to_vec());
288            }
289        }
290
291        // Try to add/remove dimensions
292        if expected.len() == 1 && actual.len() == 2 {
293            // Flatten 2D to 1D
294            let total_size: usize = actual.iter().product();
295            return Some(vec![total_size]);
296        }
297
298        if expected.len() == 2 && actual.len() == 1 {
299            // Reshape 1D to 2D
300            let size = actual[0];
301            // Try to find reasonable 2D shape
302            for i in 1..=(size as f64).sqrt() as usize + 1 {
303                if size % i == 0 {
304                    return Some(vec![i, size / i]);
305                }
306            }
307        }
308
309        None
310    }
311
312    /// Attempt to recover from insufficient data
313    pub fn recover_insufficient_data(
314        required: usize,
315        available: usize,
316    ) -> Option<RecoveryStrategy> {
317        if available == 0 {
318            return Some(RecoveryStrategy::GenerateSyntheticData(required));
319        }
320
321        if available < required {
322            if available >= required / 2 {
323                return Some(RecoveryStrategy::ReduceRequirement(available));
324            } else {
325                return Some(RecoveryStrategy::AugmentData(required - available));
326            }
327        }
328
329        None
330    }
331
332    /// Attempt to fix invalid parameters
333    pub fn recover_invalid_parameter(param_name: &str, param_value: &str) -> Option<String> {
334        match param_name {
335            "n_components" | "n_clusters" | "max_iter" => {
336                // Try to parse as number and ensure it's positive
337                if let Ok(val) = param_value.parse::<i32>() {
338                    if val <= 0 {
339                        return Some("1".to_string());
340                    }
341                }
342                Some("10".to_string())
343            }
344            "random_state" => {
345                // Provide a default random state
346                Some("42".to_string())
347            }
348            "tolerance" | "alpha" | "learning_rate" => {
349                // Ensure positive float
350                if let Ok(val) = param_value.parse::<f64>() {
351                    if val <= 0.0 {
352                        return Some("0.01".to_string());
353                    }
354                }
355                Some("0.01".to_string())
356            }
357            _ => None,
358        }
359    }
360}
361
362/// Recovery strategy recommendations
363#[derive(Debug, Clone)]
364pub enum RecoveryStrategy {
365    GenerateSyntheticData(usize),
366    ReduceRequirement(usize),
367    AugmentData(usize),
368    ReshapeData(Vec<usize>),
369    UseDefaultParameter(String),
370}
371
372// ===== ERROR REPORTING =====
373
374/// Global error reporter for collecting and analyzing errors
375pub struct ErrorReporter {
376    errors: Arc<Mutex<Vec<EnhancedError>>>,
377    enabled: bool,
378}
379
380impl Default for ErrorReporter {
381    fn default() -> Self {
382        Self::new()
383    }
384}
385
386impl ErrorReporter {
387    /// Create a new error reporter
388    pub fn new() -> Self {
389        Self {
390            errors: Arc::new(Mutex::new(Vec::new())),
391            enabled: true,
392        }
393    }
394
395    /// Enable or disable error reporting
396    pub fn set_enabled(&mut self, enabled: bool) {
397        self.enabled = enabled;
398    }
399
400    /// Report an error
401    pub fn report_error(&self, error: EnhancedError) {
402        if !self.enabled {
403            return;
404        }
405
406        if let Ok(mut errors) = self.errors.lock() {
407            errors.push(error);
408
409            // Keep only the last 1000 errors to prevent memory issues
410            if errors.len() > 1000 {
411                errors.remove(0);
412            }
413        }
414    }
415
416    /// Get error statistics
417    pub fn get_statistics(&self) -> Option<ErrorStatistics> {
418        let errors = self.errors.lock().ok()?;
419
420        if errors.is_empty() {
421            return None;
422        }
423
424        let mut stats = ErrorStatistics {
425            total_errors: errors.len(),
426            ..Default::default()
427        };
428
429        for error in errors.iter() {
430            match &error.error {
431                UtilsError::ShapeMismatch { .. } => stats.shape_errors += 1,
432                UtilsError::InvalidParameter(_) => stats.parameter_errors += 1,
433                UtilsError::EmptyInput => stats.input_errors += 1,
434                UtilsError::InvalidRandomState(_) => stats.random_state_errors += 1,
435                UtilsError::InsufficientData { .. } => stats.data_errors += 1,
436            }
437        }
438
439        // Calculate error frequency over time windows
440        let now = std::time::SystemTime::now();
441        let one_hour_ago = now - std::time::Duration::from_secs(3600);
442        let one_day_ago = now - std::time::Duration::from_secs(86400);
443
444        stats.errors_last_hour = errors.iter().filter(|e| e.timestamp > one_hour_ago).count();
445
446        stats.errors_last_day = errors.iter().filter(|e| e.timestamp > one_day_ago).count();
447
448        Some(stats)
449    }
450
451    /// Clear all reported errors
452    pub fn clear(&self) {
453        if let Ok(mut errors) = self.errors.lock() {
454            errors.clear();
455        }
456    }
457}
458
459/// Error statistics
460#[derive(Debug, Default, Clone)]
461pub struct ErrorStatistics {
462    pub total_errors: usize,
463    pub shape_errors: usize,
464    pub parameter_errors: usize,
465    pub input_errors: usize,
466    pub random_state_errors: usize,
467    pub data_errors: usize,
468    pub errors_last_hour: usize,
469    pub errors_last_day: usize,
470}
471
472// ===== CONVENIENCE MACROS AND FUNCTIONS =====
473
474/// Create an enhanced error with context
475pub fn create_error<E: Into<UtilsError>>(error: E, operation: &str) -> EnhancedError {
476    ErrorContext::new(operation).wrap_error(error)
477}
478
479/// Create an enhanced error with context and location
480pub fn create_error_at<E: Into<UtilsError>>(
481    error: E,
482    operation: &str,
483    location: &str,
484) -> EnhancedError {
485    ErrorContext::new(operation)
486        .at_location(location)
487        .wrap_error(error)
488}
489
490#[allow(non_snake_case)]
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_enhanced_error() {
497        let base_error = UtilsError::InvalidParameter("test".to_string());
498        let enhanced = EnhancedError::new(base_error)
499            .with_context("Processing data")
500            .with_metadata("operation", "test_operation");
501
502        assert!(!enhanced.error_id.is_empty());
503        assert_eq!(enhanced.context.len(), 1);
504        assert_eq!(enhanced.metadata.len(), 1);
505
506        let formatted = enhanced.format_detailed();
507        assert!(formatted.contains("Error ID:"));
508        assert!(formatted.contains("Processing data"));
509    }
510
511    #[test]
512    fn test_error_context() {
513        let context = ErrorContext::new("test_operation")
514            .with_param("param1", "value1")
515            .at_location("test_file.rs:123");
516
517        let base_error = UtilsError::EmptyInput;
518        let enhanced = context.wrap_error(base_error);
519
520        assert!(enhanced.context.len() >= 2);
521        assert!(enhanced.metadata.contains_key("param1"));
522    }
523
524    #[test]
525    fn test_error_aggregator() {
526        let mut aggregator = ErrorAggregator::new(3, false);
527
528        assert!(!aggregator.has_errors());
529        assert_eq!(aggregator.error_count(), 0);
530
531        let error1 = EnhancedError::new(UtilsError::EmptyInput);
532        let error2 = EnhancedError::new(UtilsError::InvalidParameter("test".to_string()));
533
534        aggregator.add_error(error1).unwrap();
535        aggregator.add_error(error2).unwrap();
536
537        assert!(aggregator.has_errors());
538        assert_eq!(aggregator.error_count(), 2);
539
540        let summary = aggregator.generate_summary();
541        assert_eq!(summary.total_errors, 2);
542        assert_eq!(summary.input_errors, 1);
543        assert_eq!(summary.parameter_errors, 1);
544    }
545
546    #[test]
547    fn test_error_recovery() {
548        // Test shape mismatch recovery
549        let recovery = ErrorRecovery::recover_shape_mismatch(&[10], &[2, 5]);
550        assert_eq!(recovery, Some(vec![10]));
551
552        // Test insufficient data recovery
553        let strategy = ErrorRecovery::recover_insufficient_data(100, 50);
554        match strategy {
555            Some(RecoveryStrategy::ReduceRequirement(50)) => (),
556            _ => panic!("Expected ReduceRequirement strategy"),
557        }
558
559        // Test parameter recovery
560        let fixed = ErrorRecovery::recover_invalid_parameter("n_clusters", "-5");
561        assert_eq!(fixed, Some("1".to_string()));
562    }
563
564    #[test]
565    fn test_error_reporter() {
566        let reporter = ErrorReporter::new();
567        let error = EnhancedError::new(UtilsError::EmptyInput);
568
569        reporter.report_error(error);
570
571        let stats = reporter.get_statistics().unwrap();
572        assert_eq!(stats.total_errors, 1);
573        assert_eq!(stats.input_errors, 1);
574    }
575
576    #[test]
577    fn test_convenience_functions() {
578        let error = create_error(UtilsError::EmptyInput, "test_operation");
579        assert!(error.context.iter().any(|c| c.contains("test_operation")));
580
581        let error_with_location =
582            create_error_at(UtilsError::EmptyInput, "test_operation", "test_file.rs:123");
583        assert!(error_with_location.context.len() >= 2);
584    }
585}