1use thiserror::Error;
2
3#[derive(Error, Debug)]
5pub enum SklearsError {
6 #[error("Fit error: {0}")]
8 FitError(String),
9
10 #[error("Prediction error: {0}")]
12 PredictError(String),
13
14 #[error("Transform error: {0}")]
16 TransformError(String),
17
18 #[error("Invalid input: {0}")]
20 InvalidInput(String),
21
22 #[error("Invalid data: {reason}")]
24 InvalidData { reason: String },
25
26 #[error("Shape mismatch: expected {expected}, got {actual}")]
28 ShapeMismatch { expected: String, actual: String },
29
30 #[error("Invalid parameter '{name}': {reason}")]
32 InvalidParameter { name: String, reason: String },
33
34 #[error("Dimension mismatch: expected {expected}, got {actual}")]
36 DimensionMismatch { expected: usize, actual: usize },
37
38 #[error("Model not fitted. Call fit() before {operation}")]
40 NotFitted { operation: String },
41
42 #[error("Numerical error: {0}")]
44 NumericalError(String),
45
46 #[error("Failed to converge after {iterations} iterations")]
48 ConvergenceError { iterations: usize },
49
50 #[error("Feature dimension mismatch: model expects {expected} features, got {actual}")]
52 FeatureMismatch { expected: usize, actual: usize },
53
54 #[error("Missing dependency '{dependency}' required for {feature}")]
56 MissingDependency { dependency: String, feature: String },
57
58 #[error("IO error: {0}")]
60 IoError(#[from] std::io::Error),
61
62 #[error("File error: {0}")]
64 FileError(String),
65
66 #[error("Serialization error: {0}")]
68 SerializationError(String),
69
70 #[error("Deserialization error: {0}")]
72 DeserializationError(String),
73
74 #[error("Not implemented: {0}")]
76 NotImplemented(String),
77
78 #[error("Invalid operation: {0}")]
80 InvalidOperation(String),
81
82 #[error("Invalid state: {0}")]
84 InvalidState(String),
85
86 #[error("Configuration error: {0}")]
88 Configuration(String),
89
90 #[error("Trait not found: {0}")]
92 TraitNotFound(String),
93
94 #[error("Analysis error: {0}")]
96 AnalysisError(String),
97
98 #[error("Hardware error: {0}")]
100 HardwareError(String),
101
102 #[error("Resource allocation error: {0}")]
104 ResourceAllocationError(String),
105
106 #[error("Invalid configuration: {0}")]
108 InvalidConfiguration(String),
109
110 #[error("Processing error: {0}")]
112 ProcessingError(String),
113
114 #[error("Model error: {0}")]
116 ModelError(String),
117
118 #[error("Validation error: {0}")]
120 ValidationError(String),
121
122 #[error("{0}")]
124 Other(String),
125}
126
127impl Clone for SklearsError {
128 fn clone(&self) -> Self {
129 match self {
130 SklearsError::FitError(s) => SklearsError::FitError(s.clone()),
131 SklearsError::PredictError(s) => SklearsError::PredictError(s.clone()),
132 SklearsError::TransformError(s) => SklearsError::TransformError(s.clone()),
133 SklearsError::InvalidInput(s) => SklearsError::InvalidInput(s.clone()),
134 SklearsError::InvalidData { reason } => SklearsError::InvalidData {
135 reason: reason.clone(),
136 },
137 SklearsError::ShapeMismatch { expected, actual } => SklearsError::ShapeMismatch {
138 expected: expected.clone(),
139 actual: actual.clone(),
140 },
141 SklearsError::InvalidParameter { name, reason } => SklearsError::InvalidParameter {
142 name: name.clone(),
143 reason: reason.clone(),
144 },
145 SklearsError::DimensionMismatch { expected, actual } => {
146 SklearsError::DimensionMismatch {
147 expected: *expected,
148 actual: *actual,
149 }
150 }
151 SklearsError::NotFitted { operation } => SklearsError::NotFitted {
152 operation: operation.clone(),
153 },
154 SklearsError::NumericalError(s) => SklearsError::NumericalError(s.clone()),
155 SklearsError::ConvergenceError { iterations } => SklearsError::ConvergenceError {
156 iterations: *iterations,
157 },
158 SklearsError::FeatureMismatch { expected, actual } => SklearsError::FeatureMismatch {
159 expected: *expected,
160 actual: *actual,
161 },
162 SklearsError::IoError(io_err) => {
163 SklearsError::IoError(std::io::Error::new(io_err.kind(), format!("{io_err}")))
165 }
166 SklearsError::FileError(s) => SklearsError::FileError(s.clone()),
167 SklearsError::SerializationError(s) => SklearsError::SerializationError(s.clone()),
168 SklearsError::DeserializationError(s) => SklearsError::DeserializationError(s.clone()),
169 SklearsError::NotImplemented(s) => SklearsError::NotImplemented(s.clone()),
170 SklearsError::InvalidOperation(s) => SklearsError::InvalidOperation(s.clone()),
171 SklearsError::InvalidState(s) => SklearsError::InvalidState(s.clone()),
172 SklearsError::Configuration(s) => SklearsError::Configuration(s.clone()),
173 SklearsError::MissingDependency {
174 dependency,
175 feature,
176 } => SklearsError::MissingDependency {
177 dependency: dependency.clone(),
178 feature: feature.clone(),
179 },
180 SklearsError::TraitNotFound(s) => SklearsError::TraitNotFound(s.clone()),
181 SklearsError::AnalysisError(s) => SklearsError::AnalysisError(s.clone()),
182 SklearsError::HardwareError(s) => SklearsError::HardwareError(s.clone()),
183 SklearsError::ResourceAllocationError(s) => {
184 SklearsError::ResourceAllocationError(s.clone())
185 }
186 SklearsError::InvalidConfiguration(s) => SklearsError::InvalidConfiguration(s.clone()),
187 SklearsError::ProcessingError(s) => SklearsError::ProcessingError(s.clone()),
188 SklearsError::ModelError(s) => SklearsError::ModelError(s.clone()),
189 SklearsError::ValidationError(s) => SklearsError::ValidationError(s.clone()),
190 SklearsError::Other(s) => SklearsError::Other(s.clone()),
191 }
192 }
193}
194
195impl From<String> for SklearsError {
197 fn from(error: String) -> Self {
198 SklearsError::Other(error)
199 }
200}
201
202impl From<&str> for SklearsError {
204 fn from(error: &str) -> Self {
205 SklearsError::Other(error.to_string())
206 }
207}
208
209impl From<scirs2_core::ndarray::ShapeError> for SklearsError {
211 fn from(error: scirs2_core::ndarray::ShapeError) -> Self {
212 SklearsError::InvalidInput(format!("Array shape error: {error}"))
213 }
214}
215
216impl From<serde_json::Error> for SklearsError {
218 fn from(error: serde_json::Error) -> Self {
219 SklearsError::SerializationError(format!("JSON serialization error: {error}"))
220 }
221}
222
223pub type Result<T> = std::result::Result<T, SklearsError>;
225
226pub trait ErrorContext<T> {
228 fn context(self, msg: &str) -> Result<T>;
230
231 fn with_context<F>(self, f: F) -> Result<T>
233 where
234 F: FnOnce() -> String;
235
236 fn with_operation(self, operation: &str) -> Result<T>;
238
239 fn with_location(self, file: &str, line: u32) -> Result<T>;
241}
242
243impl<T, E> ErrorContext<T> for std::result::Result<T, E>
244where
245 E: std::error::Error,
246{
247 fn context(self, msg: &str) -> Result<T> {
248 self.map_err(|e| SklearsError::Other(format!("{msg}: {e}")))
249 }
250
251 fn with_context<F>(self, f: F) -> Result<T>
252 where
253 F: FnOnce() -> String,
254 {
255 self.map_err(|e| SklearsError::Other(format!("{}: {e}", f())))
256 }
257
258 fn with_operation(self, operation: &str) -> Result<T> {
259 self.map_err(|e| SklearsError::Other(format!("Operation '{operation}' failed: {e}")))
260 }
261
262 fn with_location(self, file: &str, line: u32) -> Result<T> {
263 self.map_err(|e| SklearsError::Other(format!("Error at {file}:{line}: {e}")))
264 }
265}
266
267#[macro_export]
269macro_rules! error_context {
270 ($result:expr) => {
271 $result.with_location(file!(), line!())
272 };
273 ($result:expr, $msg:expr) => {
274 $result.context($msg).with_location(file!(), line!())
275 };
276}
277
278pub trait SklearnContext<T> {
280 fn fit_context(self, estimator: &str, samples: usize, features: usize) -> Result<T>;
282
283 fn predict_context(self, estimator: &str, samples: usize) -> Result<T>;
285
286 fn transform_context(self, transformer: &str, samples: usize, features: usize) -> Result<T>;
288
289 fn validation_context(self, parameter: &str, value: &str) -> Result<T>;
291}
292
293impl<T, E> SklearnContext<T> for std::result::Result<T, E>
294where
295 E: std::error::Error,
296{
297 fn fit_context(self, estimator: &str, samples: usize, features: usize) -> Result<T> {
298 self.with_context(|| {
299 format!("Failed to fit {estimator} with {samples} samples and {features} features")
300 })
301 }
302
303 fn predict_context(self, estimator: &str, samples: usize) -> Result<T> {
304 self.with_context(|| format!("Failed to predict using {estimator} with {samples} samples"))
305 }
306
307 fn transform_context(self, transformer: &str, samples: usize, features: usize) -> Result<T> {
308 self.with_context(|| {
309 format!("Failed to transform using {transformer} with {samples} samples and {features} features")
310 })
311 }
312
313 fn validation_context(self, parameter: &str, value: &str) -> Result<T> {
314 self.with_context(|| {
315 format!("Validation failed for parameter '{parameter}' with value '{value}'")
316 })
317 }
318}
319
320#[macro_export]
322macro_rules! validate {
323 ($condition:expr, $message:expr) => {
324 if !($condition) {
325 return Err($crate::error::SklearsError::InvalidInput($message.to_string()));
326 }
327 };
328 ($condition:expr, $message:expr, $($arg:tt)*) => {
329 if !($condition) {
330 return Err($crate::error::SklearsError::InvalidInput(format!($message, $($arg)*)));
331 }
332 };
333}
334
335#[derive(Debug)]
337pub struct ErrorChain {
338 errors: Vec<Box<dyn std::error::Error + Send + Sync>>,
339 context: Vec<String>,
340}
341
342impl ErrorChain {
343 pub fn new() -> Self {
345 Self {
346 errors: Vec::new(),
347 context: Vec::new(),
348 }
349 }
350
351 pub fn push_error<E>(mut self, error: E) -> Self
353 where
354 E: std::error::Error + Send + Sync + 'static,
355 {
356 self.errors.push(Box::new(error));
357 self
358 }
359
360 pub fn push_context<S: Into<String>>(mut self, context: S) -> Self {
362 self.context.push(context.into());
363 self
364 }
365
366 pub fn into_error(self) -> SklearsError {
368 let message = if self.context.is_empty() && self.errors.is_empty() {
369 "Unknown error chain".to_string()
370 } else {
371 let context_str = self.context.join(" -> ");
372 let error_str = self
373 .errors
374 .iter()
375 .map(|e| e.to_string())
376 .collect::<Vec<_>>()
377 .join("; ");
378
379 if context_str.is_empty() {
380 error_str
381 } else if error_str.is_empty() {
382 context_str
383 } else {
384 format!("{context_str}: {error_str}")
385 }
386 };
387
388 SklearsError::Other(message)
389 }
390}
391
392impl Default for ErrorChain {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398pub mod validate {
400 use super::*;
401 use crate::types::{Array1, Array2, FloatBounds, Numeric};
402
403 pub fn check_consistent_length<T, U>(x: &Array2<T>, y: &Array1<U>) -> Result<()> {
405 let n_samples_x = x.nrows();
406 let n_samples_y = y.len();
407
408 if n_samples_x != n_samples_y {
409 return Err(SklearsError::ShapeMismatch {
410 expected: "X.shape[0] == y.shape[0]".to_string(),
411 actual: format!("X.shape[0]={n_samples_x}, y.shape[0]={n_samples_y}"),
412 });
413 }
414
415 Ok(())
416 }
417
418 pub fn check_n_features<T>(x: &Array2<T>, expected: usize) -> Result<()> {
420 let actual = x.ncols();
421 if actual != expected {
422 return Err(SklearsError::FeatureMismatch { expected, actual });
423 }
424 Ok(())
425 }
426
427 pub fn check_finite<T: FloatBounds>(value: T, name: &str) -> Result<()> {
429 if !value.is_finite() {
430 return Err(SklearsError::InvalidParameter {
431 name: name.to_string(),
432 reason: "must be finite".to_string(),
433 });
434 }
435 Ok(())
436 }
437
438 pub fn check_positive<T: Numeric + PartialOrd>(value: T, name: &str) -> Result<()> {
440 if value <= T::zero() {
441 return Err(SklearsError::InvalidParameter {
442 name: name.to_string(),
443 reason: "must be positive".to_string(),
444 });
445 }
446 Ok(())
447 }
448
449 pub fn check_non_negative<T: Numeric + PartialOrd>(value: T, name: &str) -> Result<()> {
451 if value < T::zero() {
452 return Err(SklearsError::InvalidParameter {
453 name: name.to_string(),
454 reason: "must be non-negative".to_string(),
455 });
456 }
457 Ok(())
458 }
459
460 pub fn check_in_range<T: Numeric + PartialOrd>(
462 value: T,
463 min: T,
464 max: T,
465 name: &str,
466 ) -> Result<()> {
467 if value < min || value > max {
468 return Err(SklearsError::InvalidParameter {
469 name: name.to_string(),
470 reason: format!("must be in range [{min}, {max}]"),
471 });
472 }
473 Ok(())
474 }
475
476 pub fn check_matmul_compatible<T, U>(a: &Array2<T>, b: &Array2<U>) -> Result<()> {
478 if a.ncols() != b.nrows() {
479 return Err(SklearsError::ShapeMismatch {
480 expected: "A.shape[1] == B.shape[0]".to_string(),
481 actual: format!("A.shape[1]={}, B.shape[0]={}", a.ncols(), b.nrows()),
482 });
483 }
484 Ok(())
485 }
486}
487
488#[allow(non_snake_case)]
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn test_error_context() {
495 let result: std::result::Result<(), std::io::Error> = Err(std::io::Error::new(
496 std::io::ErrorKind::NotFound,
497 "file not found",
498 ));
499
500 let with_context = result.context("Failed to read config file");
501 assert!(with_context.is_err());
502 assert!(with_context
503 .unwrap_err()
504 .to_string()
505 .contains("Failed to read config file"));
506 }
507
508 #[test]
509 fn test_error_with_operation() {
510 let result: std::result::Result<(), std::io::Error> = Err(std::io::Error::new(
511 std::io::ErrorKind::PermissionDenied,
512 "access denied",
513 ));
514
515 let with_op = result.with_operation("matrix_multiplication");
516 assert!(with_op.is_err());
517 assert!(with_op
518 .unwrap_err()
519 .to_string()
520 .contains("matrix_multiplication"));
521 }
522
523 #[test]
524 fn test_sklearn_context() {
525 let result: std::result::Result<(), std::io::Error> = Err(std::io::Error::new(
526 std::io::ErrorKind::InvalidInput,
527 "invalid data",
528 ));
529
530 let with_fit_context = result.fit_context("LinearRegression", 100, 5);
531 assert!(with_fit_context.is_err());
532 let error_msg = with_fit_context.unwrap_err().to_string();
533 assert!(error_msg.contains("LinearRegression"));
534 assert!(error_msg.contains("100 samples"));
535 assert!(error_msg.contains("5 features"));
536 }
537
538 #[test]
539 fn test_error_chain() {
540 let chain = ErrorChain::new()
541 .push_context("Model training")
542 .push_context("Data preprocessing")
543 .push_error(std::io::Error::new(
544 std::io::ErrorKind::NotFound,
545 "data file missing",
546 ))
547 .push_context("Feature scaling");
548
549 let error = chain.into_error();
550 let error_str = error.to_string();
551 assert!(error_str.contains("Model training"));
552 assert!(error_str.contains("Data preprocessing"));
553 assert!(error_str.contains("Feature scaling"));
554 assert!(error_str.contains("data file missing"));
555 }
556
557 #[test]
558 fn test_validation_context() {
559 let result: std::result::Result<(), std::io::Error> = Err(std::io::Error::new(
560 std::io::ErrorKind::InvalidInput,
561 "negative value",
562 ));
563
564 let with_validation = result.validation_context("learning_rate", "-0.1");
565 assert!(with_validation.is_err());
566 let error_msg = with_validation.unwrap_err().to_string();
567 assert!(error_msg.contains("learning_rate"));
568 assert!(error_msg.contains("-0.1"));
569 }
570}