Skip to main content

ruvector_cnn/
error.rs

1//! Error types for ruvector-cnn.
2//!
3//! This module defines all error types that can occur during CNN operations,
4//! including forward passes, configuration, and weight loading.
5
6use thiserror::Error;
7
8/// Result type for CNN operations.
9pub type CnnResult<T> = Result<T, CnnError>;
10
11/// Errors that can occur during CNN operations.
12#[derive(Error, Debug, Clone)]
13pub enum CnnError {
14    /// Invalid input data.
15    #[error("Invalid input: {0}")]
16    InvalidInput(String),
17
18    /// Invalid configuration.
19    #[error("Invalid configuration: {0}")]
20    InvalidConfig(String),
21
22    /// Model loading error.
23    #[error("Model error: {0}")]
24    ModelError(String),
25
26    /// Dimension mismatch (generic).
27    #[error("Dimension mismatch: {0}")]
28    DimensionMismatch(String),
29
30    /// SIMD operation error.
31    #[error("SIMD error: {0}")]
32    SimdError(String),
33
34    /// Quantization error.
35    #[error("Quantization error: {0}")]
36    QuantizationError(String),
37
38    /// Invalid tensor shape for the operation.
39    #[error("Invalid shape: expected {expected}, got {got}")]
40    InvalidShape {
41        /// Expected shape description
42        expected: String,
43        /// Actual shape description
44        got: String,
45    },
46
47    /// Shape mismatch between tensors.
48    #[error("Shape mismatch: {0}")]
49    ShapeMismatch(String),
50
51    /// Invalid parameter value.
52    #[error("Invalid parameter: {0}")]
53    InvalidParameter(String),
54
55    /// Memory allocation error.
56    #[error("Memory allocation failed: {0}")]
57    AllocationError(String),
58
59    /// Invalid channel count.
60    #[error("Invalid channel count: expected {expected}, got {actual}")]
61    InvalidChannels {
62        /// Expected channels
63        expected: usize,
64        /// Actual channels
65        actual: usize,
66    },
67
68    /// Invalid convolution parameters.
69    #[error("Invalid convolution parameters: {0}")]
70    InvalidConvParams(String),
71
72    /// Weight loading error.
73    #[error("Weight loading error: {0}")]
74    WeightLoadError(String),
75
76    /// Empty input provided.
77    #[error("Empty input: {0}")]
78    EmptyInput(String),
79
80    /// Numerical instability detected.
81    #[error("Numerical instability: {0}")]
82    NumericalInstability(String),
83
84    /// Unsupported backbone type.
85    #[error("Unsupported backbone: {0}")]
86    UnsupportedBackbone(String),
87
88    /// Batch processing error.
89    #[error("Batch processing error: {0}")]
90    BatchError(String),
91
92    /// Error during convolution computation.
93    #[error("Convolution error: {0}")]
94    ConvolutionError(String),
95
96    /// Error during pooling computation.
97    #[error("Pooling error: {0}")]
98    PoolingError(String),
99
100    /// Error during normalization.
101    #[error("Normalization error: {0}")]
102    NormalizationError(String),
103
104    /// Invalid kernel configuration.
105    #[error("Invalid kernel: kernel_size={kernel_size}, but input spatial dims are ({height}, {width})")]
106    InvalidKernel {
107        /// Kernel size
108        kernel_size: usize,
109        /// Input height
110        height: usize,
111        /// Input width
112        width: usize,
113    },
114
115    /// IO error (for model loading).
116    #[error("IO error: {0}")]
117    IoError(String),
118
119    /// Image processing error.
120    #[error("Image error: {0}")]
121    ImageError(String),
122
123    /// Index out of bounds.
124    #[error("Index out of bounds: {index} >= {size}")]
125    IndexOutOfBounds {
126        /// The index that was accessed
127        index: usize,
128        /// The size of the container
129        size: usize,
130    },
131
132    /// Unsupported operation.
133    #[error("Unsupported operation: {0}")]
134    Unsupported(String),
135}
136
137impl From<std::io::Error> for CnnError {
138    fn from(err: std::io::Error) -> Self {
139        CnnError::IoError(err.to_string())
140    }
141}
142
143impl CnnError {
144    /// Create a dimension mismatch error with expected and actual values.
145    pub fn dim_mismatch(expected: usize, actual: usize) -> Self {
146        Self::DimensionMismatch(format!("expected {expected}, got {actual}"))
147    }
148
149    /// Create an invalid shape error.
150    pub fn invalid_shape(expected: impl Into<String>, got: impl Into<String>) -> Self {
151        Self::InvalidShape {
152            expected: expected.into(),
153            got: got.into(),
154        }
155    }
156
157    /// Create a shape mismatch error.
158    pub fn shape_mismatch(msg: impl Into<String>) -> Self {
159        Self::ShapeMismatch(msg.into())
160    }
161
162    /// Create an invalid parameter error.
163    pub fn invalid_parameter(msg: impl Into<String>) -> Self {
164        Self::InvalidParameter(msg.into())
165    }
166
167    /// Create an invalid config error.
168    pub fn invalid_config(msg: impl Into<String>) -> Self {
169        Self::InvalidConfig(msg.into())
170    }
171
172    /// Create a convolution error.
173    pub fn convolution_error(msg: impl Into<String>) -> Self {
174        Self::ConvolutionError(msg.into())
175    }
176
177    /// Create a pooling error.
178    pub fn pooling_error(msg: impl Into<String>) -> Self {
179        Self::PoolingError(msg.into())
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_error_display() {
189        let err = CnnError::DimensionMismatch("expected 64, got 32".to_string());
190        assert!(err.to_string().contains("expected 64"));
191
192        let err = CnnError::InvalidConfig("kernel_size must be positive".to_string());
193        assert_eq!(
194            err.to_string(),
195            "Invalid configuration: kernel_size must be positive"
196        );
197    }
198
199    #[test]
200    fn test_error_clone() {
201        let err = CnnError::ConvolutionError("test".to_string());
202        let cloned = err.clone();
203        assert_eq!(err.to_string(), cloned.to_string());
204    }
205
206    #[test]
207    fn test_invalid_kernel_error() {
208        let err = CnnError::InvalidKernel {
209            kernel_size: 7,
210            height: 3,
211            width: 3,
212        };
213        assert!(err.to_string().contains("kernel_size=7"));
214        assert!(err.to_string().contains("(3, 3)"));
215    }
216
217    #[test]
218    fn test_invalid_channels_error() {
219        let err = CnnError::InvalidChannels {
220            expected: 3,
221            actual: 1,
222        };
223        assert!(err.to_string().contains("expected 3"));
224        assert!(err.to_string().contains("got 1"));
225    }
226
227    #[test]
228    fn test_helper_methods() {
229        let err = CnnError::invalid_shape("NCHW", "NHWC");
230        assert!(err.to_string().contains("NCHW"));
231        assert!(err.to_string().contains("NHWC"));
232
233        let err = CnnError::invalid_config("dropout must be in [0, 1]");
234        assert!(err.to_string().contains("dropout"));
235
236        let err = CnnError::dim_mismatch(64, 32);
237        assert!(err.to_string().contains("64"));
238        assert!(err.to_string().contains("32"));
239    }
240}