1mod core;
21mod general_errors;
22mod index_errors;
23mod shape_errors;
24
25pub use core::{
27 capture_minimal_stack_trace, capture_stack_trace, format_shape, ErrorCategory,
28 ErrorDebugContext, ErrorLocation, ErrorSeverity, ShapeDisplay, ThreadInfo,
29};
30
31pub use general_errors::GeneralError;
32pub use index_errors::IndexError;
33pub use shape_errors::ShapeError;
34
35pub use thiserror::Error;
37
38#[derive(Error, Debug, Clone)]
40pub enum TorshError {
41 #[error(transparent)]
43 Shape(#[from] ShapeError),
44
45 #[error(transparent)]
46 Index(#[from] IndexError),
47
48 #[error(transparent)]
49 General(#[from] GeneralError),
50
51 #[error("{message}")]
53 WithContext {
54 message: String,
55 error_category: ErrorCategory,
56 severity: ErrorSeverity,
57 debug_context: Box<ErrorDebugContext>,
58 #[source]
59 source: Option<Box<TorshError>>,
60 },
61
62 #[error(
64 "Shape mismatch: expected {}, got {}",
65 format_shape(expected),
66 format_shape(got)
67 )]
68 ShapeMismatch {
69 expected: Vec<usize>,
70 got: Vec<usize>,
71 },
72
73 #[error(
74 "Broadcasting error: incompatible shapes {} and {}",
75 format_shape(shape1),
76 format_shape(shape2)
77 )]
78 BroadcastError {
79 shape1: Vec<usize>,
80 shape2: Vec<usize>,
81 },
82
83 #[error("Index out of bounds: index {index} is out of bounds for dimension with size {size}")]
84 IndexOutOfBounds { index: usize, size: usize },
85
86 #[error("Invalid argument: {0}")]
87 InvalidArgument(String),
88
89 #[error("IO error: {0}")]
90 IoError(String),
91
92 #[error("Device mismatch: tensors must be on the same device")]
93 DeviceMismatch,
94
95 #[error("Not implemented: {0}")]
96 NotImplemented(String),
97
98 #[error("Thread synchronization error: {0}")]
100 SynchronizationError(String),
101
102 #[error("Memory allocation failed: {0}")]
103 AllocationError(String),
104
105 #[error("Invalid operation: {0}")]
106 InvalidOperation(String),
107
108 #[error("Numeric conversion error: {0}")]
109 ConversionError(String),
110
111 #[error("Backend error: {0}")]
112 BackendError(String),
113
114 #[error("Invalid shape: {0}")]
115 InvalidShape(String),
116
117 #[error("Runtime error: {0}")]
118 RuntimeError(String),
119
120 #[error("Device error: {0}")]
121 DeviceError(String),
122
123 #[error("Configuration error: {0}")]
124 ConfigError(String),
125
126 #[error("Invalid state: {0}")]
127 InvalidState(String),
128
129 #[error("Unsupported operation '{op}' for data type '{dtype}'")]
130 UnsupportedOperation { op: String, dtype: String },
131
132 #[error("Autograd error: {0}")]
133 AutogradError(String),
134
135 #[error("Compute error: {0}")]
136 ComputeError(String),
137
138 #[error("Serialization error: {0}")]
139 SerializationError(String),
140
141 #[error("Index out of bounds: index {index} is out of bounds for dimension with size {size}")]
142 IndexError { index: usize, size: usize },
143
144 #[error("Iteration error: {0}")]
145 IterationError(String),
146
147 #[error("Other error: {0}")]
148 Other(String),
149}
150
151pub type Result<T> = std::result::Result<T, TorshError>;
153
154impl TorshError {
155 pub fn shape_mismatch(expected: &[usize], got: &[usize]) -> Self {
157 Self::Shape(ShapeError::shape_mismatch(expected, got))
158 }
159
160 pub fn dimension_error(msg: &str, operation: &str) -> Self {
162 Self::General(GeneralError::DimensionError(format!(
163 "{msg} during {operation}"
164 )))
165 }
166
167 pub fn index_error(index: usize, size: usize) -> Self {
169 Self::Index(IndexError::out_of_bounds(index, size))
170 }
171
172 pub fn type_mismatch(expected: &str, actual: &str) -> Self {
174 Self::General(GeneralError::TypeMismatch {
175 expected: expected.to_string(),
176 actual: actual.to_string(),
177 })
178 }
179
180 pub fn dimension_error_with_context(msg: &str, operation: &str) -> Self {
182 Self::General(GeneralError::DimensionError(format!(
183 "{msg} during {operation}"
184 )))
185 }
186
187 pub fn synchronization_error(msg: &str) -> Self {
189 Self::SynchronizationError(msg.to_string())
190 }
191
192 pub fn allocation_error(msg: &str) -> Self {
194 Self::AllocationError(msg.to_string())
195 }
196
197 pub fn invalid_operation(msg: &str) -> Self {
199 Self::InvalidOperation(msg.to_string())
200 }
201
202 pub fn conversion_error(msg: &str) -> Self {
204 Self::ConversionError(msg.to_string())
205 }
206
207 pub fn invalid_argument_with_context(msg: &str, context: &str) -> Self {
209 Self::InvalidArgument(format!("{msg} (context: {context})"))
210 }
211
212 pub fn config_error_with_context(msg: &str, context: &str) -> Self {
214 Self::ConfigError(format!("{msg} (context: {context})"))
215 }
216
217 pub fn dimension_error_simple(msg: String) -> Self {
219 Self::InvalidShape(msg)
220 }
221
222 pub fn shape_mismatch_formatted(expected: &str, got: &str) -> Self {
224 Self::InvalidShape(format!("Shape mismatch: expected {expected}, got {got}"))
225 }
226
227 pub fn operation_error(msg: &str) -> Self {
229 Self::InvalidOperation(msg.to_string())
230 }
231
232 pub fn wrap_with_location(self, location: String) -> Self {
234 self.with_context(&location)
236 }
237
238 pub fn category(&self) -> ErrorCategory {
240 match self {
241 Self::Shape(e) => e.category(),
242 Self::Index(e) => e.category(),
243 Self::General(e) => e.category(),
244 Self::WithContext { error_category, .. } => error_category.clone(),
245 Self::ShapeMismatch { .. } | Self::BroadcastError { .. } => ErrorCategory::Shape,
246 Self::IndexOutOfBounds { .. } => ErrorCategory::UserInput,
247 Self::InvalidArgument(_) => ErrorCategory::UserInput,
248 Self::IoError(_) => ErrorCategory::Io,
249 Self::DeviceMismatch => ErrorCategory::Device,
250 Self::NotImplemented(_) => ErrorCategory::Internal,
251 Self::SynchronizationError(_) => ErrorCategory::Threading,
252 Self::AllocationError(_) => ErrorCategory::Memory,
253 Self::InvalidOperation(_) => ErrorCategory::UserInput,
254 Self::ConversionError(_) => ErrorCategory::DataType,
255 Self::BackendError(_) => ErrorCategory::Device,
256 Self::InvalidShape(_) => ErrorCategory::Shape,
257 Self::RuntimeError(_) => ErrorCategory::Internal,
258 Self::DeviceError(_) => ErrorCategory::Device,
259 Self::ConfigError(_) => ErrorCategory::Configuration,
260 Self::InvalidState(_) => ErrorCategory::Internal,
261 Self::UnsupportedOperation { .. } => ErrorCategory::UserInput,
262 Self::AutogradError(_) => ErrorCategory::Internal,
263 Self::ComputeError(_) => ErrorCategory::Internal,
264 Self::SerializationError(_) => ErrorCategory::Io,
265 Self::IndexError { .. } => ErrorCategory::UserInput,
266 Self::IterationError(_) => ErrorCategory::Internal,
267 Self::Other(_) => ErrorCategory::Internal,
268 }
269 }
270
271 pub fn severity(&self) -> ErrorSeverity {
273 match self {
274 Self::Shape(e) => e.severity(),
275 Self::Index(_) => ErrorSeverity::Medium,
276 Self::General(_) => ErrorSeverity::Low,
277 Self::WithContext { severity, .. } => severity.clone(),
278 Self::ShapeMismatch { .. } | Self::BroadcastError { .. } => ErrorSeverity::High,
279 Self::IndexOutOfBounds { .. } => ErrorSeverity::Medium,
280 Self::DeviceMismatch => ErrorSeverity::High,
281 Self::SynchronizationError(_) => ErrorSeverity::Medium,
282 Self::AllocationError(_) => ErrorSeverity::High,
283 Self::InvalidOperation(_) => ErrorSeverity::Medium,
284 Self::ConversionError(_) => ErrorSeverity::Medium,
285 Self::BackendError(_) => ErrorSeverity::High,
286 Self::InvalidShape(_) => ErrorSeverity::High,
287 Self::RuntimeError(_) => ErrorSeverity::Medium,
288 Self::DeviceError(_) => ErrorSeverity::High,
289 Self::ConfigError(_) => ErrorSeverity::Medium,
290 Self::InvalidState(_) => ErrorSeverity::Medium,
291 Self::UnsupportedOperation { .. } => ErrorSeverity::Medium,
292 Self::AutogradError(_) => ErrorSeverity::Medium,
293 _ => ErrorSeverity::Low,
294 }
295 }
296
297 pub fn with_context(self, message: &str) -> Self {
299 let category = self.category();
300 let severity = self.severity();
301
302 Self::WithContext {
303 message: message.to_string(),
304 error_category: category,
305 severity,
306 debug_context: Box::new(ErrorDebugContext::minimal()),
307 source: Some(Box::new(self)),
308 }
309 }
310}
311
312impl From<std::io::Error> for TorshError {
314 fn from(err: std::io::Error) -> Self {
315 Self::General(GeneralError::IoError(err.to_string()))
316 }
317}
318
319#[cfg(feature = "serialize")]
320impl From<serde_json::Error> for TorshError {
321 fn from(err: serde_json::Error) -> Self {
322 Self::General(GeneralError::SerializationError(err.to_string()))
323 }
324}
325
326impl<T> From<std::sync::PoisonError<T>> for TorshError {
327 fn from(err: std::sync::PoisonError<T>) -> Self {
328 Self::General(GeneralError::SynchronizationError(format!(
329 "Mutex poisoned: {err}"
330 )))
331 }
332}
333
334impl From<std::num::TryFromIntError> for TorshError {
335 fn from(err: std::num::TryFromIntError) -> Self {
336 Self::General(GeneralError::ConversionError(format!(
337 "Integer conversion failed: {err}"
338 )))
339 }
340}
341
342impl From<std::num::ParseIntError> for TorshError {
343 fn from(err: std::num::ParseIntError) -> Self {
344 Self::General(GeneralError::ConversionError(format!(
345 "Integer parsing failed: {err}"
346 )))
347 }
348}
349
350impl From<std::num::ParseFloatError> for TorshError {
351 fn from(err: std::num::ParseFloatError) -> Self {
352 Self::General(GeneralError::ConversionError(format!(
353 "Float parsing failed: {err}"
354 )))
355 }
356}
357
358#[macro_export]
360macro_rules! torsh_error_with_location {
361 ($error_type:expr) => {
362 $crate::error::TorshError::WithContext {
363 message: format!("{}", $error_type),
364 error_category: $error_type.category(),
365 severity: $error_type.severity(),
366 debug_context: $crate::error::ErrorDebugContext::minimal(),
367 source: Some(Box::new($error_type.into())),
368 }
369 };
370 ($message:expr) => {
371 $crate::error::TorshError::WithContext {
372 message: $message.to_string(),
373 error_category: $crate::error::ErrorCategory::Internal,
374 severity: $crate::error::ErrorSeverity::Medium,
375 debug_context: $crate::error::ErrorDebugContext::minimal(),
376 source: None,
377 }
378 };
379}
380
381#[macro_export]
383macro_rules! shape_mismatch_error {
384 ($expected:expr, $got:expr) => {
385 $crate::error::TorshError::shape_mismatch($expected, $got)
386 };
387}
388
389#[macro_export]
391macro_rules! index_error {
392 ($index:expr, $size:expr) => {
393 $crate::error::TorshError::index_error($index, $size)
394 };
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_modular_error_system() {
403 let shape_err = ShapeError::shape_mismatch(&[2, 3], &[3, 2]);
405 let torsh_err: TorshError = shape_err.into();
406 assert_eq!(torsh_err.category(), ErrorCategory::Shape);
407
408 let index_err = IndexError::out_of_bounds(5, 3);
410 let torsh_err: TorshError = index_err.into();
411 assert_eq!(torsh_err.category(), ErrorCategory::UserInput);
412
413 let general_err = GeneralError::InvalidArgument("test".to_string());
415 let torsh_err: TorshError = general_err.into();
416 assert_eq!(torsh_err.category(), ErrorCategory::UserInput);
417 }
418
419 #[test]
420 fn test_backward_compatibility() {
421 let error = TorshError::shape_mismatch(&[2, 3], &[3, 2]);
422 assert_eq!(error.category(), ErrorCategory::Shape);
423 assert_eq!(error.severity(), ErrorSeverity::High);
424 }
425
426 #[test]
427 fn test_error_context() {
428 let base_error = TorshError::InvalidArgument("test".to_string());
429 let contextual_error = base_error.with_context("During tensor operation");
430
431 match contextual_error {
432 TorshError::WithContext { message, .. } => {
433 assert_eq!(message, "During tensor operation");
434 }
435 _ => panic!("Expected WithContext error"),
436 }
437 }
438
439 #[test]
440 fn test_convenience_macros() {
441 let shape_error = shape_mismatch_error!(&[2, 3], &[3, 2]);
442 assert_eq!(shape_error.category(), ErrorCategory::Shape);
443
444 let idx_error = index_error!(5, 3);
445 assert_eq!(idx_error.category(), ErrorCategory::UserInput);
446 }
447
448 #[test]
449 fn test_standard_conversions() {
450 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
451 let torsh_err: TorshError = io_err.into();
452 assert_eq!(torsh_err.category(), ErrorCategory::Io);
453
454 #[cfg(feature = "serialize")]
455 {
456 let json_err = serde_json::from_str::<i32>("invalid json").unwrap_err();
457 let torsh_err: TorshError = json_err.into();
458 assert_eq!(torsh_err.category(), ErrorCategory::Internal);
459 }
460 }
461
462 #[test]
463 fn test_error_severity_ordering() {
464 let low_error = TorshError::NotImplemented("test".to_string());
465 let high_error = TorshError::shape_mismatch(&[2, 3], &[3, 2]);
466
467 assert!(low_error.severity() < high_error.severity());
468 }
469}