Skip to main content

tensorlogic_scirs_backend/
error.rs

1//! Comprehensive error types for tensorlogic-scirs-backend.
2//!
3//! This module provides detailed error types for all failure modes in the SciRS2 backend,
4//! including shape mismatches, invalid operations, device errors, and numerical issues.
5
6use std::fmt;
7use thiserror::Error;
8
9/// Main error type for SciRS2 backend operations
10#[derive(Error, Debug)]
11pub enum TlBackendError {
12    /// Shape mismatch between tensors or operations
13    #[error("Shape mismatch: {0}")]
14    ShapeMismatch(ShapeMismatchError),
15
16    /// Invalid einsum specification
17    #[error("Invalid einsum spec: {0}")]
18    InvalidEinsumSpec(String),
19
20    /// Tensor not found in storage
21    #[error("Tensor not found: {0}")]
22    TensorNotFound(String),
23
24    /// Invalid operation or operation parameters
25    #[error("Invalid operation: {0}")]
26    InvalidOperation(String),
27
28    /// Device-related errors (GPU unavailable, memory, etc.)
29    #[error("Device error: {0}")]
30    DeviceError(DeviceError),
31
32    /// Out of memory errors
33    #[error("Out of memory: {0}")]
34    OutOfMemory(String),
35
36    /// Numerical stability issues (NaN, Inf, overflow)
37    #[error("Numerical error: {0}")]
38    NumericalError(NumericalError),
39
40    /// Gradient computation errors
41    #[error("Gradient error: {0}")]
42    GradientError(String),
43
44    /// Graph structure errors (cycles, missing nodes, etc.)
45    #[error("Graph error: {0}")]
46    GraphError(String),
47
48    /// Execution errors (runtime failures)
49    #[error("Execution error: {0}")]
50    ExecutionError(String),
51
52    /// Unsupported feature or operation
53    #[error("Unsupported: {0}")]
54    Unsupported(String),
55
56    /// Internal errors (should not happen)
57    #[error("Internal error: {0}")]
58    Internal(String),
59}
60
61/// Detailed shape mismatch error with context
62#[derive(Debug, Clone)]
63pub struct ShapeMismatchError {
64    /// Description of the operation that failed
65    pub operation: String,
66    /// Expected shape(s)
67    pub expected: Vec<Vec<usize>>,
68    /// Actual shape(s) that were provided
69    pub actual: Vec<Vec<usize>>,
70    /// Additional context
71    pub context: Option<String>,
72}
73
74impl fmt::Display for ShapeMismatchError {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        write!(
77            f,
78            "Shape mismatch in {}: expected {:?}, got {:?}",
79            self.operation, self.expected, self.actual
80        )?;
81        if let Some(ctx) = &self.context {
82            write!(f, " ({})", ctx)?;
83        }
84        Ok(())
85    }
86}
87
88impl ShapeMismatchError {
89    /// Create a new shape mismatch error
90    pub fn new(
91        operation: impl Into<String>,
92        expected: Vec<Vec<usize>>,
93        actual: Vec<Vec<usize>>,
94    ) -> Self {
95        Self {
96            operation: operation.into(),
97            expected,
98            actual,
99            context: None,
100        }
101    }
102
103    /// Add context to the error
104    pub fn with_context(mut self, context: impl Into<String>) -> Self {
105        self.context = Some(context.into());
106        self
107    }
108}
109
110/// Device-related errors
111#[derive(Error, Debug, Clone)]
112pub enum DeviceError {
113    /// GPU is not available
114    #[error("GPU not available: {0}")]
115    GpuUnavailable(String),
116
117    /// Device memory allocation failed
118    #[error("Device memory allocation failed: {0}")]
119    AllocationFailed(String),
120
121    /// Device synchronization failed
122    #[error("Device synchronization failed: {0}")]
123    SyncFailed(String),
124
125    /// Unsupported device type
126    #[error("Unsupported device: {0}")]
127    UnsupportedDevice(String),
128}
129
130/// Numerical stability and correctness errors
131#[derive(Debug, Clone)]
132pub struct NumericalError {
133    /// Type of numerical issue
134    pub kind: NumericalErrorKind,
135    /// Location where the error occurred
136    pub location: String,
137    /// Values that caused the error (if available)
138    pub values: Option<Vec<f64>>,
139}
140
141/// Types of numerical errors
142#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143pub enum NumericalErrorKind {
144    /// Not-a-Number detected
145    NaN,
146    /// Infinity detected
147    Infinity,
148    /// Overflow in computation
149    Overflow,
150    /// Underflow in computation
151    Underflow,
152    /// Division by zero
153    DivisionByZero,
154    /// Loss of precision
155    PrecisionLoss,
156}
157
158impl fmt::Display for NumericalError {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        write!(f, "{:?} detected in {}", self.kind, self.location)?;
161        if let Some(vals) = &self.values {
162            write!(f, " (values: {:?})", vals)?;
163        }
164        Ok(())
165    }
166}
167
168impl NumericalError {
169    /// Create a new numerical error
170    pub fn new(kind: NumericalErrorKind, location: impl Into<String>) -> Self {
171        Self {
172            kind,
173            location: location.into(),
174            values: None,
175        }
176    }
177
178    /// Add values that caused the error
179    pub fn with_values(mut self, values: Vec<f64>) -> Self {
180        self.values = Some(values);
181        self
182    }
183}
184
185/// Result type using TlBackendError
186pub type TlBackendResult<T> = Result<T, TlBackendError>;
187
188/// Helper functions for creating common errors
189impl TlBackendError {
190    /// Create a shape mismatch error
191    pub fn shape_mismatch(
192        operation: impl Into<String>,
193        expected: Vec<Vec<usize>>,
194        actual: Vec<Vec<usize>>,
195    ) -> Self {
196        TlBackendError::ShapeMismatch(ShapeMismatchError::new(operation, expected, actual))
197    }
198
199    /// Create an invalid einsum spec error
200    pub fn invalid_einsum(spec: impl Into<String>) -> Self {
201        TlBackendError::InvalidEinsumSpec(spec.into())
202    }
203
204    /// Create a tensor not found error
205    pub fn tensor_not_found(name: impl Into<String>) -> Self {
206        TlBackendError::TensorNotFound(name.into())
207    }
208
209    /// Create an invalid operation error
210    pub fn invalid_operation(msg: impl Into<String>) -> Self {
211        TlBackendError::InvalidOperation(msg.into())
212    }
213
214    /// Create a numerical error
215    pub fn numerical(kind: NumericalErrorKind, location: impl Into<String>) -> Self {
216        TlBackendError::NumericalError(NumericalError::new(kind, location))
217    }
218
219    /// Create a GPU unavailable error
220    pub fn gpu_unavailable(msg: impl Into<String>) -> Self {
221        TlBackendError::DeviceError(DeviceError::GpuUnavailable(msg.into()))
222    }
223
224    /// Create an unsupported feature error
225    pub fn unsupported(msg: impl Into<String>) -> Self {
226        TlBackendError::Unsupported(msg.into())
227    }
228
229    /// Create an execution error
230    pub fn execution(msg: impl Into<String>) -> Self {
231        TlBackendError::ExecutionError(msg.into())
232    }
233
234    /// Create a gradient error
235    pub fn gradient(msg: impl Into<String>) -> Self {
236        TlBackendError::GradientError(msg.into())
237    }
238}
239
240/// Check if a value is numerically valid (not NaN or Inf)
241pub fn validate_numeric_value(value: f64, location: &str) -> TlBackendResult<()> {
242    if value.is_nan() {
243        Err(TlBackendError::numerical(NumericalErrorKind::NaN, location))
244    } else if value.is_infinite() {
245        Err(TlBackendError::numerical(
246            NumericalErrorKind::Infinity,
247            location,
248        ))
249    } else {
250        Ok(())
251    }
252}
253
254/// Check if all values in a slice are numerically valid
255pub fn validate_numeric_values(values: &[f64], location: &str) -> TlBackendResult<()> {
256    for &value in values.iter() {
257        if value.is_nan() {
258            return Err(TlBackendError::NumericalError(
259                NumericalError::new(NumericalErrorKind::NaN, location).with_values(vec![value]),
260            ));
261        }
262        if value.is_infinite() {
263            return Err(TlBackendError::NumericalError(
264                NumericalError::new(NumericalErrorKind::Infinity, location)
265                    .with_values(vec![value]),
266            ));
267        }
268    }
269    Ok(())
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_shape_mismatch_error() {
278        let err = TlBackendError::shape_mismatch(
279            "matmul",
280            vec![vec![2, 3], vec![3, 4]],
281            vec![vec![2, 3], vec![2, 4]],
282        );
283        assert!(matches!(err, TlBackendError::ShapeMismatch(_)));
284        assert!(err.to_string().contains("matmul"));
285    }
286
287    #[test]
288    fn test_numerical_error() {
289        let err = TlBackendError::numerical(NumericalErrorKind::NaN, "relu operation");
290        assert!(matches!(err, TlBackendError::NumericalError(_)));
291        assert!(err.to_string().contains("NaN"));
292    }
293
294    #[test]
295    fn test_validate_numeric_value() {
296        // Valid values
297        assert!(validate_numeric_value(0.0, "test").is_ok());
298        assert!(validate_numeric_value(1.5, "test").is_ok());
299        assert!(validate_numeric_value(-10.0, "test").is_ok());
300
301        // Invalid values
302        assert!(validate_numeric_value(f64::NAN, "test").is_err());
303        assert!(validate_numeric_value(f64::INFINITY, "test").is_err());
304        assert!(validate_numeric_value(f64::NEG_INFINITY, "test").is_err());
305    }
306
307    #[test]
308    fn test_validate_numeric_values() {
309        // Valid values
310        let valid = vec![0.0, 1.0, -1.0, 100.0];
311        assert!(validate_numeric_values(&valid, "test").is_ok());
312
313        // Invalid values
314        let invalid_nan = vec![0.0, f64::NAN, 1.0];
315        assert!(validate_numeric_values(&invalid_nan, "test").is_err());
316
317        let invalid_inf = vec![0.0, 1.0, f64::INFINITY];
318        assert!(validate_numeric_values(&invalid_inf, "test").is_err());
319    }
320
321    #[test]
322    fn test_error_display() {
323        let err = TlBackendError::invalid_einsum("abc,def->xyz");
324        assert_eq!(err.to_string(), "Invalid einsum spec: abc,def->xyz");
325
326        let err = TlBackendError::tensor_not_found("tensor_x");
327        assert_eq!(err.to_string(), "Tensor not found: tensor_x");
328    }
329
330    #[test]
331    fn test_device_error() {
332        let err = TlBackendError::gpu_unavailable("CUDA not installed");
333        assert!(matches!(err, TlBackendError::DeviceError(_)));
334        assert!(err.to_string().contains("GPU not available"));
335    }
336
337    #[test]
338    fn test_shape_mismatch_with_context() {
339        let mut err = ShapeMismatchError::new("einsum", vec![vec![2, 3]], vec![vec![3, 4]]);
340        err = err.with_context("input tensor 'x'");
341        let err_str = err.to_string();
342        assert!(err_str.contains("einsum"));
343        assert!(err_str.contains("input tensor 'x'"));
344    }
345}