sklears_core/
error.rs

1use thiserror::Error;
2
3/// Main error type for sklears
4#[derive(Error, Debug)]
5pub enum SklearsError {
6    /// Error during model fitting
7    #[error("Fit error: {0}")]
8    FitError(String),
9
10    /// Error during prediction
11    #[error("Prediction error: {0}")]
12    PredictError(String),
13
14    /// Error during data transformation
15    #[error("Transform error: {0}")]
16    TransformError(String),
17
18    /// Invalid input data
19    #[error("Invalid input: {0}")]
20    InvalidInput(String),
21
22    /// Invalid data quality
23    #[error("Invalid data: {reason}")]
24    InvalidData { reason: String },
25
26    /// Shape mismatch between arrays
27    #[error("Shape mismatch: expected {expected}, got {actual}")]
28    ShapeMismatch { expected: String, actual: String },
29
30    /// Invalid parameter value
31    #[error("Invalid parameter '{name}': {reason}")]
32    InvalidParameter { name: String, reason: String },
33
34    /// Dimension mismatch between arrays
35    #[error("Dimension mismatch: expected {expected}, got {actual}")]
36    DimensionMismatch { expected: usize, actual: usize },
37
38    /// Model not fitted
39    #[error("Model not fitted. Call fit() before {operation}")]
40    NotFitted { operation: String },
41
42    /// Numerical computation error
43    #[error("Numerical error: {0}")]
44    NumericalError(String),
45
46    /// Convergence failure
47    #[error("Failed to converge after {iterations} iterations")]
48    ConvergenceError { iterations: usize },
49
50    /// Feature dimension mismatch
51    #[error("Feature dimension mismatch: model expects {expected} features, got {actual}")]
52    FeatureMismatch { expected: usize, actual: usize },
53
54    /// Missing dependency error
55    #[error("Missing dependency '{dependency}' required for {feature}")]
56    MissingDependency { dependency: String, feature: String },
57
58    /// IO error
59    #[error("IO error: {0}")]
60    IoError(#[from] std::io::Error),
61
62    /// File operation error
63    #[error("File error: {0}")]
64    FileError(String),
65
66    /// Serialization error
67    #[error("Serialization error: {0}")]
68    SerializationError(String),
69
70    /// Deserialization error
71    #[error("Deserialization error: {0}")]
72    DeserializationError(String),
73
74    /// Not implemented
75    #[error("Not implemented: {0}")]
76    NotImplemented(String),
77
78    /// Invalid operation
79    #[error("Invalid operation: {0}")]
80    InvalidOperation(String),
81
82    /// Invalid state error
83    #[error("Invalid state: {0}")]
84    InvalidState(String),
85
86    /// Configuration error
87    #[error("Configuration error: {0}")]
88    Configuration(String),
89
90    /// Trait not found
91    #[error("Trait not found: {0}")]
92    TraitNotFound(String),
93
94    /// Analysis error
95    #[error("Analysis error: {0}")]
96    AnalysisError(String),
97
98    /// Hardware error
99    #[error("Hardware error: {0}")]
100    HardwareError(String),
101
102    /// Resource allocation error
103    #[error("Resource allocation error: {0}")]
104    ResourceAllocationError(String),
105
106    /// Invalid configuration error
107    #[error("Invalid configuration: {0}")]
108    InvalidConfiguration(String),
109
110    /// Processing error
111    #[error("Processing error: {0}")]
112    ProcessingError(String),
113
114    /// Model error
115    #[error("Model error: {0}")]
116    ModelError(String),
117
118    /// Validation error
119    #[error("Validation error: {0}")]
120    ValidationError(String),
121
122    /// Other errors
123    #[error("{0}")]
124    Other(String),
125}
126
127impl Clone for SklearsError {
128    fn clone(&self) -> Self {
129        match self {
130            SklearsError::FitError(s) => SklearsError::FitError(s.clone()),
131            SklearsError::PredictError(s) => SklearsError::PredictError(s.clone()),
132            SklearsError::TransformError(s) => SklearsError::TransformError(s.clone()),
133            SklearsError::InvalidInput(s) => SklearsError::InvalidInput(s.clone()),
134            SklearsError::InvalidData { reason } => SklearsError::InvalidData {
135                reason: reason.clone(),
136            },
137            SklearsError::ShapeMismatch { expected, actual } => SklearsError::ShapeMismatch {
138                expected: expected.clone(),
139                actual: actual.clone(),
140            },
141            SklearsError::InvalidParameter { name, reason } => SklearsError::InvalidParameter {
142                name: name.clone(),
143                reason: reason.clone(),
144            },
145            SklearsError::DimensionMismatch { expected, actual } => {
146                SklearsError::DimensionMismatch {
147                    expected: *expected,
148                    actual: *actual,
149                }
150            }
151            SklearsError::NotFitted { operation } => SklearsError::NotFitted {
152                operation: operation.clone(),
153            },
154            SklearsError::NumericalError(s) => SklearsError::NumericalError(s.clone()),
155            SklearsError::ConvergenceError { iterations } => SklearsError::ConvergenceError {
156                iterations: *iterations,
157            },
158            SklearsError::FeatureMismatch { expected, actual } => SklearsError::FeatureMismatch {
159                expected: *expected,
160                actual: *actual,
161            },
162            SklearsError::IoError(io_err) => {
163                // Since std::io::Error doesn't implement Clone, we create a new one with the same kind and message
164                SklearsError::IoError(std::io::Error::new(io_err.kind(), format!("{io_err}")))
165            }
166            SklearsError::FileError(s) => SklearsError::FileError(s.clone()),
167            SklearsError::SerializationError(s) => SklearsError::SerializationError(s.clone()),
168            SklearsError::DeserializationError(s) => SklearsError::DeserializationError(s.clone()),
169            SklearsError::NotImplemented(s) => SklearsError::NotImplemented(s.clone()),
170            SklearsError::InvalidOperation(s) => SklearsError::InvalidOperation(s.clone()),
171            SklearsError::InvalidState(s) => SklearsError::InvalidState(s.clone()),
172            SklearsError::Configuration(s) => SklearsError::Configuration(s.clone()),
173            SklearsError::MissingDependency {
174                dependency,
175                feature,
176            } => SklearsError::MissingDependency {
177                dependency: dependency.clone(),
178                feature: feature.clone(),
179            },
180            SklearsError::TraitNotFound(s) => SklearsError::TraitNotFound(s.clone()),
181            SklearsError::AnalysisError(s) => SklearsError::AnalysisError(s.clone()),
182            SklearsError::HardwareError(s) => SklearsError::HardwareError(s.clone()),
183            SklearsError::ResourceAllocationError(s) => {
184                SklearsError::ResourceAllocationError(s.clone())
185            }
186            SklearsError::InvalidConfiguration(s) => SklearsError::InvalidConfiguration(s.clone()),
187            SklearsError::ProcessingError(s) => SklearsError::ProcessingError(s.clone()),
188            SklearsError::ModelError(s) => SklearsError::ModelError(s.clone()),
189            SklearsError::ValidationError(s) => SklearsError::ValidationError(s.clone()),
190            SklearsError::Other(s) => SklearsError::Other(s.clone()),
191        }
192    }
193}
194
195// Convert from String
196impl From<String> for SklearsError {
197    fn from(error: String) -> Self {
198        SklearsError::Other(error)
199    }
200}
201
202// Convert from &str
203impl From<&str> for SklearsError {
204    fn from(error: &str) -> Self {
205        SklearsError::Other(error.to_string())
206    }
207}
208
209// Convert from ndarray ShapeError
210impl From<scirs2_core::ndarray::ShapeError> for SklearsError {
211    fn from(error: scirs2_core::ndarray::ShapeError) -> Self {
212        SklearsError::InvalidInput(format!("Array shape error: {error}"))
213    }
214}
215
216// Convert from serde_json::Error
217impl From<serde_json::Error> for SklearsError {
218    fn from(error: serde_json::Error) -> Self {
219        SklearsError::SerializationError(format!("JSON serialization error: {error}"))
220    }
221}
222
223/// Result type alias for sklears operations
224pub type Result<T> = std::result::Result<T, SklearsError>;
225
226/// Enhanced error context trait for better error propagation
227pub trait ErrorContext<T> {
228    /// Add context to an error
229    fn context(self, msg: &str) -> Result<T>;
230
231    /// Add context with a lazy-evaluated closure
232    fn with_context<F>(self, f: F) -> Result<T>
233    where
234        F: FnOnce() -> String;
235
236    /// Add operation context for debugging
237    fn with_operation(self, operation: &str) -> Result<T>;
238
239    /// Add location context for debugging  
240    fn with_location(self, file: &str, line: u32) -> Result<T>;
241}
242
243impl<T, E> ErrorContext<T> for std::result::Result<T, E>
244where
245    E: std::error::Error,
246{
247    fn context(self, msg: &str) -> Result<T> {
248        self.map_err(|e| SklearsError::Other(format!("{msg}: {e}")))
249    }
250
251    fn with_context<F>(self, f: F) -> Result<T>
252    where
253        F: FnOnce() -> String,
254    {
255        self.map_err(|e| SklearsError::Other(format!("{}: {e}", f())))
256    }
257
258    fn with_operation(self, operation: &str) -> Result<T> {
259        self.map_err(|e| SklearsError::Other(format!("Operation '{operation}' failed: {e}")))
260    }
261
262    fn with_location(self, file: &str, line: u32) -> Result<T> {
263        self.map_err(|e| SklearsError::Other(format!("Error at {file}:{line}: {e}")))
264    }
265}
266
267/// Macro for adding location context automatically
268#[macro_export]
269macro_rules! error_context {
270    ($result:expr) => {
271        $result.with_location(file!(), line!())
272    };
273    ($result:expr, $msg:expr) => {
274        $result.context($msg).with_location(file!(), line!())
275    };
276}
277
278/// Enhanced context propagation for sklearn-specific operations
279pub trait SklearnContext<T> {
280    /// Add context for fitting operations
281    fn fit_context(self, estimator: &str, samples: usize, features: usize) -> Result<T>;
282
283    /// Add context for prediction operations  
284    fn predict_context(self, estimator: &str, samples: usize) -> Result<T>;
285
286    /// Add context for transformation operations
287    fn transform_context(self, transformer: &str, samples: usize, features: usize) -> Result<T>;
288
289    /// Add context for validation operations
290    fn validation_context(self, parameter: &str, value: &str) -> Result<T>;
291}
292
293impl<T, E> SklearnContext<T> for std::result::Result<T, E>
294where
295    E: std::error::Error,
296{
297    fn fit_context(self, estimator: &str, samples: usize, features: usize) -> Result<T> {
298        self.with_context(|| {
299            format!("Failed to fit {estimator} with {samples} samples and {features} features")
300        })
301    }
302
303    fn predict_context(self, estimator: &str, samples: usize) -> Result<T> {
304        self.with_context(|| format!("Failed to predict using {estimator} with {samples} samples"))
305    }
306
307    fn transform_context(self, transformer: &str, samples: usize, features: usize) -> Result<T> {
308        self.with_context(|| {
309            format!("Failed to transform using {transformer} with {samples} samples and {features} features")
310        })
311    }
312
313    fn validation_context(self, parameter: &str, value: &str) -> Result<T> {
314        self.with_context(|| {
315            format!("Validation failed for parameter '{parameter}' with value '{value}'")
316        })
317    }
318}
319
320/// Convenience macro for validation
321#[macro_export]
322macro_rules! validate {
323    ($condition:expr, $message:expr) => {
324        if !($condition) {
325            return Err($crate::error::SklearsError::InvalidInput($message.to_string()));
326        }
327    };
328    ($condition:expr, $message:expr, $($arg:tt)*) => {
329        if !($condition) {
330            return Err($crate::error::SklearsError::InvalidInput(format!($message, $($arg)*)));
331        }
332    };
333}
334
335/// Chain multiple errors together for better debugging
336#[derive(Debug)]
337pub struct ErrorChain {
338    errors: Vec<Box<dyn std::error::Error + Send + Sync>>,
339    context: Vec<String>,
340}
341
342impl ErrorChain {
343    /// Create a new error chain
344    pub fn new() -> Self {
345        Self {
346            errors: Vec::new(),
347            context: Vec::new(),
348        }
349    }
350
351    /// Add an error to the chain
352    pub fn push_error<E>(mut self, error: E) -> Self
353    where
354        E: std::error::Error + Send + Sync + 'static,
355    {
356        self.errors.push(Box::new(error));
357        self
358    }
359
360    /// Add context to the chain
361    pub fn push_context<S: Into<String>>(mut self, context: S) -> Self {
362        self.context.push(context.into());
363        self
364    }
365
366    /// Convert to SklearsError
367    pub fn into_error(self) -> SklearsError {
368        let message = if self.context.is_empty() && self.errors.is_empty() {
369            "Unknown error chain".to_string()
370        } else {
371            let context_str = self.context.join(" -> ");
372            let error_str = self
373                .errors
374                .iter()
375                .map(|e| e.to_string())
376                .collect::<Vec<_>>()
377                .join("; ");
378
379            if context_str.is_empty() {
380                error_str
381            } else if error_str.is_empty() {
382                context_str
383            } else {
384                format!("{context_str}: {error_str}")
385            }
386        };
387
388        SklearsError::Other(message)
389    }
390}
391
392impl Default for ErrorChain {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398/// Validation utilities
399pub mod validate {
400    use super::*;
401    use crate::types::{Array1, Array2, FloatBounds, Numeric};
402
403    /// Check if X and y have compatible shapes
404    pub fn check_consistent_length<T, U>(x: &Array2<T>, y: &Array1<U>) -> Result<()> {
405        let n_samples_x = x.nrows();
406        let n_samples_y = y.len();
407
408        if n_samples_x != n_samples_y {
409            return Err(SklearsError::ShapeMismatch {
410                expected: "X.shape[0] == y.shape[0]".to_string(),
411                actual: format!("X.shape[0]={n_samples_x}, y.shape[0]={n_samples_y}"),
412            });
413        }
414
415        Ok(())
416    }
417
418    /// Check if array has the expected number of features
419    pub fn check_n_features<T>(x: &Array2<T>, expected: usize) -> Result<()> {
420        let actual = x.ncols();
421        if actual != expected {
422            return Err(SklearsError::FeatureMismatch { expected, actual });
423        }
424        Ok(())
425    }
426
427    /// Check if value is finite (generic over floating point types)
428    pub fn check_finite<T: FloatBounds>(value: T, name: &str) -> Result<()> {
429        if !value.is_finite() {
430            return Err(SklearsError::InvalidParameter {
431                name: name.to_string(),
432                reason: "must be finite".to_string(),
433            });
434        }
435        Ok(())
436    }
437
438    /// Check if value is positive (generic over numeric types)
439    pub fn check_positive<T: Numeric + PartialOrd>(value: T, name: &str) -> Result<()> {
440        if value <= T::zero() {
441            return Err(SklearsError::InvalidParameter {
442                name: name.to_string(),
443                reason: "must be positive".to_string(),
444            });
445        }
446        Ok(())
447    }
448
449    /// Check if value is non-negative (generic over numeric types)
450    pub fn check_non_negative<T: Numeric + PartialOrd>(value: T, name: &str) -> Result<()> {
451        if value < T::zero() {
452            return Err(SklearsError::InvalidParameter {
453                name: name.to_string(),
454                reason: "must be non-negative".to_string(),
455            });
456        }
457        Ok(())
458    }
459
460    /// Check if value is in a specific range
461    pub fn check_in_range<T: Numeric + PartialOrd>(
462        value: T,
463        min: T,
464        max: T,
465        name: &str,
466    ) -> Result<()> {
467        if value < min || value > max {
468            return Err(SklearsError::InvalidParameter {
469                name: name.to_string(),
470                reason: format!("must be in range [{min}, {max}]"),
471            });
472        }
473        Ok(())
474    }
475
476    /// Check if arrays have compatible shapes for matrix multiplication
477    pub fn check_matmul_compatible<T, U>(a: &Array2<T>, b: &Array2<U>) -> Result<()> {
478        if a.ncols() != b.nrows() {
479            return Err(SklearsError::ShapeMismatch {
480                expected: "A.shape[1] == B.shape[0]".to_string(),
481                actual: format!("A.shape[1]={}, B.shape[0]={}", a.ncols(), b.nrows()),
482            });
483        }
484        Ok(())
485    }
486}
487
488#[allow(non_snake_case)]
489#[cfg(test)]
490mod tests {
491    use super::*;
492
493    #[test]
494    fn test_error_context() {
495        let result: std::result::Result<(), std::io::Error> = Err(std::io::Error::new(
496            std::io::ErrorKind::NotFound,
497            "file not found",
498        ));
499
500        let with_context = result.context("Failed to read config file");
501        assert!(with_context.is_err());
502        assert!(with_context
503            .unwrap_err()
504            .to_string()
505            .contains("Failed to read config file"));
506    }
507
508    #[test]
509    fn test_error_with_operation() {
510        let result: std::result::Result<(), std::io::Error> = Err(std::io::Error::new(
511            std::io::ErrorKind::PermissionDenied,
512            "access denied",
513        ));
514
515        let with_op = result.with_operation("matrix_multiplication");
516        assert!(with_op.is_err());
517        assert!(with_op
518            .unwrap_err()
519            .to_string()
520            .contains("matrix_multiplication"));
521    }
522
523    #[test]
524    fn test_sklearn_context() {
525        let result: std::result::Result<(), std::io::Error> = Err(std::io::Error::new(
526            std::io::ErrorKind::InvalidInput,
527            "invalid data",
528        ));
529
530        let with_fit_context = result.fit_context("LinearRegression", 100, 5);
531        assert!(with_fit_context.is_err());
532        let error_msg = with_fit_context.unwrap_err().to_string();
533        assert!(error_msg.contains("LinearRegression"));
534        assert!(error_msg.contains("100 samples"));
535        assert!(error_msg.contains("5 features"));
536    }
537
538    #[test]
539    fn test_error_chain() {
540        let chain = ErrorChain::new()
541            .push_context("Model training")
542            .push_context("Data preprocessing")
543            .push_error(std::io::Error::new(
544                std::io::ErrorKind::NotFound,
545                "data file missing",
546            ))
547            .push_context("Feature scaling");
548
549        let error = chain.into_error();
550        let error_str = error.to_string();
551        assert!(error_str.contains("Model training"));
552        assert!(error_str.contains("Data preprocessing"));
553        assert!(error_str.contains("Feature scaling"));
554        assert!(error_str.contains("data file missing"));
555    }
556
557    #[test]
558    fn test_validation_context() {
559        let result: std::result::Result<(), std::io::Error> = Err(std::io::Error::new(
560            std::io::ErrorKind::InvalidInput,
561            "negative value",
562        ));
563
564        let with_validation = result.validation_context("learning_rate", "-0.1");
565        assert!(with_validation.is_err());
566        let error_msg = with_validation.unwrap_err().to_string();
567        assert!(error_msg.contains("learning_rate"));
568        assert!(error_msg.contains("-0.1"));
569    }
570}