torsh_core/
error.rs

1//! Error types for ToRSh - Clean Modular Interface
2//!
3//! This module provides a unified interface to the ToRSh error system.
4//! All error implementations have been organized into specialized modules
5//! for better maintainability and categorization.
6//!
7//! # Architecture
8//!
9//! The error system is organized into specialized modules:
10//!
11//! - **core**: Error infrastructure, location tracking, debug context
12//! - **shape_errors**: Shape mismatches, broadcasting, tensor operations
13//! - **index_errors**: Index bounds checking and access violations
14//! - **general_errors**: I/O, configuration, runtime, and miscellaneous errors
15//!
16//! All error types are unified through the main `TorshError` enum which provides
17//! backward compatibility while enabling modular error handling.
18
19// Modular error system
20mod core;
21mod general_errors;
22mod index_errors;
23mod shape_errors;
24
25// Re-export the complete modular interface
26pub use core::{
27    capture_minimal_stack_trace, capture_stack_trace, format_shape, ErrorCategory,
28    ErrorDebugContext, ErrorLocation, ErrorSeverity, ShapeDisplay, ThreadInfo,
29};
30
31pub use general_errors::GeneralError;
32pub use index_errors::IndexError;
33pub use shape_errors::ShapeError;
34
35// Re-export the unified error type and result
36pub use thiserror::Error;
37
38/// Main ToRSh error enum - unified interface to all error types
39#[derive(Error, Debug, Clone)]
40pub enum TorshError {
41    // Modular error variants
42    #[error(transparent)]
43    Shape(#[from] ShapeError),
44
45    #[error(transparent)]
46    Index(#[from] IndexError),
47
48    #[error(transparent)]
49    General(#[from] GeneralError),
50
51    // Error with enhanced context information
52    #[error("{message}")]
53    WithContext {
54        message: String,
55        error_category: ErrorCategory,
56        severity: ErrorSeverity,
57        debug_context: Box<ErrorDebugContext>,
58        #[source]
59        source: Option<Box<TorshError>>,
60    },
61
62    // Legacy compatibility variants (for backward compatibility)
63    #[error(
64        "Shape mismatch: expected {}, got {}",
65        format_shape(expected),
66        format_shape(got)
67    )]
68    ShapeMismatch {
69        expected: Vec<usize>,
70        got: Vec<usize>,
71    },
72
73    #[error(
74        "Broadcasting error: incompatible shapes {} and {}",
75        format_shape(shape1),
76        format_shape(shape2)
77    )]
78    BroadcastError {
79        shape1: Vec<usize>,
80        shape2: Vec<usize>,
81    },
82
83    #[error("Index out of bounds: index {index} is out of bounds for dimension with size {size}")]
84    IndexOutOfBounds { index: usize, size: usize },
85
86    #[error("Invalid argument: {0}")]
87    InvalidArgument(String),
88
89    #[error("IO error: {0}")]
90    IoError(String),
91
92    #[error("Device mismatch: tensors must be on the same device")]
93    DeviceMismatch,
94
95    #[error("Not implemented: {0}")]
96    NotImplemented(String),
97
98    // Additional legacy compatibility variants
99    #[error("Thread synchronization error: {0}")]
100    SynchronizationError(String),
101
102    #[error("Memory allocation failed: {0}")]
103    AllocationError(String),
104
105    #[error("Invalid operation: {0}")]
106    InvalidOperation(String),
107
108    #[error("Numeric conversion error: {0}")]
109    ConversionError(String),
110
111    #[error("Backend error: {0}")]
112    BackendError(String),
113
114    #[error("Invalid shape: {0}")]
115    InvalidShape(String),
116
117    #[error("Runtime error: {0}")]
118    RuntimeError(String),
119
120    #[error("Device error: {0}")]
121    DeviceError(String),
122
123    #[error("Configuration error: {0}")]
124    ConfigError(String),
125
126    #[error("Invalid state: {0}")]
127    InvalidState(String),
128
129    #[error("Unsupported operation '{op}' for data type '{dtype}'")]
130    UnsupportedOperation { op: String, dtype: String },
131
132    #[error("Autograd error: {0}")]
133    AutogradError(String),
134
135    #[error("Compute error: {0}")]
136    ComputeError(String),
137
138    #[error("Serialization error: {0}")]
139    SerializationError(String),
140
141    #[error("Index out of bounds: index {index} is out of bounds for dimension with size {size}")]
142    IndexError { index: usize, size: usize },
143
144    #[error("Iteration error: {0}")]
145    IterationError(String),
146
147    #[error("Other error: {0}")]
148    Other(String),
149}
150
151/// Result type alias for ToRSh operations
152pub type Result<T> = std::result::Result<T, TorshError>;
153
154impl TorshError {
155    /// Create a shape mismatch error (backward compatibility)
156    pub fn shape_mismatch(expected: &[usize], got: &[usize]) -> Self {
157        Self::Shape(ShapeError::shape_mismatch(expected, got))
158    }
159
160    /// Create a dimension error during operation
161    pub fn dimension_error(msg: &str, operation: &str) -> Self {
162        Self::General(GeneralError::DimensionError(format!(
163            "{msg} during {operation}"
164        )))
165    }
166
167    /// Create an index error
168    pub fn index_error(index: usize, size: usize) -> Self {
169        Self::Index(IndexError::out_of_bounds(index, size))
170    }
171
172    /// Create a type mismatch error
173    pub fn type_mismatch(expected: &str, actual: &str) -> Self {
174        Self::General(GeneralError::TypeMismatch {
175            expected: expected.to_string(),
176            actual: actual.to_string(),
177        })
178    }
179
180    /// Create a dimension error with context (backward compatibility)
181    pub fn dimension_error_with_context(msg: &str, operation: &str) -> Self {
182        Self::General(GeneralError::DimensionError(format!(
183            "{msg} during {operation}"
184        )))
185    }
186
187    /// Create a synchronization error (backward compatibility)
188    pub fn synchronization_error(msg: &str) -> Self {
189        Self::SynchronizationError(msg.to_string())
190    }
191
192    /// Create an allocation error (backward compatibility)
193    pub fn allocation_error(msg: &str) -> Self {
194        Self::AllocationError(msg.to_string())
195    }
196
197    /// Create an invalid operation error (backward compatibility)
198    pub fn invalid_operation(msg: &str) -> Self {
199        Self::InvalidOperation(msg.to_string())
200    }
201
202    /// Create a conversion error (backward compatibility)
203    pub fn conversion_error(msg: &str) -> Self {
204        Self::ConversionError(msg.to_string())
205    }
206
207    /// Create an invalid argument error with context (backward compatibility)
208    pub fn invalid_argument_with_context(msg: &str, context: &str) -> Self {
209        Self::InvalidArgument(format!("{msg} (context: {context})"))
210    }
211
212    /// Create a config error with context (backward compatibility)
213    pub fn config_error_with_context(msg: &str, context: &str) -> Self {
214        Self::ConfigError(format!("{msg} (context: {context})"))
215    }
216
217    /// Create a dimension error (backward compatibility)
218    pub fn dimension_error_simple(msg: String) -> Self {
219        Self::InvalidShape(msg)
220    }
221
222    /// Create a formatted shape mismatch error (backward compatibility)
223    pub fn shape_mismatch_formatted(expected: &str, got: &str) -> Self {
224        Self::InvalidShape(format!("Shape mismatch: expected {expected}, got {got}"))
225    }
226
227    /// Create an operation error (backward compatibility)
228    pub fn operation_error(msg: &str) -> Self {
229        Self::InvalidOperation(msg.to_string())
230    }
231
232    /// Wrap an error with location information (backward compatibility)
233    pub fn wrap_with_location(self, location: String) -> Self {
234        // For backward compatibility, just add context
235        self.with_context(&location)
236    }
237
238    /// Get the error category
239    pub fn category(&self) -> ErrorCategory {
240        match self {
241            Self::Shape(e) => e.category(),
242            Self::Index(e) => e.category(),
243            Self::General(e) => e.category(),
244            Self::WithContext { error_category, .. } => error_category.clone(),
245            Self::ShapeMismatch { .. } | Self::BroadcastError { .. } => ErrorCategory::Shape,
246            Self::IndexOutOfBounds { .. } => ErrorCategory::UserInput,
247            Self::InvalidArgument(_) => ErrorCategory::UserInput,
248            Self::IoError(_) => ErrorCategory::Io,
249            Self::DeviceMismatch => ErrorCategory::Device,
250            Self::NotImplemented(_) => ErrorCategory::Internal,
251            Self::SynchronizationError(_) => ErrorCategory::Threading,
252            Self::AllocationError(_) => ErrorCategory::Memory,
253            Self::InvalidOperation(_) => ErrorCategory::UserInput,
254            Self::ConversionError(_) => ErrorCategory::DataType,
255            Self::BackendError(_) => ErrorCategory::Device,
256            Self::InvalidShape(_) => ErrorCategory::Shape,
257            Self::RuntimeError(_) => ErrorCategory::Internal,
258            Self::DeviceError(_) => ErrorCategory::Device,
259            Self::ConfigError(_) => ErrorCategory::Configuration,
260            Self::InvalidState(_) => ErrorCategory::Internal,
261            Self::UnsupportedOperation { .. } => ErrorCategory::UserInput,
262            Self::AutogradError(_) => ErrorCategory::Internal,
263            Self::ComputeError(_) => ErrorCategory::Internal,
264            Self::SerializationError(_) => ErrorCategory::Io,
265            Self::IndexError { .. } => ErrorCategory::UserInput,
266            Self::IterationError(_) => ErrorCategory::Internal,
267            Self::Other(_) => ErrorCategory::Internal,
268        }
269    }
270
271    /// Get the error severity
272    pub fn severity(&self) -> ErrorSeverity {
273        match self {
274            Self::Shape(e) => e.severity(),
275            Self::Index(_) => ErrorSeverity::Medium,
276            Self::General(_) => ErrorSeverity::Low,
277            Self::WithContext { severity, .. } => severity.clone(),
278            Self::ShapeMismatch { .. } | Self::BroadcastError { .. } => ErrorSeverity::High,
279            Self::IndexOutOfBounds { .. } => ErrorSeverity::Medium,
280            Self::DeviceMismatch => ErrorSeverity::High,
281            Self::SynchronizationError(_) => ErrorSeverity::Medium,
282            Self::AllocationError(_) => ErrorSeverity::High,
283            Self::InvalidOperation(_) => ErrorSeverity::Medium,
284            Self::ConversionError(_) => ErrorSeverity::Medium,
285            Self::BackendError(_) => ErrorSeverity::High,
286            Self::InvalidShape(_) => ErrorSeverity::High,
287            Self::RuntimeError(_) => ErrorSeverity::Medium,
288            Self::DeviceError(_) => ErrorSeverity::High,
289            Self::ConfigError(_) => ErrorSeverity::Medium,
290            Self::InvalidState(_) => ErrorSeverity::Medium,
291            Self::UnsupportedOperation { .. } => ErrorSeverity::Medium,
292            Self::AutogradError(_) => ErrorSeverity::Medium,
293            _ => ErrorSeverity::Low,
294        }
295    }
296
297    /// Add context to an error
298    pub fn with_context(self, message: &str) -> Self {
299        let category = self.category();
300        let severity = self.severity();
301
302        Self::WithContext {
303            message: message.to_string(),
304            error_category: category,
305            severity,
306            debug_context: Box::new(ErrorDebugContext::minimal()),
307            source: Some(Box::new(self)),
308        }
309    }
310}
311
312// Standard library error conversions
313impl From<std::io::Error> for TorshError {
314    fn from(err: std::io::Error) -> Self {
315        Self::General(GeneralError::IoError(err.to_string()))
316    }
317}
318
319#[cfg(feature = "serialize")]
320impl From<serde_json::Error> for TorshError {
321    fn from(err: serde_json::Error) -> Self {
322        Self::General(GeneralError::SerializationError(err.to_string()))
323    }
324}
325
326impl<T> From<std::sync::PoisonError<T>> for TorshError {
327    fn from(err: std::sync::PoisonError<T>) -> Self {
328        Self::General(GeneralError::SynchronizationError(format!(
329            "Mutex poisoned: {err}"
330        )))
331    }
332}
333
334impl From<std::num::TryFromIntError> for TorshError {
335    fn from(err: std::num::TryFromIntError) -> Self {
336        Self::General(GeneralError::ConversionError(format!(
337            "Integer conversion failed: {err}"
338        )))
339    }
340}
341
342impl From<std::num::ParseIntError> for TorshError {
343    fn from(err: std::num::ParseIntError) -> Self {
344        Self::General(GeneralError::ConversionError(format!(
345            "Integer parsing failed: {err}"
346        )))
347    }
348}
349
350impl From<std::num::ParseFloatError> for TorshError {
351    fn from(err: std::num::ParseFloatError) -> Self {
352        Self::General(GeneralError::ConversionError(format!(
353            "Float parsing failed: {err}"
354        )))
355    }
356}
357
358/// Convenience macros for error creation with location information
359#[macro_export]
360macro_rules! torsh_error_with_location {
361    ($error_type:expr) => {
362        $crate::error::TorshError::WithContext {
363            message: format!("{}", $error_type),
364            error_category: $error_type.category(),
365            severity: $error_type.severity(),
366            debug_context: $crate::error::ErrorDebugContext::minimal(),
367            source: Some(Box::new($error_type.into())),
368        }
369    };
370    ($message:expr) => {
371        $crate::error::TorshError::WithContext {
372            message: $message.to_string(),
373            error_category: $crate::error::ErrorCategory::Internal,
374            severity: $crate::error::ErrorSeverity::Medium,
375            debug_context: $crate::error::ErrorDebugContext::minimal(),
376            source: None,
377        }
378    };
379}
380
381/// Convenience macro for shape mismatch errors
382#[macro_export]
383macro_rules! shape_mismatch_error {
384    ($expected:expr, $got:expr) => {
385        $crate::error::TorshError::shape_mismatch($expected, $got)
386    };
387}
388
389/// Convenience macro for index errors
390#[macro_export]
391macro_rules! index_error {
392    ($index:expr, $size:expr) => {
393        $crate::error::TorshError::index_error($index, $size)
394    };
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_modular_error_system() {
403        // Test shape error conversion
404        let shape_err = ShapeError::shape_mismatch(&[2, 3], &[3, 2]);
405        let torsh_err: TorshError = shape_err.into();
406        assert_eq!(torsh_err.category(), ErrorCategory::Shape);
407
408        // Test index error conversion
409        let index_err = IndexError::out_of_bounds(5, 3);
410        let torsh_err: TorshError = index_err.into();
411        assert_eq!(torsh_err.category(), ErrorCategory::UserInput);
412
413        // Test general error conversion
414        let general_err = GeneralError::InvalidArgument("test".to_string());
415        let torsh_err: TorshError = general_err.into();
416        assert_eq!(torsh_err.category(), ErrorCategory::UserInput);
417    }
418
419    #[test]
420    fn test_backward_compatibility() {
421        let error = TorshError::shape_mismatch(&[2, 3], &[3, 2]);
422        assert_eq!(error.category(), ErrorCategory::Shape);
423        assert_eq!(error.severity(), ErrorSeverity::High);
424    }
425
426    #[test]
427    fn test_error_context() {
428        let base_error = TorshError::InvalidArgument("test".to_string());
429        let contextual_error = base_error.with_context("During tensor operation");
430
431        match contextual_error {
432            TorshError::WithContext { message, .. } => {
433                assert_eq!(message, "During tensor operation");
434            }
435            _ => panic!("Expected WithContext error"),
436        }
437    }
438
439    #[test]
440    fn test_convenience_macros() {
441        let shape_error = shape_mismatch_error!(&[2, 3], &[3, 2]);
442        assert_eq!(shape_error.category(), ErrorCategory::Shape);
443
444        let idx_error = index_error!(5, 3);
445        assert_eq!(idx_error.category(), ErrorCategory::UserInput);
446    }
447
448    #[test]
449    fn test_standard_conversions() {
450        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
451        let torsh_err: TorshError = io_err.into();
452        assert_eq!(torsh_err.category(), ErrorCategory::Io);
453
454        #[cfg(feature = "serialize")]
455        {
456            let json_err = serde_json::from_str::<i32>("invalid json").unwrap_err();
457            let torsh_err: TorshError = json_err.into();
458            assert_eq!(torsh_err.category(), ErrorCategory::Internal);
459        }
460    }
461
462    #[test]
463    fn test_error_severity_ordering() {
464        let low_error = TorshError::NotImplemented("test".to_string());
465        let high_error = TorshError::shape_mismatch(&[2, 3], &[3, 2]);
466
467        assert!(low_error.severity() < high_error.severity());
468    }
469}