trustformers_core/errors/
standardization.rs1use super::{ErrorKind, TrustformersError};
7
8#[allow(deprecated)]
9use crate::error::CoreError;
10
11pub trait StandardError {
13 fn standardize(self) -> TrustformersError;
15
16 fn standardize_with_operation(self, operation: &str) -> TrustformersError;
18
19 fn standardize_with_component(self, component: &str) -> TrustformersError;
21
22 fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError;
24}
25
26#[allow(deprecated)]
27impl StandardError for CoreError {
28 fn standardize(self) -> TrustformersError {
29 self.into()
30 }
31
32 fn standardize_with_operation(self, operation: &str) -> TrustformersError {
33 TrustformersError::from(self).with_operation(operation)
34 }
35
36 fn standardize_with_component(self, component: &str) -> TrustformersError {
37 TrustformersError::from(self).with_component(component)
38 }
39
40 fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError {
41 TrustformersError::from(self)
42 .with_operation(operation)
43 .with_component(component)
44 }
45}
46
47impl StandardError for String {
48 fn standardize(self) -> TrustformersError {
49 TrustformersError::new(ErrorKind::Other(self))
50 }
51
52 fn standardize_with_operation(self, operation: &str) -> TrustformersError {
53 TrustformersError::new(ErrorKind::Other(self)).with_operation(operation)
54 }
55
56 fn standardize_with_component(self, component: &str) -> TrustformersError {
57 TrustformersError::new(ErrorKind::Other(self)).with_component(component)
58 }
59
60 fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError {
61 TrustformersError::new(ErrorKind::Other(self))
62 .with_operation(operation)
63 .with_component(component)
64 }
65}
66
67impl StandardError for &str {
68 fn standardize(self) -> TrustformersError {
69 TrustformersError::new(ErrorKind::Other(self.to_string()))
70 }
71
72 fn standardize_with_operation(self, operation: &str) -> TrustformersError {
73 TrustformersError::new(ErrorKind::Other(self.to_string())).with_operation(operation)
74 }
75
76 fn standardize_with_component(self, component: &str) -> TrustformersError {
77 TrustformersError::new(ErrorKind::Other(self.to_string())).with_component(component)
78 }
79
80 fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError {
81 TrustformersError::new(ErrorKind::Other(self.to_string()))
82 .with_operation(operation)
83 .with_component(component)
84 }
85}
86
87impl StandardError for std::io::Error {
88 fn standardize(self) -> TrustformersError {
89 TrustformersError::new(ErrorKind::IoError(self))
90 }
91
92 fn standardize_with_operation(self, operation: &str) -> TrustformersError {
93 TrustformersError::new(ErrorKind::IoError(self)).with_operation(operation)
94 }
95
96 fn standardize_with_component(self, component: &str) -> TrustformersError {
97 TrustformersError::new(ErrorKind::IoError(self)).with_component(component)
98 }
99
100 fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError {
101 TrustformersError::new(ErrorKind::IoError(self))
102 .with_operation(operation)
103 .with_component(component)
104 }
105}
106
107#[macro_export]
109macro_rules! std_error {
110 ($err:expr) => {
111 $crate::errors::standardization::StandardError::standardize($err)
112 };
113
114 ($err:expr, operation = $op:expr) => {
115 $crate::errors::standardization::StandardError::standardize_with_operation($err, $op)
116 };
117
118 ($err:expr, component = $comp:expr) => {
119 $crate::errors::standardization::StandardError::standardize_with_component($err, $comp)
120 };
121
122 ($err:expr, operation = $op:expr, component = $comp:expr) => {
123 $crate::errors::standardization::StandardError::standardize_with_context($err, $op, $comp)
124 };
125}
126
127pub struct ErrorMigrationHelper;
129
130impl ErrorMigrationHelper {
131 pub fn shape_error(
133 expected: Vec<usize>,
134 actual: Vec<usize>,
135 operation: &str,
136 ) -> TrustformersError {
137 TrustformersError::new(ErrorKind::ShapeMismatch { expected, actual })
138 .with_operation(operation)
139 .with_suggestion("Check tensor dimensions before operations")
140 .with_suggestion("Use .reshape() or broadcasting to fix dimension mismatches")
141 }
142
143 pub fn tensor_operation_error(
145 operation: &str,
146 reason: &str,
147 component: &str,
148 ) -> TrustformersError {
149 TrustformersError::new(ErrorKind::TensorOpError {
150 operation: operation.to_string(),
151 reason: reason.to_string(),
152 })
153 .with_component(component)
154 .with_suggestion("Check tensor compatibility and data types")
155 .with_suggestion("Enable tensor debugging for more information")
156 }
157
158 pub fn memory_allocation_error(reason: &str, operation: &str) -> TrustformersError {
160 TrustformersError::new(ErrorKind::MemoryError {
161 reason: reason.to_string(),
162 })
163 .with_operation(operation)
164 .with_suggestion("Try reducing batch size or model complexity")
165 .with_suggestion("Enable memory optimization settings")
166 }
167
168 pub fn hardware_unavailable_error(
170 device: &str,
171 reason: &str,
172 component: &str,
173 ) -> TrustformersError {
174 TrustformersError::new(ErrorKind::HardwareError {
175 device: device.to_string(),
176 reason: reason.to_string(),
177 })
178 .with_component(component)
179 .with_suggestion("Check device drivers and installation")
180 .with_suggestion("Try falling back to CPU execution")
181 }
182
183 pub fn invalid_configuration_error(
185 field: &str,
186 reason: &str,
187 component: &str,
188 ) -> TrustformersError {
189 TrustformersError::new(ErrorKind::InvalidConfiguration {
190 field: field.to_string(),
191 reason: reason.to_string(),
192 })
193 .with_component(component)
194 .with_suggestion("Check configuration file syntax and values")
195 .with_suggestion("Refer to documentation for valid parameter ranges")
196 }
197}
198
199pub trait ResultStandardization<T> {
201 fn standardize_err(self) -> Result<T, TrustformersError>;
203
204 fn standardize_err_with_operation(self, operation: &str) -> Result<T, TrustformersError>;
206
207 fn standardize_err_with_component(self, component: &str) -> Result<T, TrustformersError>;
209
210 fn standardize_err_with_context(
212 self,
213 operation: &str,
214 component: &str,
215 ) -> Result<T, TrustformersError>;
216}
217
218impl<T, E> ResultStandardization<T> for Result<T, E>
219where
220 E: StandardError,
221{
222 fn standardize_err(self) -> Result<T, TrustformersError> {
223 self.map_err(|e| e.standardize())
224 }
225
226 fn standardize_err_with_operation(self, operation: &str) -> Result<T, TrustformersError> {
227 self.map_err(|e| e.standardize_with_operation(operation))
228 }
229
230 fn standardize_err_with_component(self, component: &str) -> Result<T, TrustformersError> {
231 self.map_err(|e| e.standardize_with_component(component))
232 }
233
234 fn standardize_err_with_context(
235 self,
236 operation: &str,
237 component: &str,
238 ) -> Result<T, TrustformersError> {
239 self.map_err(|e| e.standardize_with_context(operation, component))
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 #[allow(deprecated)]
249 fn test_core_error_standardization() {
250 let core_err = CoreError::InvalidInput("test".to_string());
251 let std_err = core_err.standardize_with_operation("test_operation");
252
253 assert_eq!(
254 std_err.context.operation,
255 Some("test_operation".to_string())
256 );
257 }
258
259 #[test]
260 fn test_string_error_standardization() {
261 let str_err = "Something went wrong";
262 let std_err = str_err.standardize_with_component("TestComponent");
263
264 assert_eq!(std_err.context.component, Some("TestComponent".to_string()));
265 }
266
267 #[test]
268 #[allow(deprecated)]
269 fn test_result_standardization() {
270 fn failing_function() -> Result<(), CoreError> {
271 Err(CoreError::InvalidArgument("test".to_string()))
272 }
273
274 let result = failing_function().standardize_err_with_context("test_op", "test_component");
275
276 assert!(result.is_err());
277 let err = result.unwrap_err();
278 assert_eq!(err.context.operation, Some("test_op".to_string()));
279 assert_eq!(err.context.component, Some("test_component".to_string()));
280 }
281
282 #[test]
283 fn test_migration_helper() {
284 let err =
285 ErrorMigrationHelper::shape_error(vec![2, 3, 4], vec![2, 3, 5], "matrix_multiply");
286
287 match &err.kind {
288 ErrorKind::ShapeMismatch { expected, actual } => {
289 assert_eq!(expected, &vec![2, 3, 4]);
290 assert_eq!(actual, &vec![2, 3, 5]);
291 },
292 _ => panic!("Wrong error kind"),
293 }
294
295 assert_eq!(err.context.operation, Some("matrix_multiply".to_string()));
296 assert!(err.suggestions.len() >= 2);
297 }
298
299 #[test]
300 #[allow(deprecated)]
301 fn test_std_error_macro() {
302 let core_err = CoreError::TensorOpError {
303 message: "test".to_string(),
304 context: crate::error::ErrorContext::new(
305 crate::error::ErrorCode::E2002,
306 "test_operation".to_string(),
307 ),
308 };
309
310 let err1 = std_error!(core_err);
311 assert!(matches!(err1.kind, ErrorKind::ComputeError { .. }));
312
313 let str_err = "test error";
314 let err2 = std_error!(str_err, operation = "test_op");
315 assert_eq!(err2.context.operation, Some("test_op".to_string()));
316 }
317}