1use crate::error::StatsError;
7use std::collections::HashMap;
8
9pub struct SuggestionEngine {
11 patterns: HashMap<String, Vec<Suggestion>>,
13 context_suggestions: HashMap<String, Vec<Suggestion>>,
15}
16
17#[derive(Debug, Clone)]
19pub struct Suggestion {
20 pub title: String,
22 pub steps: Vec<String>,
24 pub priority: u8,
26 pub example: Option<String>,
28 pub docs: Vec<String>,
30}
31
32impl Default for SuggestionEngine {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl SuggestionEngine {
39 pub fn new() -> Self {
41 let mut engine = Self {
42 patterns: HashMap::new(),
43 context_suggestions: HashMap::new(),
44 };
45
46 engine.initialize_patterns();
47 engine
48 }
49
50 fn initialize_patterns(&mut self) {
52 self.patterns.insert(
54 "nan".to_string(),
55 vec![
56 Suggestion {
57 title: "Remove NaN values".to_string(),
58 steps: vec![
59 "Filter out NaN values using is_nan() check".to_string(),
60 "Use array.iter().filter(|x| !x.is_nan())".to_string(),
61 "Consider using ndarray's mapv() for element-wise operations".to_string(),
62 ],
63 priority: 1,
64 example: Some(
65 r#"
66// Remove NaN values
67let cleandata: Vec<f64> = data.iter()
68 .filter(|&&x| !x.is_nan())
69 .copied()
70 .collect();
71 "#
72 .to_string(),
73 ),
74 docs: vec!["data_cleaning".to_string()],
75 },
76 Suggestion {
77 title: "Impute missing values".to_string(),
78 steps: vec![
79 "Calculate mean/median of non-NaN values".to_string(),
80 "Replace NaN with calculated statistic".to_string(),
81 "Consider forward/backward fill for time series".to_string(),
82 ],
83 priority: 2,
84 example: Some(
85 r#"
86// Impute with mean
87let mean = data.iter()
88 .filter(|&&x| !x.is_nan())
89 .sum::<f64>() / valid_count as f64;
90let imputed = data.mapv(|x| if x.is_nan() { mean } else { x });
91 "#
92 .to_string(),
93 ),
94 docs: vec!["imputation_methods".to_string()],
95 },
96 ],
97 );
98
99 self.patterns.insert(
101 "empty".to_string(),
102 vec![Suggestion {
103 title: "Check data loading process".to_string(),
104 steps: vec![
105 "Verify file path and permissions".to_string(),
106 "Check if filters are too restrictive".to_string(),
107 "Add logging to data loading steps".to_string(),
108 "Validate data source is not empty".to_string(),
109 ],
110 priority: 1,
111 example: Some(
112 r#"
113// Add validation after loading
114let data = loaddata(path)?;
115if data.is_empty() {
116 eprintln!("Warning: Loaded data is empty from {}", path);
117 return Err(StatsError::invalid_argument("No data loaded"));
118}
119 "#
120 .to_string(),
121 ),
122 docs: vec!["data_loading".to_string()],
123 }],
124 );
125
126 self.patterns.insert(
128 "dimension".to_string(),
129 vec![
130 Suggestion {
131 title: "Reshape arrays to match".to_string(),
132 steps: vec![
133 "Check shapes with .shape() or .dim()".to_string(),
134 "Use reshape() to adjust dimensions".to_string(),
135 "Ensure broadcasting rules are followed".to_string(),
136 ],
137 priority: 1,
138 example: Some(
139 r#"
140// Check and match dimensions
141println!("Array A shape: {:?}", a.shape());
142println!("Array B shape: {:?}", b.shape());
143
144// Reshape if needed
145let b_reshaped = b.reshape((a.shape()[0], 1));
146 "#
147 .to_string(),
148 ),
149 docs: vec!["array_broadcasting".to_string()],
150 },
151 Suggestion {
152 title: "Transpose if needed".to_string(),
153 steps: vec![
154 "Check if arrays need transposition".to_string(),
155 "Use .t() or .transpose() methods".to_string(),
156 ],
157 priority: 2,
158 example: Some(
159 r#"
160// Transpose for matrix multiplication
161let result = a.dot(&b.t());
162 "#
163 .to_string(),
164 ),
165 docs: vec!["linear_algebra".to_string()],
166 },
167 ],
168 );
169
170 self.patterns.insert(
172 "converge".to_string(),
173 vec![
174 Suggestion {
175 title: "Adjust algorithm parameters".to_string(),
176 steps: vec![
177 "Increase maximum iterations".to_string(),
178 "Relax convergence tolerance".to_string(),
179 "Try different learning rates".to_string(),
180 ],
181 priority: 1,
182 example: Some(
183 r#"
184// Adjust parameters
185let config = OptimizationConfig {
186 max_iter: 10000, // Increased from default
187 tolerance: 1e-6, // Relaxed from 1e-8
188 learning_rate: 0.01, // Reduced for stability
189};
190 "#
191 .to_string(),
192 ),
193 docs: vec!["optimization_parameters".to_string()],
194 },
195 Suggestion {
196 title: "Preprocess data for better conditioning".to_string(),
197 steps: vec![
198 "Standardize features to zero mean, unit variance".to_string(),
199 "Remove highly correlated features".to_string(),
200 "Apply regularization techniques".to_string(),
201 ],
202 priority: 2,
203 example: Some(
204 r#"
205// Standardize data
206let mean = data.mean().unwrap();
207let std = data.std(1);
208let standardized = (data - mean) / std;
209 "#
210 .to_string(),
211 ),
212 docs: vec!["data_preprocessing".to_string()],
213 },
214 ],
215 );
216
217 self.patterns.insert(
219 "singular".to_string(),
220 vec![
221 Suggestion {
222 title: "Add regularization".to_string(),
223 steps: vec![
224 "Add small value to diagonal (ridge regularization)".to_string(),
225 "Use SVD for pseudo-inverse".to_string(),
226 "Consider dimensionality reduction".to_string(),
227 ],
228 priority: 1,
229 example: Some(
230 r#"
231// Ridge regularization
232let lambda = 1e-4;
233let regularized = matrix + lambda * Array2::eye(matrix.nrows());
234 "#
235 .to_string(),
236 ),
237 docs: vec!["regularization".to_string()],
238 },
239 Suggestion {
240 title: "Check for linear dependencies".to_string(),
241 steps: vec![
242 "Calculate correlation matrix".to_string(),
243 "Remove highly correlated features (|r| > 0.95)".to_string(),
244 "Use PCA to identify redundant dimensions".to_string(),
245 ],
246 priority: 2,
247 example: Some(
248 r#"
249// Check correlations
250let corr_matrix = corrcoef(&data.t(), "pearson")?;
251for i in 0..n_features {
252 for j in i+1..n_features {
253 if corr_matrix[(i,j)].abs() > 0.95 {
254 println!("Features {} and {} are highly correlated", i, j);
255 }
256 }
257}
258 "#
259 .to_string(),
260 ),
261 docs: vec!["multicollinearity".to_string()],
262 },
263 ],
264 );
265
266 self.patterns.insert(
268 "overflow".to_string(),
269 vec![
270 Suggestion {
271 title: "Scale input data".to_string(),
272 steps: vec![
273 "Normalize to [0, 1] or [-1, 1] range".to_string(),
274 "Use log transformation for large values".to_string(),
275 "Apply feature scaling techniques".to_string(),
276 ],
277 priority: 1,
278 example: Some(
279 r#"
280// Min-max scaling
281let min = data.min().unwrap();
282let max = data.max().unwrap();
283let scaled = (data - min) / (max - min);
284
285// Log transformation
286let log_transformed = data.mapv(|x| x.ln());
287 "#
288 .to_string(),
289 ),
290 docs: vec!["feature_scaling".to_string()],
291 },
292 Suggestion {
293 title: "Use numerically stable algorithms".to_string(),
294 steps: vec![
295 "Use log-sum-exp trick for exponentials".to_string(),
296 "Prefer stable implementations (e.g., log1p)".to_string(),
297 "Work in log space when possible".to_string(),
298 ],
299 priority: 2,
300 example: Some(
301 r#"
302// Log-sum-exp trick
303#[allow(dead_code)]
304fn log_sum_exp(values: &[f64]) -> f64 {
305 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
306 let sum = values.iter().map(|&x| (x - max_val).exp()).sum::<f64>();
307 max_val + sum.ln()
308}
309 "#
310 .to_string(),
311 ),
312 docs: vec!["numerical_stability".to_string()],
313 },
314 ],
315 );
316 }
317
318 pub fn get_suggestions(&self, error: &StatsError) -> Vec<Suggestion> {
320 let error_str = error.to_string().to_lowercase();
321 let mut suggestions = Vec::new();
322
323 for (pattern, pattern_suggestions) in &self.patterns {
325 if error_str.contains(pattern) {
326 suggestions.extend_from_slice(pattern_suggestions);
327 }
328 }
329
330 suggestions.sort_by_key(|s| s.priority);
332 suggestions
333 }
334
335 pub fn add_context_suggestions(&mut self, context: String, suggestions: Vec<Suggestion>) {
337 self.context_suggestions.insert(context, suggestions);
338 }
339
340 pub fn get_context_suggestions(&self, context: &str) -> Option<&Vec<Suggestion>> {
342 self.context_suggestions.get(context)
343 }
344}
345
346pub struct ErrorFormatter {
348 suggestion_engine: SuggestionEngine,
349}
350
351impl Default for ErrorFormatter {
352 fn default() -> Self {
353 Self::new()
354 }
355}
356
357impl ErrorFormatter {
358 pub fn new() -> Self {
360 Self {
361 suggestion_engine: SuggestionEngine::new(),
362 }
363 }
364
365 pub fn format_error(&self, error: StatsError, context: Option<&str>) -> String {
367 let mut output = format!("Error: {}\n", error);
368
369 let mut suggestions = self.suggestion_engine.get_suggestions(&error);
371
372 if let Some(ctx) = context {
374 if let Some(ctx_suggestions) = self.suggestion_engine.get_context_suggestions(ctx) {
375 suggestions.extend_from_slice(ctx_suggestions);
376 }
377 }
378
379 if !suggestions.is_empty() {
380 output.push_str("\nš Suggested Solutions:\n");
381
382 for (i, suggestion) in suggestions.iter().enumerate() {
383 output.push_str(&format!(
384 "\n{}. {} (Priority: {})\n",
385 i + 1,
386 suggestion.title,
387 suggestion.priority
388 ));
389
390 output.push_str(" Steps:\n");
391 for step in &suggestion.steps {
392 output.push_str(&format!(" ⢠{}\n", step));
393 }
394
395 if let Some(example) = &suggestion.example {
396 output.push_str("\n Example:\n");
397 for line in example.lines() {
398 output.push_str(&format!(" {}\n", line));
399 }
400 }
401
402 if !suggestion.docs.is_empty() {
403 output.push_str("\n See also: ");
404 output.push_str(&suggestion.docs.join(", "));
405 output.push('\n');
406 }
407 }
408 }
409
410 output
411 }
412}
413
414#[allow(dead_code)]
416pub fn diagnose_error(error: &StatsError) -> DiagnosisReport {
417 let error_str = error.to_string().to_lowercase();
418
419 let error_type = if error_str.contains("dimension") {
420 ErrorType::DimensionMismatch
421 } else if error_str.contains("empty") {
422 ErrorType::EmptyData
423 } else if error_str.contains("nan") {
424 ErrorType::InvalidValues
425 } else if error_str.contains("converge") {
426 ErrorType::ConvergenceFailure
427 } else if error_str.contains("singular") {
428 ErrorType::SingularMatrix
429 } else if error_str.contains("overflow") {
430 ErrorType::NumericalOverflow
431 } else if error_str.contains("domain") {
432 ErrorType::DomainError
433 } else {
434 ErrorType::Other
435 };
436
437 let severity = match error_type {
438 ErrorType::NumericalOverflow | ErrorType::SingularMatrix => Severity::High,
439 ErrorType::ConvergenceFailure | ErrorType::InvalidValues => Severity::Medium,
440 _ => Severity::Low,
441 };
442
443 let likely_causes = match error_type {
444 ErrorType::DimensionMismatch => vec![
445 "Arrays have incompatible shapes".to_string(),
446 "Missing transpose operation".to_string(),
447 "Incorrect axis specification".to_string(),
448 ],
449 ErrorType::EmptyData => vec![
450 "Data loading failed".to_string(),
451 "Filters removed all data".to_string(),
452 "Incorrect file path".to_string(),
453 ],
454 ErrorType::InvalidValues => vec![
455 "Missing data not handled".to_string(),
456 "Division by zero".to_string(),
457 "Invalid mathematical operation".to_string(),
458 ],
459 ErrorType::ConvergenceFailure => vec![
460 "Poor initial values".to_string(),
461 "Ill-conditioned problem".to_string(),
462 "Insufficient iterations".to_string(),
463 ],
464 ErrorType::SingularMatrix => vec![
465 "Linear dependencies in data".to_string(),
466 "Insufficient observations".to_string(),
467 "Perfect multicollinearity".to_string(),
468 ],
469 ErrorType::NumericalOverflow => vec![
470 "Values too large".to_string(),
471 "Exponential growth".to_string(),
472 "Insufficient precision".to_string(),
473 ],
474 ErrorType::DomainError => vec![
475 "Invalid parameter values".to_string(),
476 "Out of bounds input".to_string(),
477 "Constraint violation".to_string(),
478 ],
479 ErrorType::Other => vec!["Unknown cause".to_string()],
480 };
481
482 DiagnosisReport {
483 error_type,
484 severity,
485 likely_causes,
486 }
487}
488
489#[derive(Debug)]
491pub struct DiagnosisReport {
492 pub error_type: ErrorType,
493 pub severity: Severity,
494 pub likely_causes: Vec<String>,
495}
496
497#[derive(Debug, PartialEq)]
499pub enum ErrorType {
500 DimensionMismatch,
501 EmptyData,
502 InvalidValues,
503 ConvergenceFailure,
504 SingularMatrix,
505 NumericalOverflow,
506 DomainError,
507 Other,
508}
509
510#[derive(Debug, PartialEq, PartialOrd)]
512pub enum Severity {
513 Low,
514 Medium,
515 High,
516}
517
518#[macro_export]
520macro_rules! stats_error_with_suggestions {
521 ($error_type:ident, $msg:expr, $($suggestion:expr),+) => {
522 {
523 let error = StatsError::$error_type($msg);
524 let formatter = ErrorFormatter::new();
525 let mut engine = SuggestionEngine::new();
526
527 let suggestions = vec![
528 $(
529 Suggestion {
530 title: $suggestion.to_string(),
531 steps: vec![],
532 priority: 1,
533 example: None,
534 docs: vec![],
535 },
536 )+
537 ];
538
539 engine.add_context_suggestions("custom".to_string(), suggestions);
540 eprintln!("{}", formatter.format_error(error, Some("custom")));
541 error
542 }
543 };
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549
550 #[test]
551 fn test_suggestion_engine() {
552 let engine = SuggestionEngine::new();
553
554 let nan_error = StatsError::invalid_argument("Found NaN values");
556 let suggestions = engine.get_suggestions(&nan_error);
557 assert!(!suggestions.is_empty());
558 assert_eq!(suggestions[0].priority, 1);
559 }
560
561 #[test]
562 fn test_error_diagnosis() {
563 let dim_error = StatsError::dimension_mismatch("Arrays must have same length");
564 let diagnosis = diagnose_error(&dim_error);
565 assert_eq!(diagnosis.error_type, ErrorType::DimensionMismatch);
566 assert_eq!(diagnosis.severity, Severity::Low);
567 }
568
569 #[test]
570 fn test_error_formatter() {
571 let formatter = ErrorFormatter::new();
572 let error = StatsError::invalid_argument("Array contains NaN values");
573 let formatted = formatter.format_error(error, None);
574
575 assert!(formatted.contains("Suggested Solutions"));
576 assert!(formatted.contains("Remove NaN values"));
577 }
578}