Skip to main content

tenflowers_core/
error.rs

1use crate::{DType, Device};
2use thiserror::Error;
3
4/// Enhanced error handling with contextual information and recovery strategies
5#[derive(Error, Debug, Clone)]
6pub enum TensorError {
7    #[error("Shape mismatch in operation '{operation}': expected {expected}, got {got}")]
8    ShapeMismatch {
9        operation: String,
10        expected: String,
11        got: String,
12        context: Option<ErrorContext>,
13    },
14
15    #[error("Incompatible devices in operation '{operation}': {device1} and {device2}")]
16    DeviceMismatch {
17        operation: String,
18        device1: String,
19        device2: String,
20        context: Option<ErrorContext>,
21    },
22
23    #[error("Operation '{operation}' not supported on device: {device}")]
24    UnsupportedDevice {
25        operation: String,
26        device: String,
27        fallback_available: bool,
28        context: Option<ErrorContext>,
29    },
30
31    #[error("Invalid shape in operation '{operation}': {reason}")]
32    InvalidShape {
33        operation: String,
34        reason: String,
35        shape: Option<Vec<usize>>,
36        context: Option<ErrorContext>,
37    },
38
39    #[error("Invalid axis {axis} in operation '{operation}' for tensor with {ndim} dimensions")]
40    InvalidAxis {
41        operation: String,
42        axis: i32,
43        ndim: usize,
44        context: Option<ErrorContext>,
45    },
46
47    #[error("Gradient computation not enabled for tensor in operation '{operation}'")]
48    GradientNotEnabled {
49        operation: String,
50        suggestion: String,
51        context: Option<ErrorContext>,
52    },
53
54    #[error("Invalid argument in operation '{operation}': {reason}")]
55    InvalidArgument {
56        operation: String,
57        reason: String,
58        context: Option<ErrorContext>,
59    },
60
61    #[error("Memory allocation failed in operation '{operation}': {details}")]
62    AllocationError {
63        operation: String,
64        details: String,
65        requested_bytes: Option<usize>,
66        available_bytes: Option<usize>,
67        context: Option<ErrorContext>,
68    },
69
70    #[error("Operation '{operation}' not supported: {reason}")]
71    UnsupportedOperation {
72        operation: String,
73        reason: String,
74        alternatives: Vec<String>,
75        context: Option<ErrorContext>,
76    },
77
78    #[error("GPU error in operation '{operation}': {details}")]
79    #[cfg(feature = "gpu")]
80    GpuError {
81        operation: String,
82        details: String,
83        gpu_id: Option<usize>,
84        fallback_attempted: bool,
85        context: Option<ErrorContext>,
86    },
87
88    #[error("Device error in operation '{operation}': {details}")]
89    DeviceError {
90        operation: String,
91        details: String,
92        device: String,
93        context: Option<ErrorContext>,
94    },
95
96    #[error("Compute error in operation '{operation}': {details}")]
97    ComputeError {
98        operation: String,
99        details: String,
100        retry_possible: bool,
101        context: Option<ErrorContext>,
102    },
103
104    #[error("BLAS error in operation '{operation}': {details}")]
105    #[cfg(feature = "blas")]
106    BlasError {
107        operation: String,
108        details: String,
109        context: Option<ErrorContext>,
110    },
111
112    #[error("Serialization error in operation '{operation}': {details}")]
113    SerializationError {
114        operation: String,
115        details: String,
116        context: Option<ErrorContext>,
117    },
118
119    #[error("Operation '{operation}' not implemented: {details}")]
120    NotImplemented {
121        operation: String,
122        details: String,
123        planned_version: Option<String>,
124        context: Option<ErrorContext>,
125    },
126
127    #[error("Invalid operation '{operation}': {reason}")]
128    InvalidOperation {
129        operation: String,
130        reason: String,
131        context: Option<ErrorContext>,
132    },
133
134    #[error("Benchmark error in '{operation}': {details}")]
135    BenchmarkError {
136        operation: String,
137        details: String,
138        context: Option<ErrorContext>,
139    },
140
141    #[error("IO error in operation '{operation}': {details}")]
142    IoError {
143        operation: String,
144        details: String,
145        path: Option<String>,
146        context: Option<ErrorContext>,
147    },
148
149    #[error("Numerical error in operation '{operation}': {details}")]
150    NumericalError {
151        operation: String,
152        details: String,
153        suggestions: Vec<String>,
154        context: Option<ErrorContext>,
155    },
156
157    #[error("Resource exhaustion in operation '{operation}': {resource}")]
158    ResourceExhausted {
159        operation: String,
160        resource: String,
161        current_usage: Option<usize>,
162        limit: Option<usize>,
163        context: Option<ErrorContext>,
164    },
165
166    #[error("Timeout in operation '{operation}' after {duration_ms}ms")]
167    Timeout {
168        operation: String,
169        duration_ms: u64,
170        context: Option<ErrorContext>,
171    },
172
173    #[error("Cache operation failed in '{operation}': {details}")]
174    CacheError {
175        operation: String,
176        details: String,
177        recoverable: bool,
178        context: Option<ErrorContext>,
179    },
180
181    #[error("Other error in operation '{operation}': {details}")]
182    Other {
183        operation: String,
184        details: String,
185        context: Option<ErrorContext>,
186    },
187}
188
189/// Additional context information for errors
190#[derive(Debug, Clone)]
191pub struct ErrorContext {
192    /// Input tensor shapes
193    pub input_shapes: Vec<Vec<usize>>,
194    /// Input tensor devices
195    pub input_devices: Vec<Device>,
196    /// Input tensor data types
197    pub input_dtypes: Vec<DType>,
198    /// Output shape (if applicable)
199    pub output_shape: Option<Vec<usize>>,
200    /// Thread ID where error occurred
201    pub thread_id: String,
202    /// Stack trace (if available)
203    pub stack_trace: Option<String>,
204    /// Additional metadata
205    pub metadata: std::collections::HashMap<String, String>,
206}
207
208/// Recovery strategy for handling errors
209#[derive(Debug, Clone)]
210pub enum RecoveryStrategy {
211    /// No recovery possible
212    None,
213    /// Fallback to CPU execution
214    FallbackToCpu,
215    /// Retry with different parameters
216    RetryWithParams(std::collections::HashMap<String, String>),
217    /// Use alternative algorithm
218    UseAlternative(String),
219    /// Reduce precision
220    ReducePrecision,
221    /// Free memory and retry
222    FreeMemoryAndRetry,
223}
224
225impl ErrorContext {
226    /// Create a new error context
227    pub fn new() -> Self {
228        Self {
229            input_shapes: Vec::new(),
230            input_devices: Vec::new(),
231            input_dtypes: Vec::new(),
232            output_shape: None,
233            thread_id: format!("{:?}", std::thread::current().id()),
234            stack_trace: None,
235            metadata: std::collections::HashMap::new(),
236        }
237    }
238
239    /// Add input tensor information
240    pub fn with_input_tensor(mut self, shape: &[usize], device: Device, dtype: DType) -> Self {
241        self.input_shapes.push(shape.to_vec());
242        self.input_devices.push(device);
243        self.input_dtypes.push(dtype);
244        self
245    }
246
247    /// Add output shape information
248    pub fn with_output_shape(mut self, shape: &[usize]) -> Self {
249        self.output_shape = Some(shape.to_vec());
250        self
251    }
252
253    /// Add metadata
254    pub fn with_metadata(mut self, key: String, value: String) -> Self {
255        self.metadata.insert(key, value);
256        self
257    }
258}
259
260impl Default for ErrorContext {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266impl TensorError {
267    /// Create a shape mismatch error with context
268    pub fn shape_mismatch(operation: &str, expected: &str, got: &str) -> Self {
269        Self::ShapeMismatch {
270            operation: operation.to_string(),
271            expected: expected.to_string(),
272            got: got.to_string(),
273            context: None,
274        }
275    }
276
277    /// Create a device mismatch error with context
278    pub fn device_mismatch(operation: &str, device1: &str, device2: &str) -> Self {
279        Self::DeviceMismatch {
280            operation: operation.to_string(),
281            device1: device1.to_string(),
282            device2: device2.to_string(),
283            context: None,
284        }
285    }
286
287    /// Create an unsupported device error with fallback information
288    pub fn unsupported_device(operation: &str, device: &str, fallback_available: bool) -> Self {
289        Self::UnsupportedDevice {
290            operation: operation.to_string(),
291            device: device.to_string(),
292            fallback_available,
293            context: None,
294        }
295    }
296
297    /// Create a GPU error with fallback information
298    #[cfg(feature = "gpu")]
299    pub fn gpu_error(
300        operation: &str,
301        details: &str,
302        gpu_id: Option<usize>,
303        fallback_attempted: bool,
304    ) -> Self {
305        Self::GpuError {
306            operation: operation.to_string(),
307            details: details.to_string(),
308            gpu_id,
309            fallback_attempted,
310            context: None,
311        }
312    }
313
314    /// Create an allocation error with memory information
315    pub fn allocation_error(
316        operation: &str,
317        details: &str,
318        requested: Option<usize>,
319        available: Option<usize>,
320    ) -> Self {
321        Self::AllocationError {
322            operation: operation.to_string(),
323            details: details.to_string(),
324            requested_bytes: requested,
325            available_bytes: available,
326            context: None,
327        }
328    }
329
330    /// Create a numerical error with suggestions
331    pub fn numerical_error(operation: &str, details: &str, suggestions: Vec<String>) -> Self {
332        Self::NumericalError {
333            operation: operation.to_string(),
334            details: details.to_string(),
335            suggestions,
336            context: None,
337        }
338    }
339
340    /// Create an invalid argument error (for backward compatibility)
341    pub fn invalid_argument(reason: String) -> Self {
342        Self::InvalidArgument {
343            operation: "unknown".to_string(),
344            reason,
345            context: None,
346        }
347    }
348
349    /// Create an invalid argument error with operation context
350    pub fn invalid_argument_op(operation: &str, reason: &str) -> Self {
351        Self::InvalidArgument {
352            operation: operation.to_string(),
353            reason: reason.to_string(),
354            context: None,
355        }
356    }
357
358    /// Create a generic "other" error (for backward compatibility)
359    pub fn other(details: String) -> Self {
360        Self::Other {
361            operation: "unknown".to_string(),
362            details,
363            context: None,
364        }
365    }
366
367    /// Create a generic "other" error with operation context
368    pub fn other_op(operation: &str, details: &str) -> Self {
369        Self::Other {
370            operation: operation.to_string(),
371            details: details.to_string(),
372            context: None,
373        }
374    }
375
376    /// Create an allocation error (for backward compatibility)
377    pub fn allocation_error_simple(details: String) -> Self {
378        Self::AllocationError {
379            operation: "unknown".to_string(),
380            details,
381            requested_bytes: None,
382            available_bytes: None,
383            context: None,
384        }
385    }
386
387    /// Create an unsupported operation error (for backward compatibility)
388    pub fn unsupported_operation_simple(reason: String) -> Self {
389        Self::UnsupportedOperation {
390            operation: "unknown".to_string(),
391            reason,
392            alternatives: Vec::new(),
393            context: None,
394        }
395    }
396
397    /// Create an invalid shape error with operation context
398    pub fn invalid_shape(operation: &str, expected: &str, got: &str) -> Self {
399        Self::InvalidShape {
400            operation: operation.to_string(),
401            reason: format!("Expected {}, got {}", expected, got),
402            shape: None,
403            context: None,
404        }
405    }
406
407    /// Create an invalid shape error (for backward compatibility)
408    pub fn invalid_shape_simple(reason: String) -> Self {
409        Self::InvalidShape {
410            operation: "unknown".to_string(),
411            reason,
412            shape: None,
413            context: None,
414        }
415    }
416
417    /// Create a device error (for backward compatibility)
418    pub fn device_error_simple(details: String) -> Self {
419        Self::DeviceError {
420            operation: "unknown".to_string(),
421            details,
422            device: "unknown".to_string(),
423            context: None,
424        }
425    }
426
427    /// Create a compute error (for backward compatibility)
428    pub fn compute_error_simple(details: String) -> Self {
429        Self::ComputeError {
430            operation: "unknown".to_string(),
431            details,
432            retry_possible: false,
433            context: None,
434        }
435    }
436
437    /// Create a serialization error (for backward compatibility)
438    pub fn serialization_error_simple(details: String) -> Self {
439        Self::SerializationError {
440            operation: "unknown".to_string(),
441            details,
442            context: None,
443        }
444    }
445
446    /// Create a not implemented error (for backward compatibility)
447    pub fn not_implemented_simple(details: String) -> Self {
448        Self::NotImplemented {
449            operation: "unknown".to_string(),
450            details,
451            planned_version: None,
452            context: None,
453        }
454    }
455
456    /// Create an invalid operation error (for backward compatibility)
457    pub fn invalid_operation_simple(reason: String) -> Self {
458        Self::InvalidOperation {
459            operation: "unknown".to_string(),
460            reason,
461            context: None,
462        }
463    }
464
465    /// Create a benchmark error (for backward compatibility)
466    pub fn benchmark_error_simple(details: String) -> Self {
467        Self::BenchmarkError {
468            operation: "unknown".to_string(),
469            details,
470            context: None,
471        }
472    }
473
474    /// Create an IO error (for backward compatibility)
475    pub fn io_error_simple(details: String) -> Self {
476        Self::IoError {
477            operation: "unknown".to_string(),
478            details,
479            path: None,
480            context: None,
481        }
482    }
483
484    /// Create a resource exhausted error (for backward compatibility)
485    pub fn resource_exhausted_simple(resource: String) -> Self {
486        Self::ResourceExhausted {
487            operation: "unknown".to_string(),
488            resource,
489            current_usage: None,
490            limit: None,
491            context: None,
492        }
493    }
494
495    /// Create a timeout error (for backward compatibility)
496    pub fn timeout_simple(duration_ms: u64) -> Self {
497        Self::Timeout {
498            operation: "unknown".to_string(),
499            duration_ms,
500            context: None,
501        }
502    }
503
504    /// Add context to an existing error
505    pub fn with_context(mut self, context: ErrorContext) -> Self {
506        match &mut self {
507            Self::ShapeMismatch { context: ctx, .. } => *ctx = Some(context),
508            Self::DeviceMismatch { context: ctx, .. } => *ctx = Some(context),
509            Self::UnsupportedDevice { context: ctx, .. } => *ctx = Some(context),
510            Self::InvalidShape { context: ctx, .. } => *ctx = Some(context),
511            Self::InvalidAxis { context: ctx, .. } => *ctx = Some(context),
512            Self::GradientNotEnabled { context: ctx, .. } => *ctx = Some(context),
513            Self::InvalidArgument { context: ctx, .. } => *ctx = Some(context),
514            Self::AllocationError { context: ctx, .. } => *ctx = Some(context),
515            Self::UnsupportedOperation { context: ctx, .. } => *ctx = Some(context),
516            #[cfg(feature = "gpu")]
517            Self::GpuError { context: ctx, .. } => *ctx = Some(context),
518            Self::DeviceError { context: ctx, .. } => *ctx = Some(context),
519            Self::ComputeError { context: ctx, .. } => *ctx = Some(context),
520            #[cfg(feature = "blas")]
521            Self::BlasError { context: ctx, .. } => *ctx = Some(context),
522            Self::SerializationError { context: ctx, .. } => *ctx = Some(context),
523            Self::NotImplemented { context: ctx, .. } => *ctx = Some(context),
524            Self::InvalidOperation { context: ctx, .. } => *ctx = Some(context),
525            Self::BenchmarkError { context: ctx, .. } => *ctx = Some(context),
526            Self::IoError { context: ctx, .. } => *ctx = Some(context),
527            Self::NumericalError { context: ctx, .. } => *ctx = Some(context),
528            Self::ResourceExhausted { context: ctx, .. } => *ctx = Some(context),
529            Self::Timeout { context: ctx, .. } => *ctx = Some(context),
530            Self::CacheError { context: ctx, .. } => *ctx = Some(context),
531            Self::Other { context: ctx, .. } => *ctx = Some(context),
532        }
533        self
534    }
535
536    /// Get the operation name for this error
537    pub fn operation(&self) -> &str {
538        match self {
539            Self::ShapeMismatch { operation, .. } => operation,
540            Self::DeviceMismatch { operation, .. } => operation,
541            Self::UnsupportedDevice { operation, .. } => operation,
542            Self::InvalidShape { operation, .. } => operation,
543            Self::InvalidAxis { operation, .. } => operation,
544            Self::GradientNotEnabled { operation, .. } => operation,
545            Self::InvalidArgument { operation, .. } => operation,
546            Self::AllocationError { operation, .. } => operation,
547            Self::UnsupportedOperation { operation, .. } => operation,
548            #[cfg(feature = "gpu")]
549            Self::GpuError { operation, .. } => operation,
550            Self::DeviceError { operation, .. } => operation,
551            Self::ComputeError { operation, .. } => operation,
552            #[cfg(feature = "blas")]
553            Self::BlasError { operation, .. } => operation,
554            Self::SerializationError { operation, .. } => operation,
555            Self::NotImplemented { operation, .. } => operation,
556            Self::InvalidOperation { operation, .. } => operation,
557            Self::BenchmarkError { operation, .. } => operation,
558            Self::IoError { operation, .. } => operation,
559            Self::NumericalError { operation, .. } => operation,
560            Self::ResourceExhausted { operation, .. } => operation,
561            Self::Timeout { operation, .. } => operation,
562            Self::CacheError { operation, .. } => operation,
563            Self::Other { operation, .. } => operation,
564        }
565    }
566
567    /// Check if this error supports fallback recovery
568    pub fn supports_fallback(&self) -> bool {
569        match self {
570            Self::UnsupportedDevice {
571                fallback_available, ..
572            } => *fallback_available,
573            #[cfg(feature = "gpu")]
574            Self::GpuError { .. } => true,
575            Self::AllocationError { .. } => true,
576            Self::ComputeError { retry_possible, .. } => *retry_possible,
577            _ => false,
578        }
579    }
580
581    /// Get suggested recovery strategy
582    pub fn recovery_strategy(&self) -> RecoveryStrategy {
583        match self {
584            Self::UnsupportedDevice {
585                fallback_available: true,
586                ..
587            } => RecoveryStrategy::FallbackToCpu,
588            #[cfg(feature = "gpu")]
589            Self::GpuError {
590                fallback_attempted: false,
591                ..
592            } => RecoveryStrategy::FallbackToCpu,
593            Self::AllocationError { .. } => RecoveryStrategy::FreeMemoryAndRetry,
594            Self::ComputeError {
595                retry_possible: true,
596                ..
597            } => {
598                let mut params = std::collections::HashMap::new();
599                params.insert("reduce_precision".to_string(), "true".to_string());
600                RecoveryStrategy::RetryWithParams(params)
601            }
602            Self::NumericalError { .. } => RecoveryStrategy::ReducePrecision,
603            _ => RecoveryStrategy::None,
604        }
605    }
606}
607
608/// Trait for automatic error recovery
609pub trait ErrorRecovery<T> {
610    /// Attempt to recover from error using suggested strategy
611    fn recover_with_strategy(self, strategy: RecoveryStrategy) -> Result<T>;
612
613    /// Attempt automatic recovery if possible
614    fn auto_recover(self) -> Result<T>;
615}
616
617impl<T> ErrorRecovery<T> for Result<T> {
618    fn recover_with_strategy(self, _strategy: RecoveryStrategy) -> Result<T> {
619        // For now, just return the original result
620        // In a full implementation, this would attempt recovery based on the strategy
621        self
622    }
623
624    fn auto_recover(self) -> Result<T> {
625        match &self {
626            Err(error) if error.supports_fallback() => {
627                let strategy = error.recovery_strategy();
628                self.recover_with_strategy(strategy)
629            }
630            _ => self,
631        }
632    }
633}
634
635pub type Result<T> = std::result::Result<T, TensorError>;
636
637/// Convert from scirs2_core::ndarray::ShapeError to TensorError
638impl From<scirs2_core::ndarray::ShapeError> for TensorError {
639    fn from(err: scirs2_core::ndarray::ShapeError) -> Self {
640        Self::InvalidShape {
641            operation: "tensor_creation".to_string(),
642            reason: format!("Shape error: {err}"),
643            shape: None,
644            context: None,
645        }
646    }
647}
648
649/// Convert from std::fmt::Error to TensorError
650impl From<std::fmt::Error> for TensorError {
651    fn from(err: std::fmt::Error) -> Self {
652        Self::Other {
653            operation: "formatting".to_string(),
654            details: format!("Formatting error: {err}"),
655            context: None,
656        }
657    }
658}