scirs2_stats/
error_suggestions.rs

1//! Enhanced error suggestion system with context-aware recovery strategies
2//!
3//! This module provides an intelligent error suggestion system that analyzes
4//! error patterns and provides detailed, actionable recovery suggestions.
5
6use crate::error::StatsError;
7use std::collections::HashMap;
8
9/// Error suggestion engine that provides intelligent recovery suggestions
10pub struct SuggestionEngine {
11    /// Common error patterns and their solutions
12    patterns: HashMap<String, Vec<Suggestion>>,
13    /// Context-specific suggestions
14    context_suggestions: HashMap<String, Vec<Suggestion>>,
15}
16
17/// A recovery suggestion with priority and detailed steps
18#[derive(Debug, Clone)]
19pub struct Suggestion {
20    /// Brief description of the suggestion
21    pub title: String,
22    /// Detailed steps to implement the suggestion
23    pub steps: Vec<String>,
24    /// Priority level (1-5, where 1 is highest)
25    pub priority: u8,
26    /// Example code if applicable
27    pub example: Option<String>,
28    /// Links to relevant documentation
29    pub docs: Vec<String>,
30}
31
32impl Default for SuggestionEngine {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl SuggestionEngine {
39    /// Create a new suggestion engine with built-in patterns
40    pub fn new() -> Self {
41        let mut engine = Self {
42            patterns: HashMap::new(),
43            context_suggestions: HashMap::new(),
44        };
45
46        engine.initialize_patterns();
47        engine
48    }
49
50    /// Initialize common error patterns and solutions
51    fn initialize_patterns(&mut self) {
52        // NaN value patterns
53        self.patterns.insert(
54            "nan".to_string(),
55            vec![
56                Suggestion {
57                    title: "Remove NaN values".to_string(),
58                    steps: vec![
59                        "Filter out NaN values using is_nan() check".to_string(),
60                        "Use array.iter().filter(|x| !x.is_nan())".to_string(),
61                        "Consider using ndarray's mapv() for element-wise operations".to_string(),
62                    ],
63                    priority: 1,
64                    example: Some(
65                        r#"
66// Remove NaN values
67let cleandata: Vec<f64> = data.iter()
68    .filter(|&&x| !x.is_nan())
69    .copied()
70    .collect();
71                    "#
72                        .to_string(),
73                    ),
74                    docs: vec!["data_cleaning".to_string()],
75                },
76                Suggestion {
77                    title: "Impute missing values".to_string(),
78                    steps: vec![
79                        "Calculate mean/median of non-NaN values".to_string(),
80                        "Replace NaN with calculated statistic".to_string(),
81                        "Consider forward/backward fill for time series".to_string(),
82                    ],
83                    priority: 2,
84                    example: Some(
85                        r#"
86// Impute with mean
87let mean = data.iter()
88    .filter(|&&x| !x.is_nan())
89    .sum::<f64>() / valid_count as f64;
90let imputed = data.mapv(|x| if x.is_nan() { mean } else { x });
91                    "#
92                        .to_string(),
93                    ),
94                    docs: vec!["imputation_methods".to_string()],
95                },
96            ],
97        );
98
99        // Empty array patterns
100        self.patterns.insert(
101            "empty".to_string(),
102            vec![Suggestion {
103                title: "Check data loading process".to_string(),
104                steps: vec![
105                    "Verify file path and permissions".to_string(),
106                    "Check if filters are too restrictive".to_string(),
107                    "Add logging to data loading steps".to_string(),
108                    "Validate data source is not empty".to_string(),
109                ],
110                priority: 1,
111                example: Some(
112                    r#"
113// Add validation after loading
114let data = loaddata(path)?;
115if data.is_empty() {
116    eprintln!("Warning: Loaded data is empty from {}", path);
117    return Err(StatsError::invalid_argument("No data loaded"));
118}
119                    "#
120                    .to_string(),
121                ),
122                docs: vec!["data_loading".to_string()],
123            }],
124        );
125
126        // Dimension mismatch patterns
127        self.patterns.insert(
128            "dimension".to_string(),
129            vec![
130                Suggestion {
131                    title: "Reshape arrays to match".to_string(),
132                    steps: vec![
133                        "Check shapes with .shape() or .dim()".to_string(),
134                        "Use reshape() to adjust dimensions".to_string(),
135                        "Ensure broadcasting rules are followed".to_string(),
136                    ],
137                    priority: 1,
138                    example: Some(
139                        r#"
140// Check and match dimensions
141println!("Array A shape: {:?}", a.shape());
142println!("Array B shape: {:?}", b.shape());
143
144// Reshape if needed
145let b_reshaped = b.reshape((a.shape()[0], 1));
146                    "#
147                        .to_string(),
148                    ),
149                    docs: vec!["array_broadcasting".to_string()],
150                },
151                Suggestion {
152                    title: "Transpose if needed".to_string(),
153                    steps: vec![
154                        "Check if arrays need transposition".to_string(),
155                        "Use .t() or .transpose() methods".to_string(),
156                    ],
157                    priority: 2,
158                    example: Some(
159                        r#"
160// Transpose for matrix multiplication
161let result = a.dot(&b.t());
162                    "#
163                        .to_string(),
164                    ),
165                    docs: vec!["linear_algebra".to_string()],
166                },
167            ],
168        );
169
170        // Convergence failure patterns
171        self.patterns.insert(
172            "converge".to_string(),
173            vec![
174                Suggestion {
175                    title: "Adjust algorithm parameters".to_string(),
176                    steps: vec![
177                        "Increase maximum iterations".to_string(),
178                        "Relax convergence tolerance".to_string(),
179                        "Try different learning rates".to_string(),
180                    ],
181                    priority: 1,
182                    example: Some(
183                        r#"
184// Adjust parameters
185let config = OptimizationConfig {
186    max_iter: 10000,  // Increased from default
187    tolerance: 1e-6,  // Relaxed from 1e-8
188    learning_rate: 0.01,  // Reduced for stability
189};
190                    "#
191                        .to_string(),
192                    ),
193                    docs: vec!["optimization_parameters".to_string()],
194                },
195                Suggestion {
196                    title: "Preprocess data for better conditioning".to_string(),
197                    steps: vec![
198                        "Standardize features to zero mean, unit variance".to_string(),
199                        "Remove highly correlated features".to_string(),
200                        "Apply regularization techniques".to_string(),
201                    ],
202                    priority: 2,
203                    example: Some(
204                        r#"
205// Standardize data
206let mean = data.mean().unwrap();
207let std = data.std(1);
208let standardized = (data - mean) / std;
209                    "#
210                        .to_string(),
211                    ),
212                    docs: vec!["data_preprocessing".to_string()],
213                },
214            ],
215        );
216
217        // Singular matrix patterns
218        self.patterns.insert(
219            "singular".to_string(),
220            vec![
221                Suggestion {
222                    title: "Add regularization".to_string(),
223                    steps: vec![
224                        "Add small value to diagonal (ridge regularization)".to_string(),
225                        "Use SVD for pseudo-inverse".to_string(),
226                        "Consider dimensionality reduction".to_string(),
227                    ],
228                    priority: 1,
229                    example: Some(
230                        r#"
231// Ridge regularization
232let lambda = 1e-4;
233let regularized = matrix + lambda * Array2::eye(matrix.nrows());
234                    "#
235                        .to_string(),
236                    ),
237                    docs: vec!["regularization".to_string()],
238                },
239                Suggestion {
240                    title: "Check for linear dependencies".to_string(),
241                    steps: vec![
242                        "Calculate correlation matrix".to_string(),
243                        "Remove highly correlated features (|r| > 0.95)".to_string(),
244                        "Use PCA to identify redundant dimensions".to_string(),
245                    ],
246                    priority: 2,
247                    example: Some(
248                        r#"
249// Check correlations
250let corr_matrix = corrcoef(&data.t(), "pearson")?;
251for i in 0..n_features {
252    for j in i+1..n_features {
253        if corr_matrix[(i,j)].abs() > 0.95 {
254            println!("Features {} and {} are highly correlated", i, j);
255        }
256    }
257}
258                    "#
259                        .to_string(),
260                    ),
261                    docs: vec!["multicollinearity".to_string()],
262                },
263            ],
264        );
265
266        // Overflow patterns
267        self.patterns.insert(
268            "overflow".to_string(),
269            vec![
270                Suggestion {
271                    title: "Scale input data".to_string(),
272                    steps: vec![
273                        "Normalize to [0, 1] or [-1, 1] range".to_string(),
274                        "Use log transformation for large values".to_string(),
275                        "Apply feature scaling techniques".to_string(),
276                    ],
277                    priority: 1,
278                    example: Some(
279                        r#"
280// Min-max scaling
281let min = data.min().unwrap();
282let max = data.max().unwrap();
283let scaled = (data - min) / (max - min);
284
285// Log transformation
286let log_transformed = data.mapv(|x| x.ln());
287                    "#
288                        .to_string(),
289                    ),
290                    docs: vec!["feature_scaling".to_string()],
291                },
292                Suggestion {
293                    title: "Use numerically stable algorithms".to_string(),
294                    steps: vec![
295                        "Use log-sum-exp trick for exponentials".to_string(),
296                        "Prefer stable implementations (e.g., log1p)".to_string(),
297                        "Work in log space when possible".to_string(),
298                    ],
299                    priority: 2,
300                    example: Some(
301                        r#"
302// Log-sum-exp trick
303#[allow(dead_code)]
304fn log_sum_exp(values: &[f64]) -> f64 {
305    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
306    let sum = values.iter().map(|&x| (x - max_val).exp()).sum::<f64>();
307    max_val + sum.ln()
308}
309                    "#
310                        .to_string(),
311                    ),
312                    docs: vec!["numerical_stability".to_string()],
313                },
314            ],
315        );
316    }
317
318    /// Get suggestions for a specific error
319    pub fn get_suggestions(&self, error: &StatsError) -> Vec<Suggestion> {
320        let error_str = error.to_string().to_lowercase();
321        let mut suggestions = Vec::new();
322
323        // Check each pattern
324        for (pattern, pattern_suggestions) in &self.patterns {
325            if error_str.contains(pattern) {
326                suggestions.extend_from_slice(pattern_suggestions);
327            }
328        }
329
330        // Sort by priority
331        suggestions.sort_by_key(|s| s.priority);
332        suggestions
333    }
334
335    /// Add context-specific suggestions
336    pub fn add_context_suggestions(&mut self, context: String, suggestions: Vec<Suggestion>) {
337        self.context_suggestions.insert(context, suggestions);
338    }
339
340    /// Get suggestions for a specific context
341    pub fn get_context_suggestions(&self, context: &str) -> Option<&Vec<Suggestion>> {
342        self.context_suggestions.get(context)
343    }
344}
345
346/// Enhanced error formatter with suggestions
347pub struct ErrorFormatter {
348    suggestion_engine: SuggestionEngine,
349}
350
351impl Default for ErrorFormatter {
352    fn default() -> Self {
353        Self::new()
354    }
355}
356
357impl ErrorFormatter {
358    /// Create a new error formatter
359    pub fn new() -> Self {
360        Self {
361            suggestion_engine: SuggestionEngine::new(),
362        }
363    }
364
365    /// Format an error with detailed suggestions
366    pub fn format_error(&self, error: StatsError, context: Option<&str>) -> String {
367        let mut output = format!("Error: {}\n", error);
368
369        // Get automatic suggestions
370        let mut suggestions = self.suggestion_engine.get_suggestions(&error);
371
372        // Add context-specific suggestions if available
373        if let Some(ctx) = context {
374            if let Some(ctx_suggestions) = self.suggestion_engine.get_context_suggestions(ctx) {
375                suggestions.extend_from_slice(ctx_suggestions);
376            }
377        }
378
379        if !suggestions.is_empty() {
380            output.push_str("\nšŸ“‹ Suggested Solutions:\n");
381
382            for (i, suggestion) in suggestions.iter().enumerate() {
383                output.push_str(&format!(
384                    "\n{}. {} (Priority: {})\n",
385                    i + 1,
386                    suggestion.title,
387                    suggestion.priority
388                ));
389
390                output.push_str("   Steps:\n");
391                for step in &suggestion.steps {
392                    output.push_str(&format!("   • {}\n", step));
393                }
394
395                if let Some(example) = &suggestion.example {
396                    output.push_str("\n   Example:\n");
397                    for line in example.lines() {
398                        output.push_str(&format!("   {}\n", line));
399                    }
400                }
401
402                if !suggestion.docs.is_empty() {
403                    output.push_str("\n   See also: ");
404                    output.push_str(&suggestion.docs.join(", "));
405                    output.push('\n');
406                }
407            }
408        }
409
410        output
411    }
412}
413
414/// Quick error diagnosis tool
415#[allow(dead_code)]
416pub fn diagnose_error(error: &StatsError) -> DiagnosisReport {
417    let error_str = error.to_string().to_lowercase();
418
419    let error_type = if error_str.contains("dimension") {
420        ErrorType::DimensionMismatch
421    } else if error_str.contains("empty") {
422        ErrorType::EmptyData
423    } else if error_str.contains("nan") {
424        ErrorType::InvalidValues
425    } else if error_str.contains("converge") {
426        ErrorType::ConvergenceFailure
427    } else if error_str.contains("singular") {
428        ErrorType::SingularMatrix
429    } else if error_str.contains("overflow") {
430        ErrorType::NumericalOverflow
431    } else if error_str.contains("domain") {
432        ErrorType::DomainError
433    } else {
434        ErrorType::Other
435    };
436
437    let severity = match error_type {
438        ErrorType::NumericalOverflow | ErrorType::SingularMatrix => Severity::High,
439        ErrorType::ConvergenceFailure | ErrorType::InvalidValues => Severity::Medium,
440        _ => Severity::Low,
441    };
442
443    let likely_causes = match error_type {
444        ErrorType::DimensionMismatch => vec![
445            "Arrays have incompatible shapes".to_string(),
446            "Missing transpose operation".to_string(),
447            "Incorrect axis specification".to_string(),
448        ],
449        ErrorType::EmptyData => vec![
450            "Data loading failed".to_string(),
451            "Filters removed all data".to_string(),
452            "Incorrect file path".to_string(),
453        ],
454        ErrorType::InvalidValues => vec![
455            "Missing data not handled".to_string(),
456            "Division by zero".to_string(),
457            "Invalid mathematical operation".to_string(),
458        ],
459        ErrorType::ConvergenceFailure => vec![
460            "Poor initial values".to_string(),
461            "Ill-conditioned problem".to_string(),
462            "Insufficient iterations".to_string(),
463        ],
464        ErrorType::SingularMatrix => vec![
465            "Linear dependencies in data".to_string(),
466            "Insufficient observations".to_string(),
467            "Perfect multicollinearity".to_string(),
468        ],
469        ErrorType::NumericalOverflow => vec![
470            "Values too large".to_string(),
471            "Exponential growth".to_string(),
472            "Insufficient precision".to_string(),
473        ],
474        ErrorType::DomainError => vec![
475            "Invalid parameter values".to_string(),
476            "Out of bounds input".to_string(),
477            "Constraint violation".to_string(),
478        ],
479        ErrorType::Other => vec!["Unknown cause".to_string()],
480    };
481
482    DiagnosisReport {
483        error_type,
484        severity,
485        likely_causes,
486    }
487}
488
489/// Error diagnosis report
490#[derive(Debug)]
491pub struct DiagnosisReport {
492    pub error_type: ErrorType,
493    pub severity: Severity,
494    pub likely_causes: Vec<String>,
495}
496
497/// Common error types
498#[derive(Debug, PartialEq)]
499pub enum ErrorType {
500    DimensionMismatch,
501    EmptyData,
502    InvalidValues,
503    ConvergenceFailure,
504    SingularMatrix,
505    NumericalOverflow,
506    DomainError,
507    Other,
508}
509
510/// Error severity levels
511#[derive(Debug, PartialEq, PartialOrd)]
512pub enum Severity {
513    Low,
514    Medium,
515    High,
516}
517
518/// Helper macro for creating errors with suggestions
519#[macro_export]
520macro_rules! stats_error_with_suggestions {
521    ($error_type:ident, $msg:expr, $($suggestion:expr),+) => {
522        {
523            let error = StatsError::$error_type($msg);
524            let formatter = ErrorFormatter::new();
525            let mut engine = SuggestionEngine::new();
526
527            let suggestions = vec![
528                $(
529                    Suggestion {
530                        title: $suggestion.to_string(),
531                        steps: vec![],
532                        priority: 1,
533                        example: None,
534                        docs: vec![],
535                    },
536                )+
537            ];
538
539            engine.add_context_suggestions("custom".to_string(), suggestions);
540            eprintln!("{}", formatter.format_error(error, Some("custom")));
541            error
542        }
543    };
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549
550    #[test]
551    fn test_suggestion_engine() {
552        let engine = SuggestionEngine::new();
553
554        // Test NaN error suggestions
555        let nan_error = StatsError::invalid_argument("Found NaN values");
556        let suggestions = engine.get_suggestions(&nan_error);
557        assert!(!suggestions.is_empty());
558        assert_eq!(suggestions[0].priority, 1);
559    }
560
561    #[test]
562    fn test_error_diagnosis() {
563        let dim_error = StatsError::dimension_mismatch("Arrays must have same length");
564        let diagnosis = diagnose_error(&dim_error);
565        assert_eq!(diagnosis.error_type, ErrorType::DimensionMismatch);
566        assert_eq!(diagnosis.severity, Severity::Low);
567    }
568
569    #[test]
570    fn test_error_formatter() {
571        let formatter = ErrorFormatter::new();
572        let error = StatsError::invalid_argument("Array contains NaN values");
573        let formatted = formatter.format_error(error, None);
574
575        assert!(formatted.contains("Suggested Solutions"));
576        assert!(formatted.contains("Remove NaN values"));
577    }
578}