Skip to main content

trustformers_core/errors/
standardization.rs

1//! Error standardization utilities for migrating modules to TrustformersError
2//!
3//! This module provides utilities to help modules migrate from the legacy CoreError
4//! system to the new TrustformersError system with rich context and suggestions.
5
6use super::{ErrorKind, TrustformersError};
7
8#[allow(deprecated)]
9use crate::error::CoreError;
10
11/// Standard error interface that all modules should use
12pub trait StandardError {
13    /// Convert any error to a standardized TrustformersError
14    fn standardize(self) -> TrustformersError;
15
16    /// Convert with operation context
17    fn standardize_with_operation(self, operation: &str) -> TrustformersError;
18
19    /// Convert with component context
20    fn standardize_with_component(self, component: &str) -> TrustformersError;
21
22    /// Convert with full context
23    fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError;
24}
25
26#[allow(deprecated)]
27impl StandardError for CoreError {
28    fn standardize(self) -> TrustformersError {
29        self.into()
30    }
31
32    fn standardize_with_operation(self, operation: &str) -> TrustformersError {
33        TrustformersError::from(self).with_operation(operation)
34    }
35
36    fn standardize_with_component(self, component: &str) -> TrustformersError {
37        TrustformersError::from(self).with_component(component)
38    }
39
40    fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError {
41        TrustformersError::from(self)
42            .with_operation(operation)
43            .with_component(component)
44    }
45}
46
47impl StandardError for String {
48    fn standardize(self) -> TrustformersError {
49        TrustformersError::new(ErrorKind::Other(self))
50    }
51
52    fn standardize_with_operation(self, operation: &str) -> TrustformersError {
53        TrustformersError::new(ErrorKind::Other(self)).with_operation(operation)
54    }
55
56    fn standardize_with_component(self, component: &str) -> TrustformersError {
57        TrustformersError::new(ErrorKind::Other(self)).with_component(component)
58    }
59
60    fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError {
61        TrustformersError::new(ErrorKind::Other(self))
62            .with_operation(operation)
63            .with_component(component)
64    }
65}
66
67impl StandardError for &str {
68    fn standardize(self) -> TrustformersError {
69        TrustformersError::new(ErrorKind::Other(self.to_string()))
70    }
71
72    fn standardize_with_operation(self, operation: &str) -> TrustformersError {
73        TrustformersError::new(ErrorKind::Other(self.to_string())).with_operation(operation)
74    }
75
76    fn standardize_with_component(self, component: &str) -> TrustformersError {
77        TrustformersError::new(ErrorKind::Other(self.to_string())).with_component(component)
78    }
79
80    fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError {
81        TrustformersError::new(ErrorKind::Other(self.to_string()))
82            .with_operation(operation)
83            .with_component(component)
84    }
85}
86
87impl StandardError for std::io::Error {
88    fn standardize(self) -> TrustformersError {
89        TrustformersError::new(ErrorKind::IoError(self))
90    }
91
92    fn standardize_with_operation(self, operation: &str) -> TrustformersError {
93        TrustformersError::new(ErrorKind::IoError(self)).with_operation(operation)
94    }
95
96    fn standardize_with_component(self, component: &str) -> TrustformersError {
97        TrustformersError::new(ErrorKind::IoError(self)).with_component(component)
98    }
99
100    fn standardize_with_context(self, operation: &str, component: &str) -> TrustformersError {
101        TrustformersError::new(ErrorKind::IoError(self))
102            .with_operation(operation)
103            .with_component(component)
104    }
105}
106
107/// Macro for easy error standardization with automatic context
108#[macro_export]
109macro_rules! std_error {
110    ($err:expr) => {
111        $crate::errors::standardization::StandardError::standardize($err)
112    };
113
114    ($err:expr, operation = $op:expr) => {
115        $crate::errors::standardization::StandardError::standardize_with_operation($err, $op)
116    };
117
118    ($err:expr, component = $comp:expr) => {
119        $crate::errors::standardization::StandardError::standardize_with_component($err, $comp)
120    };
121
122    ($err:expr, operation = $op:expr, component = $comp:expr) => {
123        $crate::errors::standardization::StandardError::standardize_with_context($err, $op, $comp)
124    };
125}
126
127/// Migration utilities for common error patterns
128pub struct ErrorMigrationHelper;
129
130impl ErrorMigrationHelper {
131    /// Convert legacy shape error pattern to new system
132    pub fn shape_error(
133        expected: Vec<usize>,
134        actual: Vec<usize>,
135        operation: &str,
136    ) -> TrustformersError {
137        TrustformersError::new(ErrorKind::ShapeMismatch { expected, actual })
138            .with_operation(operation)
139            .with_suggestion("Check tensor dimensions before operations")
140            .with_suggestion("Use .reshape() or broadcasting to fix dimension mismatches")
141    }
142
143    /// Convert legacy tensor operation error to new system
144    pub fn tensor_operation_error(
145        operation: &str,
146        reason: &str,
147        component: &str,
148    ) -> TrustformersError {
149        TrustformersError::new(ErrorKind::TensorOpError {
150            operation: operation.to_string(),
151            reason: reason.to_string(),
152        })
153        .with_component(component)
154        .with_suggestion("Check tensor compatibility and data types")
155        .with_suggestion("Enable tensor debugging for more information")
156    }
157
158    /// Convert legacy memory error to new system
159    pub fn memory_allocation_error(reason: &str, operation: &str) -> TrustformersError {
160        TrustformersError::new(ErrorKind::MemoryError {
161            reason: reason.to_string(),
162        })
163        .with_operation(operation)
164        .with_suggestion("Try reducing batch size or model complexity")
165        .with_suggestion("Enable memory optimization settings")
166    }
167
168    /// Convert legacy hardware error to new system
169    pub fn hardware_unavailable_error(
170        device: &str,
171        reason: &str,
172        component: &str,
173    ) -> TrustformersError {
174        TrustformersError::new(ErrorKind::HardwareError {
175            device: device.to_string(),
176            reason: reason.to_string(),
177        })
178        .with_component(component)
179        .with_suggestion("Check device drivers and installation")
180        .with_suggestion("Try falling back to CPU execution")
181    }
182
183    /// Convert legacy configuration error to new system
184    pub fn invalid_configuration_error(
185        field: &str,
186        reason: &str,
187        component: &str,
188    ) -> TrustformersError {
189        TrustformersError::new(ErrorKind::InvalidConfiguration {
190            field: field.to_string(),
191            reason: reason.to_string(),
192        })
193        .with_component(component)
194        .with_suggestion("Check configuration file syntax and values")
195        .with_suggestion("Refer to documentation for valid parameter ranges")
196    }
197}
198
199/// Extension trait for Result types to add standardization
200pub trait ResultStandardization<T> {
201    /// Standardize any error in a Result
202    fn standardize_err(self) -> Result<T, TrustformersError>;
203
204    /// Standardize with operation context
205    fn standardize_err_with_operation(self, operation: &str) -> Result<T, TrustformersError>;
206
207    /// Standardize with component context
208    fn standardize_err_with_component(self, component: &str) -> Result<T, TrustformersError>;
209
210    /// Standardize with full context
211    fn standardize_err_with_context(
212        self,
213        operation: &str,
214        component: &str,
215    ) -> Result<T, TrustformersError>;
216}
217
218impl<T, E> ResultStandardization<T> for Result<T, E>
219where
220    E: StandardError,
221{
222    fn standardize_err(self) -> Result<T, TrustformersError> {
223        self.map_err(|e| e.standardize())
224    }
225
226    fn standardize_err_with_operation(self, operation: &str) -> Result<T, TrustformersError> {
227        self.map_err(|e| e.standardize_with_operation(operation))
228    }
229
230    fn standardize_err_with_component(self, component: &str) -> Result<T, TrustformersError> {
231        self.map_err(|e| e.standardize_with_component(component))
232    }
233
234    fn standardize_err_with_context(
235        self,
236        operation: &str,
237        component: &str,
238    ) -> Result<T, TrustformersError> {
239        self.map_err(|e| e.standardize_with_context(operation, component))
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    #[allow(deprecated)]
249    fn test_core_error_standardization() {
250        let core_err = CoreError::InvalidInput("test".to_string());
251        let std_err = core_err.standardize_with_operation("test_operation");
252
253        assert_eq!(
254            std_err.context.operation,
255            Some("test_operation".to_string())
256        );
257    }
258
259    #[test]
260    fn test_string_error_standardization() {
261        let str_err = "Something went wrong";
262        let std_err = str_err.standardize_with_component("TestComponent");
263
264        assert_eq!(std_err.context.component, Some("TestComponent".to_string()));
265    }
266
267    #[test]
268    #[allow(deprecated)]
269    fn test_result_standardization() {
270        fn failing_function() -> Result<(), CoreError> {
271            Err(CoreError::InvalidArgument("test".to_string()))
272        }
273
274        let result = failing_function().standardize_err_with_context("test_op", "test_component");
275
276        assert!(result.is_err());
277        let err = result.unwrap_err();
278        assert_eq!(err.context.operation, Some("test_op".to_string()));
279        assert_eq!(err.context.component, Some("test_component".to_string()));
280    }
281
282    #[test]
283    fn test_migration_helper() {
284        let err =
285            ErrorMigrationHelper::shape_error(vec![2, 3, 4], vec![2, 3, 5], "matrix_multiply");
286
287        match &err.kind {
288            ErrorKind::ShapeMismatch { expected, actual } => {
289                assert_eq!(expected, &vec![2, 3, 4]);
290                assert_eq!(actual, &vec![2, 3, 5]);
291            },
292            _ => panic!("Wrong error kind"),
293        }
294
295        assert_eq!(err.context.operation, Some("matrix_multiply".to_string()));
296        assert!(err.suggestions.len() >= 2);
297    }
298
299    #[test]
300    #[allow(deprecated)]
301    fn test_std_error_macro() {
302        let core_err = CoreError::TensorOpError {
303            message: "test".to_string(),
304            context: crate::error::ErrorContext::new(
305                crate::error::ErrorCode::E2002,
306                "test_operation".to_string(),
307            ),
308        };
309
310        let err1 = std_error!(core_err);
311        assert!(matches!(err1.kind, ErrorKind::ComputeError { .. }));
312
313        let str_err = "test error";
314        let err2 = std_error!(str_err, operation = "test_op");
315        assert_eq!(err2.context.operation, Some("test_op".to_string()));
316    }
317}