Skip to main content

trustformers_core/errors/
conversions.rs

1//! Error conversion utilities for seamless integration with existing code
2
3#![allow(deprecated)] // Backward compatibility conversions for deprecated CoreError
4#![allow(unused_variables)] // Error conversions
5
6use super::{ErrorKind, TrustformersError};
7use crate::error::CoreError;
8use anyhow::Error as AnyhowError;
9use scirs2_core::ndarray::ShapeError;
10use std::time::Instant;
11
12impl From<CoreError> for TrustformersError {
13    fn from(err: CoreError) -> Self {
14        match err {
15            CoreError::DimensionMismatch { context: _ } => {
16                TrustformersError::new(ErrorKind::DimensionMismatch {
17                    expected: "unknown".to_string(),
18                    actual: "unknown".to_string(),
19                })
20            },
21
22            CoreError::ShapeMismatch {
23                expected,
24                got,
25                context: _,
26            } => TrustformersError::new(ErrorKind::DimensionMismatch {
27                expected: format!("{:?}", expected),
28                actual: format!("{:?}", got),
29            }),
30
31            CoreError::InvalidArgument(msg) => {
32                TrustformersError::new(ErrorKind::InvalidConfiguration {
33                    field: "argument".to_string(),
34                    reason: msg,
35                })
36            },
37
38            CoreError::InvalidConfig(msg) => {
39                TrustformersError::new(ErrorKind::InvalidConfiguration {
40                    field: "config".to_string(),
41                    reason: msg,
42                })
43            },
44
45            CoreError::NotImplemented(msg) => TrustformersError::new(ErrorKind::Other(msg))
46                .with_suggestion("This feature is not yet implemented")
47                .with_suggestion("Check the roadmap for planned features"),
48
49            CoreError::Io(io_err) => TrustformersError::new(ErrorKind::IoError(io_err)),
50            CoreError::Serialization(serde_err) => {
51                TrustformersError::new(ErrorKind::SerializationError {
52                    format: "JSON".to_string(),
53                    reason: serde_err.to_string(),
54                })
55            },
56
57            CoreError::WeightLoadError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
58            CoreError::TensorOpError {
59                message,
60                context: _,
61            } => TrustformersError::new(ErrorKind::ComputeError {
62                operation: "tensor_op".to_string(),
63                reason: message,
64            }),
65            CoreError::ModelError(msg) => {
66                TrustformersError::new(ErrorKind::ModelNotFound { name: msg })
67            },
68            CoreError::ShapeError(msg) => TrustformersError::new(ErrorKind::DimensionMismatch {
69                expected: "valid_shape".to_string(),
70                actual: msg,
71            }),
72            CoreError::SafeTensorsError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
73            CoreError::InvalidInput(msg) => {
74                TrustformersError::new(ErrorKind::InvalidConfiguration {
75                    field: "input".to_string(),
76                    reason: msg,
77                })
78            },
79            CoreError::TokenizerError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
80            CoreError::RuntimeError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
81            CoreError::IoError(msg) => {
82                TrustformersError::new(ErrorKind::IoError(std::io::Error::other(msg)))
83            },
84            CoreError::ConfigError {
85                message,
86                context: _,
87            } => TrustformersError::new(ErrorKind::InvalidConfiguration {
88                field: "config".to_string(),
89                reason: message,
90            }),
91            CoreError::ComputationError(msg) => TrustformersError::new(ErrorKind::ComputeError {
92                operation: "computation".to_string(),
93                reason: msg,
94            }),
95            CoreError::SerializationError(msg) => {
96                TrustformersError::new(ErrorKind::SerializationError {
97                    format: "unknown".to_string(),
98                    reason: msg,
99                })
100            },
101            CoreError::QuantizationError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
102            CoreError::ResourceExhausted(msg) => TrustformersError::new(ErrorKind::OutOfMemory {
103                required: 0,
104                available: 0,
105            })
106            .with_context("details", msg),
107            CoreError::FormattingError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
108            CoreError::ImageProcessingError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
109            CoreError::LockError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
110            CoreError::PluginError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
111            CoreError::MemoryError {
112                message,
113                context: _,
114            } => TrustformersError::new(ErrorKind::OutOfMemory {
115                required: 0,
116                available: 0,
117            })
118            .with_context("details", message),
119            CoreError::HardwareError {
120                message,
121                context: _,
122            } => TrustformersError::new(ErrorKind::HardwareError {
123                device: "unknown".to_string(),
124                reason: message,
125            }),
126            CoreError::PerformanceError {
127                message,
128                context: _,
129            } => TrustformersError::new(ErrorKind::PerformanceError { reason: message }),
130            CoreError::Timeout(msg) => TrustformersError::new(ErrorKind::TimeoutError {
131                operation: "unknown".to_string(),
132                timeout_ms: 0,
133            })
134            .with_context("details", msg),
135            CoreError::NetworkError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
136            CoreError::FileNotFound(msg) => {
137                TrustformersError::new(ErrorKind::FileNotFound { path: msg })
138            },
139            CoreError::TensorNotFound(msg) => TrustformersError::new(ErrorKind::Other(msg)),
140            CoreError::InvalidFormat(msg) => TrustformersError::new(ErrorKind::InvalidFormat {
141                expected: "valid_format".to_string(),
142                actual: msg,
143            }),
144            CoreError::InvalidState(msg) => TrustformersError::new(ErrorKind::Other(msg)),
145            CoreError::UnsupportedFormat(msg) => {
146                TrustformersError::new(ErrorKind::UnsupportedOperation {
147                    operation: "format_parsing".to_string(),
148                    target: msg,
149                })
150            },
151            CoreError::AutodiffError(msg) => TrustformersError::new(ErrorKind::ComputeError {
152                operation: "autodiff".to_string(),
153                reason: msg,
154            }),
155            CoreError::InvalidOperation(msg) => {
156                TrustformersError::new(ErrorKind::UnsupportedOperation {
157                    operation: "tensor_operation".to_string(),
158                    target: msg,
159                })
160            },
161            CoreError::InternalError(msg) => TrustformersError::new(ErrorKind::Other(msg)),
162
163            CoreError::Other(msg) => TrustformersError::new(ErrorKind::Other(msg.to_string())),
164            CoreError::DeviceNotFound(device_id) => {
165                TrustformersError::new(ErrorKind::HardwareError {
166                    device: device_id,
167                    reason: "Device not found in registry".to_string(),
168                })
169            },
170        }
171    }
172}
173
174impl From<std::io::Error> for TrustformersError {
175    fn from(err: std::io::Error) -> Self {
176        TrustformersError::new(ErrorKind::IoError(err))
177            .with_suggestion("Check file permissions and path existence")
178            .with_suggestion("Ensure sufficient disk space")
179    }
180}
181
182impl From<std::fmt::Error> for TrustformersError {
183    fn from(err: std::fmt::Error) -> Self {
184        TrustformersError::new(ErrorKind::Other(format!("Format error: {}", err)))
185            .with_suggestion("Check string formatting operations")
186    }
187}
188
189impl From<serde_json::Error> for TrustformersError {
190    fn from(err: serde_json::Error) -> Self {
191        TrustformersError::new(ErrorKind::SerializationError {
192            format: "JSON".to_string(),
193            reason: err.to_string(),
194        })
195    }
196}
197
198// Backward compatibility: TrustformersError -> CoreError conversion
199impl From<TrustformersError> for CoreError {
200    fn from(err: TrustformersError) -> Self {
201        match err.kind {
202            ErrorKind::DimensionMismatch { expected, actual } => CoreError::DimensionMismatch {
203                context: crate::error::ErrorContext::new(
204                    crate::error::ErrorCode::E1002,
205                    "dimension_mismatch".to_string(),
206                ),
207            },
208            ErrorKind::ShapeMismatch { expected, actual } => CoreError::ShapeMismatch {
209                expected,
210                got: actual,
211                context: crate::error::ErrorContext::new(
212                    crate::error::ErrorCode::E1001,
213                    "shape_mismatch".to_string(),
214                ),
215            },
216            ErrorKind::OutOfMemory { .. } => CoreError::MemoryError {
217                message: "Out of memory".to_string(),
218                context: crate::error::ErrorContext::new(
219                    crate::error::ErrorCode::E3001,
220                    "memory_allocation".to_string(),
221                ),
222            },
223            ErrorKind::InvalidConfiguration { field, reason } => {
224                CoreError::InvalidConfig(format!("{}: {}", field, reason))
225            },
226            ErrorKind::ModelNotFound { name } => {
227                CoreError::ModelError(format!("Model not found: {}", name))
228            },
229            ErrorKind::TensorOpError { operation, reason } => CoreError::TensorOpError {
230                message: format!("{}: {}", operation, reason),
231                context: crate::error::ErrorContext::new(crate::error::ErrorCode::E2002, operation),
232            },
233            ErrorKind::IoError(io_err) => CoreError::Io(io_err),
234            _ => CoreError::Other(anyhow::anyhow!(err.to_string())),
235        }
236    }
237}
238
239impl From<ShapeError> for TrustformersError {
240    fn from(err: ShapeError) -> Self {
241        TrustformersError::new(ErrorKind::DimensionMismatch {
242            expected: "valid shape".to_string(),
243            actual: format!("invalid shape: {}", err),
244        })
245        .with_suggestion("Check tensor dimensions and shape compatibility")
246        .with_suggestion("Ensure tensor shapes match operation requirements")
247    }
248}
249
250impl From<AnyhowError> for TrustformersError {
251    fn from(err: AnyhowError) -> Self {
252        // Try to downcast to known error types
253        if let Some(core_err) = err.downcast_ref::<CoreError>() {
254            // Convert without cloning by matching on the error type
255            return match core_err {
256                CoreError::DimensionMismatch { context: _ } => {
257                    TrustformersError::new(ErrorKind::DimensionMismatch {
258                        expected: "unknown".to_string(),
259                        actual: "unknown".to_string(),
260                    })
261                },
262                _ => TrustformersError::new(ErrorKind::Other(core_err.to_string())),
263            };
264        }
265
266        if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
267            return TrustformersError::new(ErrorKind::IoError(io_err.kind().into()))
268                .with_context("source", err.to_string());
269        }
270
271        // Generic conversion
272        TrustformersError::new(ErrorKind::Other(err.to_string()))
273    }
274}
275
276/// Extension trait for adding context to Results
277pub trait ResultExt<T> {
278    /// Add operation context to an error
279    fn with_operation(self, operation: impl Into<String>) -> Result<T, TrustformersError>;
280
281    /// Add component context to an error
282    fn with_component(self, component: impl Into<String>) -> Result<T, TrustformersError>;
283
284    /// Add arbitrary context to an error
285    fn with_context_key(self, key: &str, value: impl Into<String>) -> Result<T, TrustformersError>;
286}
287
288impl<T, E> ResultExt<T> for Result<T, E>
289where
290    E: Into<TrustformersError>,
291{
292    fn with_operation(self, operation: impl Into<String>) -> Result<T, TrustformersError> {
293        self.map_err(|e| e.into().with_operation(operation))
294    }
295
296    fn with_component(self, component: impl Into<String>) -> Result<T, TrustformersError> {
297        self.map_err(|e| e.into().with_component(component))
298    }
299
300    fn with_context_key(self, key: &str, value: impl Into<String>) -> Result<T, TrustformersError> {
301        self.map_err(|e| e.into().with_context(key, value.into()))
302    }
303}
304
305/// Helper function for dimension mismatch errors
306pub fn dimension_mismatch(expected: impl ToString, actual: impl ToString) -> TrustformersError {
307    TrustformersError::new(ErrorKind::DimensionMismatch {
308        expected: expected.to_string(),
309        actual: actual.to_string(),
310    })
311}
312
313/// Helper function for OOM errors
314pub fn out_of_memory(required: usize, available: usize) -> TrustformersError {
315    TrustformersError::new(ErrorKind::OutOfMemory {
316        required,
317        available,
318    })
319}
320
321/// Helper function for configuration errors
322pub fn invalid_config(field: impl Into<String>, reason: impl Into<String>) -> TrustformersError {
323    TrustformersError::new(ErrorKind::InvalidConfiguration {
324        field: field.into(),
325        reason: reason.into(),
326    })
327}
328
329/// Helper function for model not found errors
330pub fn model_not_found(name: impl Into<String>) -> TrustformersError {
331    TrustformersError::new(ErrorKind::ModelNotFound { name: name.into() })
332}
333
334/// Helper function for compute errors
335pub fn compute_error(operation: impl Into<String>, reason: impl Into<String>) -> TrustformersError {
336    TrustformersError::new(ErrorKind::ComputeError {
337        operation: operation.into(),
338        reason: reason.into(),
339    })
340}
341
342/// Helper function for shape mismatch errors
343pub fn shape_mismatch(expected: Vec<usize>, actual: Vec<usize>) -> TrustformersError {
344    TrustformersError::new(ErrorKind::ShapeMismatch { expected, actual })
345}
346
347/// Helper function for tensor operation errors
348pub fn tensor_op_error(
349    operation: impl Into<String>,
350    reason: impl Into<String>,
351) -> TrustformersError {
352    TrustformersError::new(ErrorKind::TensorOpError {
353        operation: operation.into(),
354        reason: reason.into(),
355    })
356}
357
358/// Helper function for memory errors
359pub fn memory_error(reason: impl Into<String>) -> TrustformersError {
360    TrustformersError::new(ErrorKind::MemoryError {
361        reason: reason.into(),
362    })
363}
364
365/// Helper function for hardware errors
366pub fn hardware_error(device: impl Into<String>, reason: impl Into<String>) -> TrustformersError {
367    TrustformersError::new(ErrorKind::HardwareError {
368        device: device.into(),
369        reason: reason.into(),
370    })
371}
372
373/// Helper function for performance errors
374pub fn performance_error(reason: impl Into<String>) -> TrustformersError {
375    TrustformersError::new(ErrorKind::PerformanceError {
376        reason: reason.into(),
377    })
378}
379
380/// Helper function for invalid input errors
381pub fn invalid_input(reason: impl Into<String>) -> TrustformersError {
382    TrustformersError::new(ErrorKind::InvalidInput {
383        reason: reason.into(),
384    })
385}
386
387/// Helper function for runtime errors
388pub fn runtime_error(reason: impl Into<String>) -> TrustformersError {
389    TrustformersError::new(ErrorKind::RuntimeError {
390        reason: reason.into(),
391    })
392}
393
394/// Helper function for resource exhausted errors
395pub fn resource_exhausted(
396    resource: impl Into<String>,
397    reason: impl Into<String>,
398) -> TrustformersError {
399    TrustformersError::new(ErrorKind::ResourceExhausted {
400        resource: resource.into(),
401        reason: reason.into(),
402    })
403}
404
405/// Helper function for timeout errors
406pub fn timeout_error(operation: impl Into<String>, timeout_ms: u64) -> TrustformersError {
407    TrustformersError::new(ErrorKind::TimeoutError {
408        operation: operation.into(),
409        timeout_ms,
410    })
411}
412
413/// Helper function for file not found errors
414pub fn file_not_found(path: impl Into<String>) -> TrustformersError {
415    TrustformersError::new(ErrorKind::FileNotFound { path: path.into() })
416}
417
418/// Helper function for invalid format errors
419pub fn invalid_format(expected: impl Into<String>, actual: impl Into<String>) -> TrustformersError {
420    TrustformersError::new(ErrorKind::InvalidFormat {
421        expected: expected.into(),
422        actual: actual.into(),
423    })
424}
425
426/// Helper function for unsupported operation errors
427pub fn unsupported_operation(
428    operation: impl Into<String>,
429    target: impl Into<String>,
430) -> TrustformersError {
431    TrustformersError::new(ErrorKind::UnsupportedOperation {
432        operation: operation.into(),
433        target: target.into(),
434    })
435}
436
437/// Helper function for not implemented errors
438pub fn not_implemented(feature: impl Into<String>) -> TrustformersError {
439    TrustformersError::new(ErrorKind::NotImplemented {
440        feature: feature.into(),
441    })
442}
443
444/// Helper function for model compatibility errors
445pub fn model_compatibility_error(
446    model_type: impl Into<String>,
447    required_version: impl Into<String>,
448) -> TrustformersError {
449    TrustformersError::new(ErrorKind::InvalidConfiguration {
450        field: "model_compatibility".to_string(),
451        reason: format!(
452            "Model type '{}' requires version '{}'",
453            model_type.into(),
454            required_version.into()
455        ),
456    })
457    .with_suggestion("Update to a compatible model version")
458    .with_suggestion("Check the model documentation for compatibility requirements")
459}
460
461/// Helper function for quantization errors
462pub fn quantization_error(
463    operation: impl Into<String>,
464    reason: impl Into<String>,
465) -> TrustformersError {
466    TrustformersError::new(ErrorKind::ComputeError {
467        operation: operation.into(),
468        reason: reason.into(),
469    })
470    .with_suggestion("Try a different quantization scheme")
471    .with_suggestion("Check if the model supports the requested quantization")
472    .with_suggestion("Verify quantization parameters are within valid ranges")
473}
474
475/// Helper function for hardware acceleration errors
476pub fn acceleration_error(
477    backend: impl Into<String>,
478    reason: impl Into<String>,
479) -> TrustformersError {
480    TrustformersError::new(ErrorKind::HardwareError {
481        device: backend.into(),
482        reason: reason.into(),
483    })
484    .with_suggestion("Check hardware drivers are installed and up to date")
485    .with_suggestion("Verify hardware compatibility with the operation")
486    .with_suggestion("Try falling back to CPU execution")
487}
488
489/// Helper function for checkpoint loading errors
490pub fn checkpoint_error(path: impl Into<String>, reason: impl Into<String>) -> TrustformersError {
491    TrustformersError::new(ErrorKind::IoError(std::io::Error::new(
492        std::io::ErrorKind::InvalidData,
493        format!("Checkpoint error at {}: {}", path.into(), reason.into()),
494    )))
495    .with_suggestion("Verify the checkpoint file is not corrupted")
496    .with_suggestion("Check if the checkpoint format is supported")
497    .with_suggestion("Ensure sufficient disk space and permissions")
498}
499
500/// Helper function for creating errors with timing information
501pub fn timed_error(
502    kind: ErrorKind,
503    operation_start: Instant,
504    operation_name: impl Into<String>,
505) -> TrustformersError {
506    let duration = operation_start.elapsed();
507    TrustformersError::new(kind)
508        .with_operation(operation_name)
509        .with_context("duration_ms", duration.as_millis().to_string())
510        .with_suggestion(format!(
511            "Operation took {:.2}ms - consider optimization if this is slow",
512            duration.as_millis()
513        ))
514}
515
516/// Result extension trait for adding error context with timing
517pub trait TimedResultExt<T> {
518    /// Add timing context to an error result
519    fn with_timing(
520        self,
521        operation_start: Instant,
522        operation_name: impl Into<String>,
523    ) -> Result<T, TrustformersError>;
524}
525
526impl<T, E> TimedResultExt<T> for Result<T, E>
527where
528    E: Into<TrustformersError>,
529{
530    fn with_timing(
531        self,
532        operation_start: Instant,
533        operation_name: impl Into<String>,
534    ) -> Result<T, TrustformersError> {
535        self.map_err(|err| {
536            let duration = operation_start.elapsed();
537            err.into()
538                .with_operation(operation_name)
539                .with_context("duration_ms", duration.as_millis().to_string())
540        })
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn test_core_error_conversion() {
550        let tf_err =
551            TrustformersError::dimension_mismatch("expected".to_string(), "actual".to_string());
552
553        match &tf_err.kind {
554            ErrorKind::DimensionMismatch { .. } => {},
555            _ => panic!("Wrong error kind"),
556        }
557    }
558
559    #[test]
560    fn test_result_extension() {
561        fn failing_operation() -> Result<(), TrustformersError> {
562            Err(TrustformersError::invalid_argument("test".to_string()))
563        }
564
565        let result = failing_operation()
566            .with_operation("test_operation")
567            .with_component("TestComponent");
568
569        assert!(result.is_err());
570        let err = result.unwrap_err();
571        assert_eq!(err.context.operation, Some("test_operation".to_string()));
572        assert_eq!(err.context.component, Some("TestComponent".to_string()));
573    }
574}