1use crate::{DType, Device};
2use thiserror::Error;
3
4#[derive(Error, Debug, Clone)]
6pub enum TensorError {
7 #[error("Shape mismatch in operation '{operation}': expected {expected}, got {got}")]
8 ShapeMismatch {
9 operation: String,
10 expected: String,
11 got: String,
12 context: Option<ErrorContext>,
13 },
14
15 #[error("Incompatible devices in operation '{operation}': {device1} and {device2}")]
16 DeviceMismatch {
17 operation: String,
18 device1: String,
19 device2: String,
20 context: Option<ErrorContext>,
21 },
22
23 #[error("Operation '{operation}' not supported on device: {device}")]
24 UnsupportedDevice {
25 operation: String,
26 device: String,
27 fallback_available: bool,
28 context: Option<ErrorContext>,
29 },
30
31 #[error("Invalid shape in operation '{operation}': {reason}")]
32 InvalidShape {
33 operation: String,
34 reason: String,
35 shape: Option<Vec<usize>>,
36 context: Option<ErrorContext>,
37 },
38
39 #[error("Invalid axis {axis} in operation '{operation}' for tensor with {ndim} dimensions")]
40 InvalidAxis {
41 operation: String,
42 axis: i32,
43 ndim: usize,
44 context: Option<ErrorContext>,
45 },
46
47 #[error("Gradient computation not enabled for tensor in operation '{operation}'")]
48 GradientNotEnabled {
49 operation: String,
50 suggestion: String,
51 context: Option<ErrorContext>,
52 },
53
54 #[error("Invalid argument in operation '{operation}': {reason}")]
55 InvalidArgument {
56 operation: String,
57 reason: String,
58 context: Option<ErrorContext>,
59 },
60
61 #[error("Memory allocation failed in operation '{operation}': {details}")]
62 AllocationError {
63 operation: String,
64 details: String,
65 requested_bytes: Option<usize>,
66 available_bytes: Option<usize>,
67 context: Option<ErrorContext>,
68 },
69
70 #[error("Operation '{operation}' not supported: {reason}")]
71 UnsupportedOperation {
72 operation: String,
73 reason: String,
74 alternatives: Vec<String>,
75 context: Option<ErrorContext>,
76 },
77
78 #[error("GPU error in operation '{operation}': {details}")]
79 #[cfg(feature = "gpu")]
80 GpuError {
81 operation: String,
82 details: String,
83 gpu_id: Option<usize>,
84 fallback_attempted: bool,
85 context: Option<ErrorContext>,
86 },
87
88 #[error("Device error in operation '{operation}': {details}")]
89 DeviceError {
90 operation: String,
91 details: String,
92 device: String,
93 context: Option<ErrorContext>,
94 },
95
96 #[error("Compute error in operation '{operation}': {details}")]
97 ComputeError {
98 operation: String,
99 details: String,
100 retry_possible: bool,
101 context: Option<ErrorContext>,
102 },
103
104 #[error("BLAS error in operation '{operation}': {details}")]
105 #[cfg(feature = "blas")]
106 BlasError {
107 operation: String,
108 details: String,
109 context: Option<ErrorContext>,
110 },
111
112 #[error("Serialization error in operation '{operation}': {details}")]
113 SerializationError {
114 operation: String,
115 details: String,
116 context: Option<ErrorContext>,
117 },
118
119 #[error("Operation '{operation}' not implemented: {details}")]
120 NotImplemented {
121 operation: String,
122 details: String,
123 planned_version: Option<String>,
124 context: Option<ErrorContext>,
125 },
126
127 #[error("Invalid operation '{operation}': {reason}")]
128 InvalidOperation {
129 operation: String,
130 reason: String,
131 context: Option<ErrorContext>,
132 },
133
134 #[error("Benchmark error in '{operation}': {details}")]
135 BenchmarkError {
136 operation: String,
137 details: String,
138 context: Option<ErrorContext>,
139 },
140
141 #[error("IO error in operation '{operation}': {details}")]
142 IoError {
143 operation: String,
144 details: String,
145 path: Option<String>,
146 context: Option<ErrorContext>,
147 },
148
149 #[error("Numerical error in operation '{operation}': {details}")]
150 NumericalError {
151 operation: String,
152 details: String,
153 suggestions: Vec<String>,
154 context: Option<ErrorContext>,
155 },
156
157 #[error("Resource exhaustion in operation '{operation}': {resource}")]
158 ResourceExhausted {
159 operation: String,
160 resource: String,
161 current_usage: Option<usize>,
162 limit: Option<usize>,
163 context: Option<ErrorContext>,
164 },
165
166 #[error("Timeout in operation '{operation}' after {duration_ms}ms")]
167 Timeout {
168 operation: String,
169 duration_ms: u64,
170 context: Option<ErrorContext>,
171 },
172
173 #[error("Cache operation failed in '{operation}': {details}")]
174 CacheError {
175 operation: String,
176 details: String,
177 recoverable: bool,
178 context: Option<ErrorContext>,
179 },
180
181 #[error("Other error in operation '{operation}': {details}")]
182 Other {
183 operation: String,
184 details: String,
185 context: Option<ErrorContext>,
186 },
187}
188
189#[derive(Debug, Clone)]
191pub struct ErrorContext {
192 pub input_shapes: Vec<Vec<usize>>,
194 pub input_devices: Vec<Device>,
196 pub input_dtypes: Vec<DType>,
198 pub output_shape: Option<Vec<usize>>,
200 pub thread_id: String,
202 pub stack_trace: Option<String>,
204 pub metadata: std::collections::HashMap<String, String>,
206}
207
208#[derive(Debug, Clone)]
210pub enum RecoveryStrategy {
211 None,
213 FallbackToCpu,
215 RetryWithParams(std::collections::HashMap<String, String>),
217 UseAlternative(String),
219 ReducePrecision,
221 FreeMemoryAndRetry,
223}
224
225impl ErrorContext {
226 pub fn new() -> Self {
228 Self {
229 input_shapes: Vec::new(),
230 input_devices: Vec::new(),
231 input_dtypes: Vec::new(),
232 output_shape: None,
233 thread_id: format!("{:?}", std::thread::current().id()),
234 stack_trace: None,
235 metadata: std::collections::HashMap::new(),
236 }
237 }
238
239 pub fn with_input_tensor(mut self, shape: &[usize], device: Device, dtype: DType) -> Self {
241 self.input_shapes.push(shape.to_vec());
242 self.input_devices.push(device);
243 self.input_dtypes.push(dtype);
244 self
245 }
246
247 pub fn with_output_shape(mut self, shape: &[usize]) -> Self {
249 self.output_shape = Some(shape.to_vec());
250 self
251 }
252
253 pub fn with_metadata(mut self, key: String, value: String) -> Self {
255 self.metadata.insert(key, value);
256 self
257 }
258}
259
260impl Default for ErrorContext {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266impl TensorError {
267 pub fn shape_mismatch(operation: &str, expected: &str, got: &str) -> Self {
269 Self::ShapeMismatch {
270 operation: operation.to_string(),
271 expected: expected.to_string(),
272 got: got.to_string(),
273 context: None,
274 }
275 }
276
277 pub fn device_mismatch(operation: &str, device1: &str, device2: &str) -> Self {
279 Self::DeviceMismatch {
280 operation: operation.to_string(),
281 device1: device1.to_string(),
282 device2: device2.to_string(),
283 context: None,
284 }
285 }
286
287 pub fn unsupported_device(operation: &str, device: &str, fallback_available: bool) -> Self {
289 Self::UnsupportedDevice {
290 operation: operation.to_string(),
291 device: device.to_string(),
292 fallback_available,
293 context: None,
294 }
295 }
296
297 #[cfg(feature = "gpu")]
299 pub fn gpu_error(
300 operation: &str,
301 details: &str,
302 gpu_id: Option<usize>,
303 fallback_attempted: bool,
304 ) -> Self {
305 Self::GpuError {
306 operation: operation.to_string(),
307 details: details.to_string(),
308 gpu_id,
309 fallback_attempted,
310 context: None,
311 }
312 }
313
314 pub fn allocation_error(
316 operation: &str,
317 details: &str,
318 requested: Option<usize>,
319 available: Option<usize>,
320 ) -> Self {
321 Self::AllocationError {
322 operation: operation.to_string(),
323 details: details.to_string(),
324 requested_bytes: requested,
325 available_bytes: available,
326 context: None,
327 }
328 }
329
330 pub fn numerical_error(operation: &str, details: &str, suggestions: Vec<String>) -> Self {
332 Self::NumericalError {
333 operation: operation.to_string(),
334 details: details.to_string(),
335 suggestions,
336 context: None,
337 }
338 }
339
340 pub fn invalid_argument(reason: String) -> Self {
342 Self::InvalidArgument {
343 operation: "unknown".to_string(),
344 reason,
345 context: None,
346 }
347 }
348
349 pub fn invalid_argument_op(operation: &str, reason: &str) -> Self {
351 Self::InvalidArgument {
352 operation: operation.to_string(),
353 reason: reason.to_string(),
354 context: None,
355 }
356 }
357
358 pub fn other(details: String) -> Self {
360 Self::Other {
361 operation: "unknown".to_string(),
362 details,
363 context: None,
364 }
365 }
366
367 pub fn other_op(operation: &str, details: &str) -> Self {
369 Self::Other {
370 operation: operation.to_string(),
371 details: details.to_string(),
372 context: None,
373 }
374 }
375
376 pub fn allocation_error_simple(details: String) -> Self {
378 Self::AllocationError {
379 operation: "unknown".to_string(),
380 details,
381 requested_bytes: None,
382 available_bytes: None,
383 context: None,
384 }
385 }
386
387 pub fn unsupported_operation_simple(reason: String) -> Self {
389 Self::UnsupportedOperation {
390 operation: "unknown".to_string(),
391 reason,
392 alternatives: Vec::new(),
393 context: None,
394 }
395 }
396
397 pub fn invalid_shape(operation: &str, expected: &str, got: &str) -> Self {
399 Self::InvalidShape {
400 operation: operation.to_string(),
401 reason: format!("Expected {}, got {}", expected, got),
402 shape: None,
403 context: None,
404 }
405 }
406
407 pub fn invalid_shape_simple(reason: String) -> Self {
409 Self::InvalidShape {
410 operation: "unknown".to_string(),
411 reason,
412 shape: None,
413 context: None,
414 }
415 }
416
417 pub fn device_error_simple(details: String) -> Self {
419 Self::DeviceError {
420 operation: "unknown".to_string(),
421 details,
422 device: "unknown".to_string(),
423 context: None,
424 }
425 }
426
427 pub fn compute_error_simple(details: String) -> Self {
429 Self::ComputeError {
430 operation: "unknown".to_string(),
431 details,
432 retry_possible: false,
433 context: None,
434 }
435 }
436
437 pub fn serialization_error_simple(details: String) -> Self {
439 Self::SerializationError {
440 operation: "unknown".to_string(),
441 details,
442 context: None,
443 }
444 }
445
446 pub fn not_implemented_simple(details: String) -> Self {
448 Self::NotImplemented {
449 operation: "unknown".to_string(),
450 details,
451 planned_version: None,
452 context: None,
453 }
454 }
455
456 pub fn invalid_operation_simple(reason: String) -> Self {
458 Self::InvalidOperation {
459 operation: "unknown".to_string(),
460 reason,
461 context: None,
462 }
463 }
464
465 pub fn benchmark_error_simple(details: String) -> Self {
467 Self::BenchmarkError {
468 operation: "unknown".to_string(),
469 details,
470 context: None,
471 }
472 }
473
474 pub fn io_error_simple(details: String) -> Self {
476 Self::IoError {
477 operation: "unknown".to_string(),
478 details,
479 path: None,
480 context: None,
481 }
482 }
483
484 pub fn resource_exhausted_simple(resource: String) -> Self {
486 Self::ResourceExhausted {
487 operation: "unknown".to_string(),
488 resource,
489 current_usage: None,
490 limit: None,
491 context: None,
492 }
493 }
494
495 pub fn timeout_simple(duration_ms: u64) -> Self {
497 Self::Timeout {
498 operation: "unknown".to_string(),
499 duration_ms,
500 context: None,
501 }
502 }
503
504 pub fn with_context(mut self, context: ErrorContext) -> Self {
506 match &mut self {
507 Self::ShapeMismatch { context: ctx, .. } => *ctx = Some(context),
508 Self::DeviceMismatch { context: ctx, .. } => *ctx = Some(context),
509 Self::UnsupportedDevice { context: ctx, .. } => *ctx = Some(context),
510 Self::InvalidShape { context: ctx, .. } => *ctx = Some(context),
511 Self::InvalidAxis { context: ctx, .. } => *ctx = Some(context),
512 Self::GradientNotEnabled { context: ctx, .. } => *ctx = Some(context),
513 Self::InvalidArgument { context: ctx, .. } => *ctx = Some(context),
514 Self::AllocationError { context: ctx, .. } => *ctx = Some(context),
515 Self::UnsupportedOperation { context: ctx, .. } => *ctx = Some(context),
516 #[cfg(feature = "gpu")]
517 Self::GpuError { context: ctx, .. } => *ctx = Some(context),
518 Self::DeviceError { context: ctx, .. } => *ctx = Some(context),
519 Self::ComputeError { context: ctx, .. } => *ctx = Some(context),
520 #[cfg(feature = "blas")]
521 Self::BlasError { context: ctx, .. } => *ctx = Some(context),
522 Self::SerializationError { context: ctx, .. } => *ctx = Some(context),
523 Self::NotImplemented { context: ctx, .. } => *ctx = Some(context),
524 Self::InvalidOperation { context: ctx, .. } => *ctx = Some(context),
525 Self::BenchmarkError { context: ctx, .. } => *ctx = Some(context),
526 Self::IoError { context: ctx, .. } => *ctx = Some(context),
527 Self::NumericalError { context: ctx, .. } => *ctx = Some(context),
528 Self::ResourceExhausted { context: ctx, .. } => *ctx = Some(context),
529 Self::Timeout { context: ctx, .. } => *ctx = Some(context),
530 Self::CacheError { context: ctx, .. } => *ctx = Some(context),
531 Self::Other { context: ctx, .. } => *ctx = Some(context),
532 }
533 self
534 }
535
536 pub fn operation(&self) -> &str {
538 match self {
539 Self::ShapeMismatch { operation, .. } => operation,
540 Self::DeviceMismatch { operation, .. } => operation,
541 Self::UnsupportedDevice { operation, .. } => operation,
542 Self::InvalidShape { operation, .. } => operation,
543 Self::InvalidAxis { operation, .. } => operation,
544 Self::GradientNotEnabled { operation, .. } => operation,
545 Self::InvalidArgument { operation, .. } => operation,
546 Self::AllocationError { operation, .. } => operation,
547 Self::UnsupportedOperation { operation, .. } => operation,
548 #[cfg(feature = "gpu")]
549 Self::GpuError { operation, .. } => operation,
550 Self::DeviceError { operation, .. } => operation,
551 Self::ComputeError { operation, .. } => operation,
552 #[cfg(feature = "blas")]
553 Self::BlasError { operation, .. } => operation,
554 Self::SerializationError { operation, .. } => operation,
555 Self::NotImplemented { operation, .. } => operation,
556 Self::InvalidOperation { operation, .. } => operation,
557 Self::BenchmarkError { operation, .. } => operation,
558 Self::IoError { operation, .. } => operation,
559 Self::NumericalError { operation, .. } => operation,
560 Self::ResourceExhausted { operation, .. } => operation,
561 Self::Timeout { operation, .. } => operation,
562 Self::CacheError { operation, .. } => operation,
563 Self::Other { operation, .. } => operation,
564 }
565 }
566
567 pub fn supports_fallback(&self) -> bool {
569 match self {
570 Self::UnsupportedDevice {
571 fallback_available, ..
572 } => *fallback_available,
573 #[cfg(feature = "gpu")]
574 Self::GpuError { .. } => true,
575 Self::AllocationError { .. } => true,
576 Self::ComputeError { retry_possible, .. } => *retry_possible,
577 _ => false,
578 }
579 }
580
581 pub fn recovery_strategy(&self) -> RecoveryStrategy {
583 match self {
584 Self::UnsupportedDevice {
585 fallback_available: true,
586 ..
587 } => RecoveryStrategy::FallbackToCpu,
588 #[cfg(feature = "gpu")]
589 Self::GpuError {
590 fallback_attempted: false,
591 ..
592 } => RecoveryStrategy::FallbackToCpu,
593 Self::AllocationError { .. } => RecoveryStrategy::FreeMemoryAndRetry,
594 Self::ComputeError {
595 retry_possible: true,
596 ..
597 } => {
598 let mut params = std::collections::HashMap::new();
599 params.insert("reduce_precision".to_string(), "true".to_string());
600 RecoveryStrategy::RetryWithParams(params)
601 }
602 Self::NumericalError { .. } => RecoveryStrategy::ReducePrecision,
603 _ => RecoveryStrategy::None,
604 }
605 }
606}
607
608pub trait ErrorRecovery<T> {
610 fn recover_with_strategy(self, strategy: RecoveryStrategy) -> Result<T>;
612
613 fn auto_recover(self) -> Result<T>;
615}
616
617impl<T> ErrorRecovery<T> for Result<T> {
618 fn recover_with_strategy(self, _strategy: RecoveryStrategy) -> Result<T> {
619 self
622 }
623
624 fn auto_recover(self) -> Result<T> {
625 match &self {
626 Err(error) if error.supports_fallback() => {
627 let strategy = error.recovery_strategy();
628 self.recover_with_strategy(strategy)
629 }
630 _ => self,
631 }
632 }
633}
634
635pub type Result<T> = std::result::Result<T, TensorError>;
636
637impl From<scirs2_core::ndarray::ShapeError> for TensorError {
639 fn from(err: scirs2_core::ndarray::ShapeError) -> Self {
640 Self::InvalidShape {
641 operation: "tensor_creation".to_string(),
642 reason: format!("Shape error: {err}"),
643 shape: None,
644 context: None,
645 }
646 }
647}
648
649impl From<std::fmt::Error> for TensorError {
651 fn from(err: std::fmt::Error) -> Self {
652 Self::Other {
653 operation: "formatting".to_string(),
654 details: format!("Formatting error: {err}"),
655 context: None,
656 }
657 }
658}