1mod core;
21mod general_errors;
22mod index_errors;
23mod shape_errors;
24
25pub use core::{
27 capture_minimal_stack_trace, capture_stack_trace, format_shape, ErrorCategory,
28 ErrorDebugContext, ErrorLocation, ErrorSeverity, ShapeDisplay, ThreadInfo,
29};
30
31pub use general_errors::GeneralError;
32pub use index_errors::IndexError;
33pub use shape_errors::ShapeError;
34
35pub use thiserror::Error;
37
38#[derive(Error, Debug, Clone)]
40pub enum TorshError {
41 #[error(transparent)]
43 Shape(#[from] ShapeError),
44
45 #[error(transparent)]
46 Index(#[from] IndexError),
47
48 #[error(transparent)]
49 General(#[from] GeneralError),
50
51 #[error("{message}")]
53 WithContext {
54 message: String,
55 error_category: ErrorCategory,
56 severity: ErrorSeverity,
57 debug_context: Box<ErrorDebugContext>,
58 #[source]
59 source: Option<Box<TorshError>>,
60 },
61
62 #[error(
64 "Shape mismatch: expected {}, got {}",
65 format_shape(expected),
66 format_shape(got)
67 )]
68 ShapeMismatch {
69 expected: Vec<usize>,
70 got: Vec<usize>,
71 },
72
73 #[error(
74 "Broadcasting error: incompatible shapes {} and {}",
75 format_shape(shape1),
76 format_shape(shape2)
77 )]
78 BroadcastError {
79 shape1: Vec<usize>,
80 shape2: Vec<usize>,
81 },
82
83 #[error("Index out of bounds: index {index} is out of bounds for dimension with size {size}")]
84 IndexOutOfBounds { index: usize, size: usize },
85
86 #[error("Invalid argument: {0}")]
87 InvalidArgument(String),
88
89 #[error("IO error: {0}")]
90 IoError(String),
91
92 #[error("Device mismatch: tensors must be on the same device")]
93 DeviceMismatch,
94
95 #[error("Not implemented: {0}")]
96 NotImplemented(String),
97
98 #[error("Thread synchronization error: {0}")]
100 SynchronizationError(String),
101
102 #[error("Memory allocation failed: {0}")]
103 AllocationError(String),
104
105 #[error("Invalid operation: {0}")]
106 InvalidOperation(String),
107
108 #[error("Numeric conversion error: {0}")]
109 ConversionError(String),
110
111 #[error("Backend error: {0}")]
112 BackendError(String),
113
114 #[error("Invalid shape: {0}")]
115 InvalidShape(String),
116
117 #[error("Runtime error: {0}")]
118 RuntimeError(String),
119
120 #[error("Device error: {0}")]
121 DeviceError(String),
122
123 #[error("Configuration error: {0}")]
124 ConfigError(String),
125
126 #[error("Invalid state: {0}")]
127 InvalidState(String),
128
129 #[error("Unsupported operation '{op}' for data type '{dtype}'")]
130 UnsupportedOperation { op: String, dtype: String },
131
132 #[error("Autograd error: {0}")]
133 AutogradError(String),
134
135 #[error("Compute error: {0}")]
136 ComputeError(String),
137
138 #[error("Serialization error: {0}")]
139 SerializationError(String),
140
141 #[error("Index out of bounds: index {index} is out of bounds for dimension with size {size}")]
142 IndexError { index: usize, size: usize },
143
144 #[error(
145 "Invalid dimension: dimension {dim} is out of bounds for tensor with {ndim} dimensions"
146 )]
147 InvalidDimension { dim: usize, ndim: usize },
148
149 #[error("Iteration error: {0}")]
150 IterationError(String),
151
152 #[error("Other error: {0}")]
153 Other(String),
154
155 #[error("Context error: {message}")]
157 Context { message: String },
158
159 #[error("Invalid device: device {device_id}")]
160 InvalidDevice { device_id: usize },
161
162 #[error("Backend operation failed: {0}")]
163 Backend(String),
164
165 #[error("Invalid value: {0}")]
166 InvalidValue(String),
167
168 #[error("Memory error: {message}")]
169 Memory { message: String },
170
171 #[error("cuDNN error: {0}")]
172 CudnnError(String),
173
174 #[error("Unimplemented: {0}")]
175 Unimplemented(String),
176
177 #[error("Initialization error: {0}")]
178 InitializationError(String),
179}
180
181pub type Result<T> = std::result::Result<T, TorshError>;
183
184impl TorshError {
185 pub fn shape_mismatch(expected: &[usize], got: &[usize]) -> Self {
187 Self::Shape(ShapeError::shape_mismatch(expected, got))
188 }
189
190 pub fn dimension_error(msg: &str, operation: &str) -> Self {
192 Self::General(GeneralError::DimensionError(format!(
193 "{msg} during {operation}"
194 )))
195 }
196
197 pub fn index_error(index: usize, size: usize) -> Self {
199 Self::Index(IndexError::out_of_bounds(index, size))
200 }
201
202 pub fn type_mismatch(expected: &str, actual: &str) -> Self {
204 Self::General(GeneralError::TypeMismatch {
205 expected: expected.to_string(),
206 actual: actual.to_string(),
207 })
208 }
209
210 pub fn dimension_error_with_context(msg: &str, operation: &str) -> Self {
212 Self::General(GeneralError::DimensionError(format!(
213 "{msg} during {operation}"
214 )))
215 }
216
217 pub fn synchronization_error(msg: &str) -> Self {
219 Self::SynchronizationError(msg.to_string())
220 }
221
222 pub fn allocation_error(msg: &str) -> Self {
224 Self::AllocationError(msg.to_string())
225 }
226
227 pub fn invalid_operation(msg: &str) -> Self {
229 Self::InvalidOperation(msg.to_string())
230 }
231
232 pub fn conversion_error(msg: &str) -> Self {
234 Self::ConversionError(msg.to_string())
235 }
236
237 pub fn invalid_argument_with_context(msg: &str, context: &str) -> Self {
239 Self::InvalidArgument(format!("{msg} (context: {context})"))
240 }
241
242 pub fn config_error_with_context(msg: &str, context: &str) -> Self {
244 Self::ConfigError(format!("{msg} (context: {context})"))
245 }
246
247 pub fn dimension_error_simple(msg: String) -> Self {
249 Self::InvalidShape(msg)
250 }
251
252 pub fn shape_mismatch_formatted(expected: &str, got: &str) -> Self {
254 Self::InvalidShape(format!("Shape mismatch: expected {expected}, got {got}"))
255 }
256
257 pub fn operation_error(msg: &str) -> Self {
259 Self::InvalidOperation(msg.to_string())
260 }
261
262 pub fn wrap_with_location(self, location: String) -> Self {
264 self.with_context(&location)
266 }
267
268 pub fn category(&self) -> ErrorCategory {
270 match self {
271 Self::Shape(e) => e.category(),
272 Self::Index(e) => e.category(),
273 Self::General(e) => e.category(),
274 Self::WithContext { error_category, .. } => error_category.clone(),
275 Self::ShapeMismatch { .. } | Self::BroadcastError { .. } => ErrorCategory::Shape,
276 Self::IndexOutOfBounds { .. } => ErrorCategory::UserInput,
277 Self::InvalidArgument(_) => ErrorCategory::UserInput,
278 Self::IoError(_) => ErrorCategory::Io,
279 Self::DeviceMismatch => ErrorCategory::Device,
280 Self::NotImplemented(_) => ErrorCategory::Internal,
281 Self::SynchronizationError(_) => ErrorCategory::Threading,
282 Self::AllocationError(_) => ErrorCategory::Memory,
283 Self::InvalidOperation(_) => ErrorCategory::UserInput,
284 Self::ConversionError(_) => ErrorCategory::DataType,
285 Self::BackendError(_) => ErrorCategory::Device,
286 Self::InvalidShape(_) => ErrorCategory::Shape,
287 Self::RuntimeError(_) => ErrorCategory::Internal,
288 Self::DeviceError(_) => ErrorCategory::Device,
289 Self::ConfigError(_) => ErrorCategory::Configuration,
290 Self::InvalidState(_) => ErrorCategory::Internal,
291 Self::UnsupportedOperation { .. } => ErrorCategory::UserInput,
292 Self::AutogradError(_) => ErrorCategory::Internal,
293 Self::ComputeError(_) => ErrorCategory::Internal,
294 Self::SerializationError(_) => ErrorCategory::Io,
295 Self::IndexError { .. } => ErrorCategory::UserInput,
296 Self::InvalidDimension { .. } => ErrorCategory::UserInput,
297 Self::IterationError(_) => ErrorCategory::Internal,
298 Self::Other(_) => ErrorCategory::Internal,
299 Self::Context { .. } => ErrorCategory::Device,
301 Self::InvalidDevice { .. } => ErrorCategory::Device,
302 Self::Backend(_) => ErrorCategory::Device,
303 Self::InvalidValue(_) => ErrorCategory::UserInput,
304 Self::Memory { .. } => ErrorCategory::Memory,
305 Self::CudnnError(_) => ErrorCategory::Device,
306 Self::Unimplemented(_) => ErrorCategory::Internal,
307 Self::InitializationError(_) => ErrorCategory::Internal,
308 }
309 }
310
311 pub fn severity(&self) -> ErrorSeverity {
313 match self {
314 Self::Shape(e) => e.severity(),
315 Self::Index(_) => ErrorSeverity::Medium,
316 Self::General(_) => ErrorSeverity::Low,
317 Self::WithContext { severity, .. } => severity.clone(),
318 Self::ShapeMismatch { .. } | Self::BroadcastError { .. } => ErrorSeverity::High,
319 Self::IndexOutOfBounds { .. } => ErrorSeverity::Medium,
320 Self::InvalidDimension { .. } => ErrorSeverity::Medium,
321 Self::DeviceMismatch => ErrorSeverity::High,
322 Self::SynchronizationError(_) => ErrorSeverity::Medium,
323 Self::AllocationError(_) => ErrorSeverity::High,
324 Self::InvalidOperation(_) => ErrorSeverity::Medium,
325 Self::ConversionError(_) => ErrorSeverity::Medium,
326 Self::BackendError(_) => ErrorSeverity::High,
327 Self::InvalidShape(_) => ErrorSeverity::High,
328 Self::RuntimeError(_) => ErrorSeverity::Medium,
329 Self::DeviceError(_) => ErrorSeverity::High,
330 Self::ConfigError(_) => ErrorSeverity::Medium,
331 Self::InvalidState(_) => ErrorSeverity::Medium,
332 Self::UnsupportedOperation { .. } => ErrorSeverity::Medium,
333 Self::AutogradError(_) => ErrorSeverity::Medium,
334 _ => ErrorSeverity::Low,
335 }
336 }
337
338 pub fn with_context(self, message: &str) -> Self {
354 let category = self.category();
355 let severity = self.severity();
356
357 Self::WithContext {
358 message: message.to_string(),
359 error_category: category,
360 severity,
361 debug_context: Box::new(ErrorDebugContext::minimal()),
362 source: Some(Box::new(self)),
363 }
364 }
365
366 pub fn with_rich_context(self, message: &str) -> Self {
383 let category = self.category();
384 let severity = self.severity();
385
386 Self::WithContext {
387 message: message.to_string(),
388 error_category: category,
389 severity,
390 debug_context: Box::new(ErrorDebugContext::new()),
391 source: Some(Box::new(self)),
392 }
393 }
394
395 pub fn with_metadata(self, message: &str) -> Self {
413 let category = self.category();
414 let severity = self.severity();
415
416 Self::WithContext {
417 message: message.to_string(),
418 error_category: category,
419 severity,
420 debug_context: Box::new(ErrorDebugContext::minimal()),
421 source: Some(Box::new(self)),
422 }
423 }
424
425 pub fn add_metadata(self, key: &str, value: &str) -> Self {
443 match self {
444 Self::WithContext {
445 message,
446 error_category,
447 severity,
448 mut debug_context,
449 source,
450 } => {
451 debug_context
452 .metadata
453 .insert(key.to_string(), value.to_string());
454 Self::WithContext {
455 message,
456 error_category,
457 severity,
458 debug_context,
459 source,
460 }
461 }
462 other => {
463 let category = other.category();
464 let severity = other.severity();
465 let mut context = ErrorDebugContext::minimal();
466 context.metadata.insert(key.to_string(), value.to_string());
467
468 Self::WithContext {
469 message: format!("{other}"),
470 error_category: category,
471 severity,
472 debug_context: Box::new(context),
473 source: Some(Box::new(other)),
474 }
475 }
476 }
477 }
478
479 pub fn add_shape_metadata(self, key: &str, shape: &[usize]) -> Self {
495 self.add_metadata(key, &format!("{:?}", shape))
496 }
497
498 pub fn with_operation(self, operation: &str) -> Self {
515 self.add_metadata("operation", operation)
516 }
517
518 pub fn with_device(self, device_id: usize) -> Self {
534 self.add_metadata("device_id", &device_id.to_string())
535 }
536
537 pub fn with_dtype(self, dtype: &str) -> Self {
553 self.add_metadata("dtype", dtype)
554 }
555
556 pub fn metadata(&self) -> std::collections::HashMap<String, String> {
560 match self {
561 Self::WithContext { debug_context, .. } => debug_context.metadata.clone(),
562 _ => std::collections::HashMap::new(),
563 }
564 }
565
566 pub fn debug_context(&self) -> Option<&ErrorDebugContext> {
570 match self {
571 Self::WithContext { debug_context, .. } => Some(debug_context),
572 _ => None,
573 }
574 }
575
576 pub fn format_debug(&self) -> String {
580 let mut output = format!("Error: {self}\n");
581
582 if let Some(context) = self.debug_context() {
583 output.push_str("\n");
584 output.push_str(&context.format_debug_info());
585 }
586
587 if let Self::WithContext {
588 source: Some(source),
589 ..
590 } = self
591 {
592 output.push_str("\nCaused by:\n");
593 output.push_str(&format!(" {source}"));
594 }
595
596 output
597 }
598}
599
600impl From<std::io::Error> for TorshError {
602 fn from(err: std::io::Error) -> Self {
603 Self::General(GeneralError::IoError(err.to_string()))
604 }
605}
606
607#[cfg(feature = "serialize")]
608impl From<serde_json::Error> for TorshError {
609 fn from(err: serde_json::Error) -> Self {
610 Self::General(GeneralError::SerializationError(err.to_string()))
611 }
612}
613
614impl<T> From<std::sync::PoisonError<T>> for TorshError {
615 fn from(err: std::sync::PoisonError<T>) -> Self {
616 Self::General(GeneralError::SynchronizationError(format!(
617 "Mutex poisoned: {err}"
618 )))
619 }
620}
621
622impl From<std::num::TryFromIntError> for TorshError {
623 fn from(err: std::num::TryFromIntError) -> Self {
624 Self::General(GeneralError::ConversionError(format!(
625 "Integer conversion failed: {err}"
626 )))
627 }
628}
629
630impl From<std::num::ParseIntError> for TorshError {
631 fn from(err: std::num::ParseIntError) -> Self {
632 Self::General(GeneralError::ConversionError(format!(
633 "Integer parsing failed: {err}"
634 )))
635 }
636}
637
638impl From<std::num::ParseFloatError> for TorshError {
639 fn from(err: std::num::ParseFloatError) -> Self {
640 Self::General(GeneralError::ConversionError(format!(
641 "Float parsing failed: {err}"
642 )))
643 }
644}
645
646#[macro_export]
648macro_rules! torsh_error_with_location {
649 ($error_type:expr) => {
650 $crate::error::TorshError::WithContext {
651 message: format!("{}", $error_type),
652 error_category: $error_type.category(),
653 severity: $error_type.severity(),
654 debug_context: $crate::error::ErrorDebugContext::minimal(),
655 source: Some(Box::new($error_type.into())),
656 }
657 };
658 ($message:expr) => {
659 $crate::error::TorshError::WithContext {
660 message: $message.to_string(),
661 error_category: $crate::error::ErrorCategory::Internal,
662 severity: $crate::error::ErrorSeverity::Medium,
663 debug_context: $crate::error::ErrorDebugContext::minimal(),
664 source: None,
665 }
666 };
667}
668
669#[macro_export]
671macro_rules! shape_mismatch_error {
672 ($expected:expr, $got:expr) => {
673 $crate::error::TorshError::shape_mismatch($expected, $got)
674 };
675}
676
677#[macro_export]
679macro_rules! index_error {
680 ($index:expr, $size:expr) => {
681 $crate::error::TorshError::index_error($index, $size)
682 };
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688
689 #[test]
690 fn test_modular_error_system() {
691 let shape_err = ShapeError::shape_mismatch(&[2, 3], &[3, 2]);
693 let torsh_err: TorshError = shape_err.into();
694 assert_eq!(torsh_err.category(), ErrorCategory::Shape);
695
696 let index_err = IndexError::out_of_bounds(5, 3);
698 let torsh_err: TorshError = index_err.into();
699 assert_eq!(torsh_err.category(), ErrorCategory::UserInput);
700
701 let general_err = GeneralError::InvalidArgument("test".to_string());
703 let torsh_err: TorshError = general_err.into();
704 assert_eq!(torsh_err.category(), ErrorCategory::UserInput);
705 }
706
707 #[test]
708 fn test_backward_compatibility() {
709 let error = TorshError::shape_mismatch(&[2, 3], &[3, 2]);
710 assert_eq!(error.category(), ErrorCategory::Shape);
711 assert_eq!(error.severity(), ErrorSeverity::High);
712 }
713
714 #[test]
715 fn test_error_context() {
716 let base_error = TorshError::InvalidArgument("test".to_string());
717 let contextual_error = base_error.with_context("During tensor operation");
718
719 match contextual_error {
720 TorshError::WithContext { message, .. } => {
721 assert_eq!(message, "During tensor operation");
722 }
723 _ => panic!("Expected WithContext error"),
724 }
725 }
726
727 #[test]
728 fn test_convenience_macros() {
729 let shape_error = shape_mismatch_error!(&[2, 3], &[3, 2]);
730 assert_eq!(shape_error.category(), ErrorCategory::Shape);
731
732 let idx_error = index_error!(5, 3);
733 assert_eq!(idx_error.category(), ErrorCategory::UserInput);
734 }
735
736 #[test]
737 fn test_standard_conversions() {
738 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
739 let torsh_err: TorshError = io_err.into();
740 assert_eq!(torsh_err.category(), ErrorCategory::Io);
741
742 #[cfg(feature = "serialize")]
743 {
744 let json_err = serde_json::from_str::<i32>("invalid json").unwrap_err();
745 let torsh_err: TorshError = json_err.into();
746 assert_eq!(torsh_err.category(), ErrorCategory::Internal);
747 }
748 }
749
750 #[test]
751 fn test_error_severity_ordering() {
752 let low_error = TorshError::NotImplemented("test".to_string());
753 let high_error = TorshError::shape_mismatch(&[2, 3], &[3, 2]);
754
755 assert!(low_error.severity() < high_error.severity());
756 }
757
758 #[test]
759 fn test_rich_context() {
760 let error = TorshError::InvalidShape("test error".to_string())
761 .with_rich_context("during tensor operation");
762
763 match error {
764 TorshError::WithContext {
765 message,
766 debug_context,
767 ..
768 } => {
769 assert_eq!(message, "during tensor operation");
770 assert!(debug_context.backtrace.is_some());
772 }
773 _ => panic!("Expected WithContext error"),
774 }
775 }
776
777 #[test]
778 fn test_add_metadata() {
779 let error = TorshError::InvalidShape("test".to_string())
780 .add_metadata("key1", "value1")
781 .add_metadata("key2", "value2");
782
783 let metadata = error.metadata();
784 assert_eq!(metadata.get("key1"), Some(&"value1".to_string()));
785 assert_eq!(metadata.get("key2"), Some(&"value2".to_string()));
786 }
787
788 #[test]
789 fn test_add_shape_metadata() {
790 let shape1 = vec![2, 3, 4];
791 let shape2 = vec![4, 5];
792
793 let error = TorshError::shape_mismatch(&shape1, &shape2)
794 .add_shape_metadata("tensor_a", &shape1)
795 .add_shape_metadata("tensor_b", &shape2);
796
797 let metadata = error.metadata();
798 assert!(metadata.contains_key("tensor_a"));
799 assert!(metadata.contains_key("tensor_b"));
800 assert!(metadata["tensor_a"].contains("2"));
801 assert!(metadata["tensor_b"].contains("4"));
802 }
803
804 #[test]
805 fn test_with_operation() {
806 let error = TorshError::InvalidShape("test".to_string()).with_operation("matmul");
807
808 let metadata = error.metadata();
809 assert_eq!(metadata.get("operation"), Some(&"matmul".to_string()));
810 }
811
812 #[test]
813 fn test_with_device() {
814 let error = TorshError::DeviceError("allocation failed".to_string()).with_device(42);
815
816 let metadata = error.metadata();
817 assert_eq!(metadata.get("device_id"), Some(&"42".to_string()));
818 }
819
820 #[test]
821 fn test_with_dtype() {
822 let error = TorshError::ConversionError("unsupported".to_string()).with_dtype("f32");
823
824 let metadata = error.metadata();
825 assert_eq!(metadata.get("dtype"), Some(&"f32".to_string()));
826 }
827
828 #[test]
829 fn test_chained_metadata() {
830 let error = TorshError::InvalidShape("test".to_string())
831 .with_operation("conv2d")
832 .add_metadata("batch_size", "32")
833 .add_shape_metadata("input_shape", &[32, 3, 224, 224])
834 .with_device(0)
835 .with_dtype("f32");
836
837 let metadata = error.metadata();
838 assert_eq!(metadata.get("operation"), Some(&"conv2d".to_string()));
839 assert_eq!(metadata.get("batch_size"), Some(&"32".to_string()));
840 assert_eq!(metadata.get("device_id"), Some(&"0".to_string()));
841 assert_eq!(metadata.get("dtype"), Some(&"f32".to_string()));
842 assert!(metadata.contains_key("input_shape"));
843 }
844
845 #[test]
846 fn test_format_debug() {
847 let error = TorshError::InvalidShape("test error".to_string())
848 .with_operation("test_op")
849 .add_metadata("key", "value");
850
851 let debug_output = error.format_debug();
852 assert!(debug_output.contains("Error:"));
853 assert!(debug_output.contains("test error"));
854 assert!(debug_output.contains("operation: test_op"));
855 assert!(debug_output.contains("key: value"));
856 }
857
858 #[test]
859 fn test_metadata_on_non_context_error() {
860 let error = TorshError::InvalidArgument("test".to_string());
862 let metadata_before = error.metadata();
863 assert!(metadata_before.is_empty());
864
865 let error_with_metadata = error.add_metadata("new_key", "new_value");
866 let metadata_after = error_with_metadata.metadata();
867 assert_eq!(
868 metadata_after.get("new_key"),
869 Some(&"new_value".to_string())
870 );
871 }
872
873 #[test]
874 fn test_debug_context_availability() {
875 let error_without_context = TorshError::InvalidShape("test".to_string());
876 assert!(error_without_context.debug_context().is_none());
877
878 let error_with_context =
879 TorshError::InvalidShape("test".to_string()).with_context("during operation");
880 assert!(error_with_context.debug_context().is_some());
881 }
882}