Skip to main content

torsh_core/
error.rs

1//! Error types for ToRSh - Clean Modular Interface
2//!
3//! This module provides a unified interface to the ToRSh error system.
4//! All error implementations have been organized into specialized modules
5//! for better maintainability and categorization.
6//!
7//! # Architecture
8//!
9//! The error system is organized into specialized modules:
10//!
11//! - **core**: Error infrastructure, location tracking, debug context
12//! - **shape_errors**: Shape mismatches, broadcasting, tensor operations
13//! - **index_errors**: Index bounds checking and access violations
14//! - **general_errors**: I/O, configuration, runtime, and miscellaneous errors
15//!
16//! All error types are unified through the main `TorshError` enum which provides
17//! backward compatibility while enabling modular error handling.
18
19// Modular error system
20mod core;
21mod general_errors;
22mod index_errors;
23mod shape_errors;
24
25// Re-export the complete modular interface
26pub use core::{
27    capture_minimal_stack_trace, capture_stack_trace, format_shape, ErrorCategory,
28    ErrorDebugContext, ErrorLocation, ErrorSeverity, ShapeDisplay, ThreadInfo,
29};
30
31pub use general_errors::GeneralError;
32pub use index_errors::IndexError;
33pub use shape_errors::ShapeError;
34
35// Re-export the unified error type and result
36pub use thiserror::Error;
37
38/// Main ToRSh error enum - unified interface to all error types
39#[derive(Error, Debug, Clone)]
40pub enum TorshError {
41    // Modular error variants
42    #[error(transparent)]
43    Shape(#[from] ShapeError),
44
45    #[error(transparent)]
46    Index(#[from] IndexError),
47
48    #[error(transparent)]
49    General(#[from] GeneralError),
50
51    // Error with enhanced context information
52    #[error("{message}")]
53    WithContext {
54        message: String,
55        error_category: ErrorCategory,
56        severity: ErrorSeverity,
57        debug_context: Box<ErrorDebugContext>,
58        #[source]
59        source: Option<Box<TorshError>>,
60    },
61
62    // Legacy compatibility variants (for backward compatibility)
63    #[error(
64        "Shape mismatch: expected {}, got {}",
65        format_shape(expected),
66        format_shape(got)
67    )]
68    ShapeMismatch {
69        expected: Vec<usize>,
70        got: Vec<usize>,
71    },
72
73    #[error(
74        "Broadcasting error: incompatible shapes {} and {}",
75        format_shape(shape1),
76        format_shape(shape2)
77    )]
78    BroadcastError {
79        shape1: Vec<usize>,
80        shape2: Vec<usize>,
81    },
82
83    #[error("Index out of bounds: index {index} is out of bounds for dimension with size {size}")]
84    IndexOutOfBounds { index: usize, size: usize },
85
86    #[error("Invalid argument: {0}")]
87    InvalidArgument(String),
88
89    #[error("IO error: {0}")]
90    IoError(String),
91
92    #[error("Device mismatch: tensors must be on the same device")]
93    DeviceMismatch,
94
95    #[error("Not implemented: {0}")]
96    NotImplemented(String),
97
98    // Additional legacy compatibility variants
99    #[error("Thread synchronization error: {0}")]
100    SynchronizationError(String),
101
102    #[error("Memory allocation failed: {0}")]
103    AllocationError(String),
104
105    #[error("Invalid operation: {0}")]
106    InvalidOperation(String),
107
108    #[error("Numeric conversion error: {0}")]
109    ConversionError(String),
110
111    #[error("Backend error: {0}")]
112    BackendError(String),
113
114    #[error("Invalid shape: {0}")]
115    InvalidShape(String),
116
117    #[error("Runtime error: {0}")]
118    RuntimeError(String),
119
120    #[error("Device error: {0}")]
121    DeviceError(String),
122
123    #[error("Configuration error: {0}")]
124    ConfigError(String),
125
126    #[error("Invalid state: {0}")]
127    InvalidState(String),
128
129    #[error("Unsupported operation '{op}' for data type '{dtype}'")]
130    UnsupportedOperation { op: String, dtype: String },
131
132    #[error("Autograd error: {0}")]
133    AutogradError(String),
134
135    #[error("Compute error: {0}")]
136    ComputeError(String),
137
138    #[error("Serialization error: {0}")]
139    SerializationError(String),
140
141    #[error("Index out of bounds: index {index} is out of bounds for dimension with size {size}")]
142    IndexError { index: usize, size: usize },
143
144    #[error(
145        "Invalid dimension: dimension {dim} is out of bounds for tensor with {ndim} dimensions"
146    )]
147    InvalidDimension { dim: usize, ndim: usize },
148
149    #[error("Iteration error: {0}")]
150    IterationError(String),
151
152    #[error("Other error: {0}")]
153    Other(String),
154
155    // CUDA/GPU Backend compatibility variants
156    #[error("Context error: {message}")]
157    Context { message: String },
158
159    #[error("Invalid device: device {device_id}")]
160    InvalidDevice { device_id: usize },
161
162    #[error("Backend operation failed: {0}")]
163    Backend(String),
164
165    #[error("Invalid value: {0}")]
166    InvalidValue(String),
167
168    #[error("Memory error: {message}")]
169    Memory { message: String },
170
171    #[error("cuDNN error: {0}")]
172    CudnnError(String),
173
174    #[error("Unimplemented: {0}")]
175    Unimplemented(String),
176
177    #[error("Initialization error: {0}")]
178    InitializationError(String),
179}
180
181/// Result type alias for ToRSh operations
182pub type Result<T> = std::result::Result<T, TorshError>;
183
184impl TorshError {
185    /// Create a shape mismatch error (backward compatibility)
186    pub fn shape_mismatch(expected: &[usize], got: &[usize]) -> Self {
187        Self::Shape(ShapeError::shape_mismatch(expected, got))
188    }
189
190    /// Create a dimension error during operation
191    pub fn dimension_error(msg: &str, operation: &str) -> Self {
192        Self::General(GeneralError::DimensionError(format!(
193            "{msg} during {operation}"
194        )))
195    }
196
197    /// Create an index error
198    pub fn index_error(index: usize, size: usize) -> Self {
199        Self::Index(IndexError::out_of_bounds(index, size))
200    }
201
202    /// Create a type mismatch error
203    pub fn type_mismatch(expected: &str, actual: &str) -> Self {
204        Self::General(GeneralError::TypeMismatch {
205            expected: expected.to_string(),
206            actual: actual.to_string(),
207        })
208    }
209
210    /// Create a dimension error with context (backward compatibility)
211    pub fn dimension_error_with_context(msg: &str, operation: &str) -> Self {
212        Self::General(GeneralError::DimensionError(format!(
213            "{msg} during {operation}"
214        )))
215    }
216
217    /// Create a synchronization error (backward compatibility)
218    pub fn synchronization_error(msg: &str) -> Self {
219        Self::SynchronizationError(msg.to_string())
220    }
221
222    /// Create an allocation error (backward compatibility)
223    pub fn allocation_error(msg: &str) -> Self {
224        Self::AllocationError(msg.to_string())
225    }
226
227    /// Create an invalid operation error (backward compatibility)
228    pub fn invalid_operation(msg: &str) -> Self {
229        Self::InvalidOperation(msg.to_string())
230    }
231
232    /// Create a conversion error (backward compatibility)
233    pub fn conversion_error(msg: &str) -> Self {
234        Self::ConversionError(msg.to_string())
235    }
236
237    /// Create an invalid argument error with context (backward compatibility)
238    pub fn invalid_argument_with_context(msg: &str, context: &str) -> Self {
239        Self::InvalidArgument(format!("{msg} (context: {context})"))
240    }
241
242    /// Create a config error with context (backward compatibility)
243    pub fn config_error_with_context(msg: &str, context: &str) -> Self {
244        Self::ConfigError(format!("{msg} (context: {context})"))
245    }
246
247    /// Create a dimension error (backward compatibility)
248    pub fn dimension_error_simple(msg: String) -> Self {
249        Self::InvalidShape(msg)
250    }
251
252    /// Create a formatted shape mismatch error (backward compatibility)
253    pub fn shape_mismatch_formatted(expected: &str, got: &str) -> Self {
254        Self::InvalidShape(format!("Shape mismatch: expected {expected}, got {got}"))
255    }
256
257    /// Create an operation error (backward compatibility)
258    pub fn operation_error(msg: &str) -> Self {
259        Self::InvalidOperation(msg.to_string())
260    }
261
262    /// Wrap an error with location information (backward compatibility)
263    pub fn wrap_with_location(self, location: String) -> Self {
264        // For backward compatibility, just add context
265        self.with_context(&location)
266    }
267
268    /// Get the error category
269    pub fn category(&self) -> ErrorCategory {
270        match self {
271            Self::Shape(e) => e.category(),
272            Self::Index(e) => e.category(),
273            Self::General(e) => e.category(),
274            Self::WithContext { error_category, .. } => error_category.clone(),
275            Self::ShapeMismatch { .. } | Self::BroadcastError { .. } => ErrorCategory::Shape,
276            Self::IndexOutOfBounds { .. } => ErrorCategory::UserInput,
277            Self::InvalidArgument(_) => ErrorCategory::UserInput,
278            Self::IoError(_) => ErrorCategory::Io,
279            Self::DeviceMismatch => ErrorCategory::Device,
280            Self::NotImplemented(_) => ErrorCategory::Internal,
281            Self::SynchronizationError(_) => ErrorCategory::Threading,
282            Self::AllocationError(_) => ErrorCategory::Memory,
283            Self::InvalidOperation(_) => ErrorCategory::UserInput,
284            Self::ConversionError(_) => ErrorCategory::DataType,
285            Self::BackendError(_) => ErrorCategory::Device,
286            Self::InvalidShape(_) => ErrorCategory::Shape,
287            Self::RuntimeError(_) => ErrorCategory::Internal,
288            Self::DeviceError(_) => ErrorCategory::Device,
289            Self::ConfigError(_) => ErrorCategory::Configuration,
290            Self::InvalidState(_) => ErrorCategory::Internal,
291            Self::UnsupportedOperation { .. } => ErrorCategory::UserInput,
292            Self::AutogradError(_) => ErrorCategory::Internal,
293            Self::ComputeError(_) => ErrorCategory::Internal,
294            Self::SerializationError(_) => ErrorCategory::Io,
295            Self::IndexError { .. } => ErrorCategory::UserInput,
296            Self::InvalidDimension { .. } => ErrorCategory::UserInput,
297            Self::IterationError(_) => ErrorCategory::Internal,
298            Self::Other(_) => ErrorCategory::Internal,
299            // CUDA/GPU Backend compatibility variants
300            Self::Context { .. } => ErrorCategory::Device,
301            Self::InvalidDevice { .. } => ErrorCategory::Device,
302            Self::Backend(_) => ErrorCategory::Device,
303            Self::InvalidValue(_) => ErrorCategory::UserInput,
304            Self::Memory { .. } => ErrorCategory::Memory,
305            Self::CudnnError(_) => ErrorCategory::Device,
306            Self::Unimplemented(_) => ErrorCategory::Internal,
307            Self::InitializationError(_) => ErrorCategory::Internal,
308        }
309    }
310
311    /// Get the error severity
312    pub fn severity(&self) -> ErrorSeverity {
313        match self {
314            Self::Shape(e) => e.severity(),
315            Self::Index(_) => ErrorSeverity::Medium,
316            Self::General(_) => ErrorSeverity::Low,
317            Self::WithContext { severity, .. } => severity.clone(),
318            Self::ShapeMismatch { .. } | Self::BroadcastError { .. } => ErrorSeverity::High,
319            Self::IndexOutOfBounds { .. } => ErrorSeverity::Medium,
320            Self::InvalidDimension { .. } => ErrorSeverity::Medium,
321            Self::DeviceMismatch => ErrorSeverity::High,
322            Self::SynchronizationError(_) => ErrorSeverity::Medium,
323            Self::AllocationError(_) => ErrorSeverity::High,
324            Self::InvalidOperation(_) => ErrorSeverity::Medium,
325            Self::ConversionError(_) => ErrorSeverity::Medium,
326            Self::BackendError(_) => ErrorSeverity::High,
327            Self::InvalidShape(_) => ErrorSeverity::High,
328            Self::RuntimeError(_) => ErrorSeverity::Medium,
329            Self::DeviceError(_) => ErrorSeverity::High,
330            Self::ConfigError(_) => ErrorSeverity::Medium,
331            Self::InvalidState(_) => ErrorSeverity::Medium,
332            Self::UnsupportedOperation { .. } => ErrorSeverity::Medium,
333            Self::AutogradError(_) => ErrorSeverity::Medium,
334            _ => ErrorSeverity::Low,
335        }
336    }
337
338    /// Add minimal context to an error (lightweight, no backtrace)
339    ///
340    /// Use this for performance-critical paths where error context
341    /// is helpful but backtrace overhead is not justified.
342    ///
343    /// # Example
344    /// ```
345    /// use torsh_core::error::{TorshError, Result};
346    ///
347    /// fn tensor_operation() -> Result<()> {
348    ///     let error = TorshError::InvalidShape("invalid dimensions".to_string())
349    ///         .with_context("during tensor reshape");
350    ///     Err(error)
351    /// }
352    /// ```
353    pub fn with_context(self, message: &str) -> Self {
354        let category = self.category();
355        let severity = self.severity();
356
357        Self::WithContext {
358            message: message.to_string(),
359            error_category: category,
360            severity,
361            debug_context: Box::new(ErrorDebugContext::minimal()),
362            source: Some(Box::new(self)),
363        }
364    }
365
366    /// Add rich context to an error (includes full backtrace)
367    ///
368    /// Use this for debugging and development environments where
369    /// detailed error information is valuable. Captures full backtrace
370    /// and thread information.
371    ///
372    /// # Example
373    /// ```
374    /// use torsh_core::error::{TorshError, Result};
375    ///
376    /// fn critical_operation() -> Result<()> {
377    ///     let error = TorshError::InvalidShape("invalid dimensions".to_string())
378    ///         .with_rich_context("during critical tensor operation");
379    ///     Err(error)
380    /// }
381    /// ```
382    pub fn with_rich_context(self, message: &str) -> Self {
383        let category = self.category();
384        let severity = self.severity();
385
386        Self::WithContext {
387            message: message.to_string(),
388            error_category: category,
389            severity,
390            debug_context: Box::new(ErrorDebugContext::new()),
391            source: Some(Box::new(self)),
392        }
393    }
394
395    /// Add context with custom metadata (minimal backtrace)
396    ///
397    /// Use this to add structured metadata without the overhead
398    /// of a full backtrace. Ideal for operation tracking and debugging.
399    ///
400    /// # Example
401    /// ```
402    /// use torsh_core::error::{TorshError, Result};
403    ///
404    /// fn tensor_add(shape1: &[usize], shape2: &[usize]) -> Result<()> {
405    ///     let error = TorshError::InvalidShape("incompatible shapes".to_string())
406    ///         .with_metadata("during tensor addition")
407    ///         .add_metadata("shape1", &format!("{:?}", shape1))
408    ///         .add_metadata("shape2", &format!("{:?}", shape2));
409    ///     Err(error)
410    /// }
411    /// ```
412    pub fn with_metadata(self, message: &str) -> Self {
413        let category = self.category();
414        let severity = self.severity();
415
416        Self::WithContext {
417            message: message.to_string(),
418            error_category: category,
419            severity,
420            debug_context: Box::new(ErrorDebugContext::minimal()),
421            source: Some(Box::new(self)),
422        }
423    }
424
425    /// Add metadata to an existing error
426    ///
427    /// This method allows adding key-value metadata to enrich error context
428    /// without creating a new error wrapper. If the error is not already
429    /// a `WithContext` variant, it will be converted to one.
430    ///
431    /// # Example
432    /// ```
433    /// use torsh_core::error::{TorshError, Result};
434    ///
435    /// fn process_tensor(name: &str, size: usize) -> Result<()> {
436    ///     let error = TorshError::AllocationError("out of memory".to_string())
437    ///         .add_metadata("tensor_name", name)
438    ///         .add_metadata("requested_size", &size.to_string());
439    ///     Err(error)
440    /// }
441    /// ```
442    pub fn add_metadata(self, key: &str, value: &str) -> Self {
443        match self {
444            Self::WithContext {
445                message,
446                error_category,
447                severity,
448                mut debug_context,
449                source,
450            } => {
451                debug_context
452                    .metadata
453                    .insert(key.to_string(), value.to_string());
454                Self::WithContext {
455                    message,
456                    error_category,
457                    severity,
458                    debug_context,
459                    source,
460                }
461            }
462            other => {
463                let category = other.category();
464                let severity = other.severity();
465                let mut context = ErrorDebugContext::minimal();
466                context.metadata.insert(key.to_string(), value.to_string());
467
468                Self::WithContext {
469                    message: format!("{other}"),
470                    error_category: category,
471                    severity,
472                    debug_context: Box::new(context),
473                    source: Some(Box::new(other)),
474                }
475            }
476        }
477    }
478
479    /// Add shape information as metadata
480    ///
481    /// Convenience method for adding tensor shape information to errors.
482    ///
483    /// # Example
484    /// ```
485    /// use torsh_core::error::{TorshError, Result};
486    ///
487    /// fn validate_shape(actual: &[usize], expected: &[usize]) -> Result<()> {
488    ///     let error = TorshError::shape_mismatch(expected, actual)
489    ///         .add_shape_metadata("actual_shape", actual)
490    ///         .add_shape_metadata("expected_shape", expected);
491    ///     Err(error)
492    /// }
493    /// ```
494    pub fn add_shape_metadata(self, key: &str, shape: &[usize]) -> Self {
495        self.add_metadata(key, &format!("{:?}", shape))
496    }
497
498    /// Add operation name as metadata
499    ///
500    /// Convenience method for tracking which operation caused the error.
501    ///
502    /// # Example
503    /// ```
504    /// use torsh_core::error::{TorshError, Result};
505    ///
506    /// fn matmul(a_shape: &[usize], b_shape: &[usize]) -> Result<()> {
507    ///     let error = TorshError::shape_mismatch(a_shape, b_shape)
508    ///         .with_operation("matmul")
509    ///         .add_shape_metadata("lhs_shape", a_shape)
510    ///         .add_shape_metadata("rhs_shape", b_shape);
511    ///     Err(error)
512    /// }
513    /// ```
514    pub fn with_operation(self, operation: &str) -> Self {
515        self.add_metadata("operation", operation)
516    }
517
518    /// Add device information as metadata
519    ///
520    /// Convenience method for tracking device-related errors.
521    ///
522    /// # Example
523    /// ```
524    /// use torsh_core::error::{TorshError, Result};
525    ///
526    /// fn allocate_on_device(device_id: usize) -> Result<()> {
527    ///     let error = TorshError::DeviceError("allocation failed".to_string())
528    ///         .with_device(device_id)
529    ///         .add_metadata("allocation_type", "tensor");
530    ///     Err(error)
531    /// }
532    /// ```
533    pub fn with_device(self, device_id: usize) -> Self {
534        self.add_metadata("device_id", &device_id.to_string())
535    }
536
537    /// Add dtype information as metadata
538    ///
539    /// Convenience method for tracking data type-related errors.
540    ///
541    /// # Example
542    /// ```
543    /// use torsh_core::error::{TorshError, Result};
544    ///
545    /// fn convert_dtype(from: &str, to: &str) -> Result<()> {
546    ///     let error = TorshError::ConversionError("unsupported conversion".to_string())
547    ///         .add_metadata("from_dtype", from)
548    ///         .add_metadata("to_dtype", to);
549    ///     Err(error)
550    /// }
551    /// ```
552    pub fn with_dtype(self, dtype: &str) -> Self {
553        self.add_metadata("dtype", dtype)
554    }
555
556    /// Get all metadata from the error
557    ///
558    /// Returns an empty map if the error doesn't have metadata.
559    pub fn metadata(&self) -> std::collections::HashMap<String, String> {
560        match self {
561            Self::WithContext { debug_context, .. } => debug_context.metadata.clone(),
562            _ => std::collections::HashMap::new(),
563        }
564    }
565
566    /// Get the error's debug context if available
567    ///
568    /// Returns None if the error is not a `WithContext` variant.
569    pub fn debug_context(&self) -> Option<&ErrorDebugContext> {
570        match self {
571            Self::WithContext { debug_context, .. } => Some(debug_context),
572            _ => None,
573        }
574    }
575
576    /// Format the error with full debug information
577    ///
578    /// This includes metadata, backtrace, and thread information when available.
579    pub fn format_debug(&self) -> String {
580        let mut output = format!("Error: {self}\n");
581
582        if let Some(context) = self.debug_context() {
583            output.push_str("\n");
584            output.push_str(&context.format_debug_info());
585        }
586
587        if let Self::WithContext {
588            source: Some(source),
589            ..
590        } = self
591        {
592            output.push_str("\nCaused by:\n");
593            output.push_str(&format!("  {source}"));
594        }
595
596        output
597    }
598}
599
600// Standard library error conversions
601impl From<std::io::Error> for TorshError {
602    fn from(err: std::io::Error) -> Self {
603        Self::General(GeneralError::IoError(err.to_string()))
604    }
605}
606
607#[cfg(feature = "serialize")]
608impl From<serde_json::Error> for TorshError {
609    fn from(err: serde_json::Error) -> Self {
610        Self::General(GeneralError::SerializationError(err.to_string()))
611    }
612}
613
614impl<T> From<std::sync::PoisonError<T>> for TorshError {
615    fn from(err: std::sync::PoisonError<T>) -> Self {
616        Self::General(GeneralError::SynchronizationError(format!(
617            "Mutex poisoned: {err}"
618        )))
619    }
620}
621
622impl From<std::num::TryFromIntError> for TorshError {
623    fn from(err: std::num::TryFromIntError) -> Self {
624        Self::General(GeneralError::ConversionError(format!(
625            "Integer conversion failed: {err}"
626        )))
627    }
628}
629
630impl From<std::num::ParseIntError> for TorshError {
631    fn from(err: std::num::ParseIntError) -> Self {
632        Self::General(GeneralError::ConversionError(format!(
633            "Integer parsing failed: {err}"
634        )))
635    }
636}
637
638impl From<std::num::ParseFloatError> for TorshError {
639    fn from(err: std::num::ParseFloatError) -> Self {
640        Self::General(GeneralError::ConversionError(format!(
641            "Float parsing failed: {err}"
642        )))
643    }
644}
645
646/// Convenience macros for error creation with location information
647#[macro_export]
648macro_rules! torsh_error_with_location {
649    ($error_type:expr) => {
650        $crate::error::TorshError::WithContext {
651            message: format!("{}", $error_type),
652            error_category: $error_type.category(),
653            severity: $error_type.severity(),
654            debug_context: $crate::error::ErrorDebugContext::minimal(),
655            source: Some(Box::new($error_type.into())),
656        }
657    };
658    ($message:expr) => {
659        $crate::error::TorshError::WithContext {
660            message: $message.to_string(),
661            error_category: $crate::error::ErrorCategory::Internal,
662            severity: $crate::error::ErrorSeverity::Medium,
663            debug_context: $crate::error::ErrorDebugContext::minimal(),
664            source: None,
665        }
666    };
667}
668
669/// Convenience macro for shape mismatch errors
670#[macro_export]
671macro_rules! shape_mismatch_error {
672    ($expected:expr, $got:expr) => {
673        $crate::error::TorshError::shape_mismatch($expected, $got)
674    };
675}
676
677/// Convenience macro for index errors
678#[macro_export]
679macro_rules! index_error {
680    ($index:expr, $size:expr) => {
681        $crate::error::TorshError::index_error($index, $size)
682    };
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688
689    #[test]
690    fn test_modular_error_system() {
691        // Test shape error conversion
692        let shape_err = ShapeError::shape_mismatch(&[2, 3], &[3, 2]);
693        let torsh_err: TorshError = shape_err.into();
694        assert_eq!(torsh_err.category(), ErrorCategory::Shape);
695
696        // Test index error conversion
697        let index_err = IndexError::out_of_bounds(5, 3);
698        let torsh_err: TorshError = index_err.into();
699        assert_eq!(torsh_err.category(), ErrorCategory::UserInput);
700
701        // Test general error conversion
702        let general_err = GeneralError::InvalidArgument("test".to_string());
703        let torsh_err: TorshError = general_err.into();
704        assert_eq!(torsh_err.category(), ErrorCategory::UserInput);
705    }
706
707    #[test]
708    fn test_backward_compatibility() {
709        let error = TorshError::shape_mismatch(&[2, 3], &[3, 2]);
710        assert_eq!(error.category(), ErrorCategory::Shape);
711        assert_eq!(error.severity(), ErrorSeverity::High);
712    }
713
714    #[test]
715    fn test_error_context() {
716        let base_error = TorshError::InvalidArgument("test".to_string());
717        let contextual_error = base_error.with_context("During tensor operation");
718
719        match contextual_error {
720            TorshError::WithContext { message, .. } => {
721                assert_eq!(message, "During tensor operation");
722            }
723            _ => panic!("Expected WithContext error"),
724        }
725    }
726
727    #[test]
728    fn test_convenience_macros() {
729        let shape_error = shape_mismatch_error!(&[2, 3], &[3, 2]);
730        assert_eq!(shape_error.category(), ErrorCategory::Shape);
731
732        let idx_error = index_error!(5, 3);
733        assert_eq!(idx_error.category(), ErrorCategory::UserInput);
734    }
735
736    #[test]
737    fn test_standard_conversions() {
738        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
739        let torsh_err: TorshError = io_err.into();
740        assert_eq!(torsh_err.category(), ErrorCategory::Io);
741
742        #[cfg(feature = "serialize")]
743        {
744            let json_err = serde_json::from_str::<i32>("invalid json").unwrap_err();
745            let torsh_err: TorshError = json_err.into();
746            assert_eq!(torsh_err.category(), ErrorCategory::Internal);
747        }
748    }
749
750    #[test]
751    fn test_error_severity_ordering() {
752        let low_error = TorshError::NotImplemented("test".to_string());
753        let high_error = TorshError::shape_mismatch(&[2, 3], &[3, 2]);
754
755        assert!(low_error.severity() < high_error.severity());
756    }
757
758    #[test]
759    fn test_rich_context() {
760        let error = TorshError::InvalidShape("test error".to_string())
761            .with_rich_context("during tensor operation");
762
763        match error {
764            TorshError::WithContext {
765                message,
766                debug_context,
767                ..
768            } => {
769                assert_eq!(message, "during tensor operation");
770                // Rich context should have backtrace (or a message about it)
771                assert!(debug_context.backtrace.is_some());
772            }
773            _ => panic!("Expected WithContext error"),
774        }
775    }
776
777    #[test]
778    fn test_add_metadata() {
779        let error = TorshError::InvalidShape("test".to_string())
780            .add_metadata("key1", "value1")
781            .add_metadata("key2", "value2");
782
783        let metadata = error.metadata();
784        assert_eq!(metadata.get("key1"), Some(&"value1".to_string()));
785        assert_eq!(metadata.get("key2"), Some(&"value2".to_string()));
786    }
787
788    #[test]
789    fn test_add_shape_metadata() {
790        let shape1 = vec![2, 3, 4];
791        let shape2 = vec![4, 5];
792
793        let error = TorshError::shape_mismatch(&shape1, &shape2)
794            .add_shape_metadata("tensor_a", &shape1)
795            .add_shape_metadata("tensor_b", &shape2);
796
797        let metadata = error.metadata();
798        assert!(metadata.contains_key("tensor_a"));
799        assert!(metadata.contains_key("tensor_b"));
800        assert!(metadata["tensor_a"].contains("2"));
801        assert!(metadata["tensor_b"].contains("4"));
802    }
803
804    #[test]
805    fn test_with_operation() {
806        let error = TorshError::InvalidShape("test".to_string()).with_operation("matmul");
807
808        let metadata = error.metadata();
809        assert_eq!(metadata.get("operation"), Some(&"matmul".to_string()));
810    }
811
812    #[test]
813    fn test_with_device() {
814        let error = TorshError::DeviceError("allocation failed".to_string()).with_device(42);
815
816        let metadata = error.metadata();
817        assert_eq!(metadata.get("device_id"), Some(&"42".to_string()));
818    }
819
820    #[test]
821    fn test_with_dtype() {
822        let error = TorshError::ConversionError("unsupported".to_string()).with_dtype("f32");
823
824        let metadata = error.metadata();
825        assert_eq!(metadata.get("dtype"), Some(&"f32".to_string()));
826    }
827
828    #[test]
829    fn test_chained_metadata() {
830        let error = TorshError::InvalidShape("test".to_string())
831            .with_operation("conv2d")
832            .add_metadata("batch_size", "32")
833            .add_shape_metadata("input_shape", &[32, 3, 224, 224])
834            .with_device(0)
835            .with_dtype("f32");
836
837        let metadata = error.metadata();
838        assert_eq!(metadata.get("operation"), Some(&"conv2d".to_string()));
839        assert_eq!(metadata.get("batch_size"), Some(&"32".to_string()));
840        assert_eq!(metadata.get("device_id"), Some(&"0".to_string()));
841        assert_eq!(metadata.get("dtype"), Some(&"f32".to_string()));
842        assert!(metadata.contains_key("input_shape"));
843    }
844
845    #[test]
846    fn test_format_debug() {
847        let error = TorshError::InvalidShape("test error".to_string())
848            .with_operation("test_op")
849            .add_metadata("key", "value");
850
851        let debug_output = error.format_debug();
852        assert!(debug_output.contains("Error:"));
853        assert!(debug_output.contains("test error"));
854        assert!(debug_output.contains("operation: test_op"));
855        assert!(debug_output.contains("key: value"));
856    }
857
858    #[test]
859    fn test_metadata_on_non_context_error() {
860        // Test that adding metadata to a non-WithContext error converts it
861        let error = TorshError::InvalidArgument("test".to_string());
862        let metadata_before = error.metadata();
863        assert!(metadata_before.is_empty());
864
865        let error_with_metadata = error.add_metadata("new_key", "new_value");
866        let metadata_after = error_with_metadata.metadata();
867        assert_eq!(
868            metadata_after.get("new_key"),
869            Some(&"new_value".to_string())
870        );
871    }
872
873    #[test]
874    fn test_debug_context_availability() {
875        let error_without_context = TorshError::InvalidShape("test".to_string());
876        assert!(error_without_context.debug_context().is_none());
877
878        let error_with_context =
879            TorshError::InvalidShape("test".to_string()).with_context("during operation");
880        assert!(error_with_context.debug_context().is_some());
881    }
882}