1use thiserror::Error;
7
8pub type CnnResult<T> = Result<T, CnnError>;
10
11#[derive(Error, Debug, Clone)]
13pub enum CnnError {
14 #[error("Invalid input: {0}")]
16 InvalidInput(String),
17
18 #[error("Invalid configuration: {0}")]
20 InvalidConfig(String),
21
22 #[error("Model error: {0}")]
24 ModelError(String),
25
26 #[error("Dimension mismatch: {0}")]
28 DimensionMismatch(String),
29
30 #[error("SIMD error: {0}")]
32 SimdError(String),
33
34 #[error("Quantization error: {0}")]
36 QuantizationError(String),
37
38 #[error("Invalid shape: expected {expected}, got {got}")]
40 InvalidShape {
41 expected: String,
43 got: String,
45 },
46
47 #[error("Shape mismatch: {0}")]
49 ShapeMismatch(String),
50
51 #[error("Invalid parameter: {0}")]
53 InvalidParameter(String),
54
55 #[error("Memory allocation failed: {0}")]
57 AllocationError(String),
58
59 #[error("Invalid channel count: expected {expected}, got {actual}")]
61 InvalidChannels {
62 expected: usize,
64 actual: usize,
66 },
67
68 #[error("Invalid convolution parameters: {0}")]
70 InvalidConvParams(String),
71
72 #[error("Weight loading error: {0}")]
74 WeightLoadError(String),
75
76 #[error("Empty input: {0}")]
78 EmptyInput(String),
79
80 #[error("Numerical instability: {0}")]
82 NumericalInstability(String),
83
84 #[error("Unsupported backbone: {0}")]
86 UnsupportedBackbone(String),
87
88 #[error("Batch processing error: {0}")]
90 BatchError(String),
91
92 #[error("Convolution error: {0}")]
94 ConvolutionError(String),
95
96 #[error("Pooling error: {0}")]
98 PoolingError(String),
99
100 #[error("Normalization error: {0}")]
102 NormalizationError(String),
103
104 #[error("Invalid kernel: kernel_size={kernel_size}, but input spatial dims are ({height}, {width})")]
106 InvalidKernel {
107 kernel_size: usize,
109 height: usize,
111 width: usize,
113 },
114
115 #[error("IO error: {0}")]
117 IoError(String),
118
119 #[error("Image error: {0}")]
121 ImageError(String),
122
123 #[error("Index out of bounds: {index} >= {size}")]
125 IndexOutOfBounds {
126 index: usize,
128 size: usize,
130 },
131
132 #[error("Unsupported operation: {0}")]
134 Unsupported(String),
135}
136
137impl From<std::io::Error> for CnnError {
138 fn from(err: std::io::Error) -> Self {
139 CnnError::IoError(err.to_string())
140 }
141}
142
143impl CnnError {
144 pub fn dim_mismatch(expected: usize, actual: usize) -> Self {
146 Self::DimensionMismatch(format!("expected {expected}, got {actual}"))
147 }
148
149 pub fn invalid_shape(expected: impl Into<String>, got: impl Into<String>) -> Self {
151 Self::InvalidShape {
152 expected: expected.into(),
153 got: got.into(),
154 }
155 }
156
157 pub fn shape_mismatch(msg: impl Into<String>) -> Self {
159 Self::ShapeMismatch(msg.into())
160 }
161
162 pub fn invalid_parameter(msg: impl Into<String>) -> Self {
164 Self::InvalidParameter(msg.into())
165 }
166
167 pub fn invalid_config(msg: impl Into<String>) -> Self {
169 Self::InvalidConfig(msg.into())
170 }
171
172 pub fn convolution_error(msg: impl Into<String>) -> Self {
174 Self::ConvolutionError(msg.into())
175 }
176
177 pub fn pooling_error(msg: impl Into<String>) -> Self {
179 Self::PoolingError(msg.into())
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_error_display() {
189 let err = CnnError::DimensionMismatch("expected 64, got 32".to_string());
190 assert!(err.to_string().contains("expected 64"));
191
192 let err = CnnError::InvalidConfig("kernel_size must be positive".to_string());
193 assert_eq!(
194 err.to_string(),
195 "Invalid configuration: kernel_size must be positive"
196 );
197 }
198
199 #[test]
200 fn test_error_clone() {
201 let err = CnnError::ConvolutionError("test".to_string());
202 let cloned = err.clone();
203 assert_eq!(err.to_string(), cloned.to_string());
204 }
205
206 #[test]
207 fn test_invalid_kernel_error() {
208 let err = CnnError::InvalidKernel {
209 kernel_size: 7,
210 height: 3,
211 width: 3,
212 };
213 assert!(err.to_string().contains("kernel_size=7"));
214 assert!(err.to_string().contains("(3, 3)"));
215 }
216
217 #[test]
218 fn test_invalid_channels_error() {
219 let err = CnnError::InvalidChannels {
220 expected: 3,
221 actual: 1,
222 };
223 assert!(err.to_string().contains("expected 3"));
224 assert!(err.to_string().contains("got 1"));
225 }
226
227 #[test]
228 fn test_helper_methods() {
229 let err = CnnError::invalid_shape("NCHW", "NHWC");
230 assert!(err.to_string().contains("NCHW"));
231 assert!(err.to_string().contains("NHWC"));
232
233 let err = CnnError::invalid_config("dropout must be in [0, 1]");
234 assert!(err.to_string().contains("dropout"));
235
236 let err = CnnError::dim_mismatch(64, 32);
237 assert!(err.to_string().contains("64"));
238 assert!(err.to_string().contains("32"));
239 }
240}