Skip to main content

trustformers_core/errors/
mod.rs

1//! Enhanced error handling with contextual information and recovery suggestions
2//!
3//! This module provides comprehensive error types with rich context and actionable
4//! recovery suggestions to improve developer experience.
5
6use std::fmt;
7use thiserror::Error;
8
9mod conversions;
10mod standardization;
11
12pub use conversions::{
13    acceleration_error, checkpoint_error, compute_error, dimension_mismatch, file_not_found,
14    hardware_error, invalid_config, invalid_format, invalid_input, memory_error,
15    model_compatibility_error, model_not_found, not_implemented, out_of_memory, performance_error,
16    quantization_error, resource_exhausted, runtime_error, shape_mismatch, tensor_op_error,
17    timed_error, timeout_error, unsupported_operation, ResultExt, TimedResultExt,
18};
19
20pub use standardization::{ErrorMigrationHelper, ResultStandardization, StandardError};
21
22/// Core error type with context and recovery suggestions
23#[derive(Debug, Error)]
24pub struct TrustformersError {
25    /// The underlying error kind
26    #[source]
27    pub kind: ErrorKind,
28
29    /// Contextual information about where the error occurred
30    pub context: ErrorContext,
31
32    /// Suggested recovery actions
33    pub suggestions: Vec<String>,
34
35    /// Error code for documentation lookup
36    pub code: ErrorCode,
37}
38
39impl TrustformersError {
40    /// Create a new error with context
41    pub fn new(kind: ErrorKind) -> Self {
42        let code = ErrorCode::from_kind(&kind);
43        let suggestions = Self::default_suggestions(&kind);
44
45        Self {
46            kind,
47            context: ErrorContext::default(),
48            suggestions,
49            code,
50        }
51    }
52
53    /// Add contextual information
54    pub fn with_context(mut self, key: &str, value: String) -> Self {
55        self.context.add(key, value);
56        self
57    }
58
59    /// Add a recovery suggestion
60    pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
61        self.suggestions.push(suggestion.into());
62        self
63    }
64
65    /// Set the operation that failed
66    pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
67        self.context.operation = Some(operation.into());
68        self
69    }
70
71    /// Set the model or component involved
72    pub fn with_component(mut self, component: impl Into<String>) -> Self {
73        self.context.component = Some(component.into());
74        self
75    }
76
77    /// Get default suggestions based on error kind
78    fn default_suggestions(kind: &ErrorKind) -> Vec<String> {
79        match kind {
80            ErrorKind::DimensionMismatch { expected, actual } => vec![
81                format!(
82                    "Check that input tensors have shape {}, not {}",
83                    expected, actual
84                ),
85                "Verify the model configuration matches your input dimensions".to_string(),
86                "Use .view() or .reshape() to adjust tensor dimensions".to_string(),
87            ],
88
89            ErrorKind::ShapeMismatch { expected, actual } => vec![
90                format!(
91                    "Check that tensor shapes match: expected {:?}, got {:?}",
92                    expected, actual
93                ),
94                "Use .reshape() to adjust tensor dimensions".to_string(),
95                "Verify input data shapes match expected model dimensions".to_string(),
96                "Consider using broadcasting operations if appropriate".to_string(),
97            ],
98
99            ErrorKind::OutOfMemory {
100                required,
101                available: _available,
102            } => vec![
103                "Try reducing batch size".to_string(),
104                "Enable gradient checkpointing to trade compute for memory".to_string(),
105                "Use mixed precision training (fp16/bf16) to reduce memory usage".to_string(),
106                format!(
107                    "Consider using model parallelism if model requires >{}GB",
108                    required / 1_000_000_000
109                ),
110            ],
111
112            ErrorKind::InvalidConfiguration { field, reason } => vec![
113                format!("Check the '{}' field in your configuration", field),
114                format!("Reason: {}", reason),
115                "Refer to the model's configuration documentation".to_string(),
116                "Use Model::from_pretrained() for validated configurations".to_string(),
117            ],
118
119            ErrorKind::ModelNotFound { name } => vec![
120                format!("Verify the model name '{}' is correct", name),
121                "Check available models with Model::list_available()".to_string(),
122                "Ensure you have internet connectivity for downloading".to_string(),
123                "Try specifying a revision if the model was recently updated".to_string(),
124            ],
125
126            ErrorKind::QuantizationError { reason } => vec![
127                "Ensure the model supports the requested quantization type".to_string(),
128                format!("Issue: {}", reason),
129                "Try a different quantization method (int8, int4, gptq, awq)".to_string(),
130                "Check if calibration data is required for this quantization".to_string(),
131            ],
132
133            ErrorKind::DeviceError { device, reason } => vec![
134                format!("Check that {} is available and properly configured", device),
135                format!("Error: {}", reason),
136                "Try running on CPU as a fallback".to_string(),
137                "Verify driver installation and versions".to_string(),
138            ],
139
140            ErrorKind::SerializationError { format, reason } => vec![
141                format!("Check the {} file format", format),
142                format!("Issue: {}", reason),
143                "Ensure the file is not corrupted".to_string(),
144                "Try converting to a different format".to_string(),
145            ],
146
147            ErrorKind::ComputeError { operation, reason } => vec![
148                format!("The {} operation failed: {}", operation, reason),
149                "Check for numerical instability (NaN/Inf values)".to_string(),
150                "Try using different precision (fp32 instead of fp16)".to_string(),
151                "Enable debug mode for detailed tensor information".to_string(),
152            ],
153
154            ErrorKind::TensorOpError { operation, reason } => vec![
155                format!("Tensor operation '{}' failed: {}", operation, reason),
156                "Check tensor dimensions and compatibility".to_string(),
157                "Verify data types are compatible".to_string(),
158                "Enable tensor debugging to see intermediate values".to_string(),
159            ],
160
161            ErrorKind::MemoryError { reason } => vec![
162                format!("Memory operation failed: {}", reason),
163                "Try reducing memory usage by clearing unused tensors".to_string(),
164                "Enable memory optimization settings".to_string(),
165                "Consider using CPU offloading for large tensors".to_string(),
166            ],
167
168            ErrorKind::HardwareError { device, reason } => vec![
169                format!("Hardware error on {}: {}", device, reason),
170                "Check device drivers and installation".to_string(),
171                "Verify hardware is properly connected".to_string(),
172                "Try falling back to CPU execution".to_string(),
173            ],
174
175            ErrorKind::PerformanceError { reason } => vec![
176                format!("Performance issue: {}", reason),
177                "Try optimizing batch size or model parameters".to_string(),
178                "Enable performance profiling to identify bottlenecks".to_string(),
179                "Consider using more efficient operations".to_string(),
180            ],
181
182            ErrorKind::InvalidInput { reason } => vec![
183                format!("Invalid input: {}", reason),
184                "Check input data format and types".to_string(),
185                "Verify input shapes match model expectations".to_string(),
186                "Ensure input data is properly preprocessed".to_string(),
187            ],
188
189            ErrorKind::RuntimeError { reason } => vec![
190                format!("Runtime error: {}", reason),
191                "Check system resources and dependencies".to_string(),
192                "Verify configuration settings".to_string(),
193                "Try restarting the operation".to_string(),
194            ],
195
196            ErrorKind::ResourceExhausted { resource, reason } => vec![
197                format!("Resource '{}' exhausted: {}", resource, reason),
198                "Reduce resource usage by optimizing operations".to_string(),
199                "Consider using resource pooling or management".to_string(),
200                "Check system resource limits".to_string(),
201            ],
202
203            ErrorKind::TimeoutError {
204                operation,
205                timeout_ms,
206            } => vec![
207                format!("Operation '{}' timed out after {}ms", operation, timeout_ms),
208                "Increase timeout duration if operation is expected to take longer".to_string(),
209                "Optimize the operation for better performance".to_string(),
210                "Check for deadlocks or infinite loops".to_string(),
211            ],
212
213            ErrorKind::FileNotFound { path } => vec![
214                format!("File not found: {}", path),
215                "Check that the file path is correct".to_string(),
216                "Verify file permissions".to_string(),
217                "Ensure the file exists in the expected location".to_string(),
218            ],
219
220            ErrorKind::InvalidFormat { expected, actual } => vec![
221                format!("Invalid format: expected {}, got {}", expected, actual),
222                "Check the file format and conversion requirements".to_string(),
223                "Verify the data is in the expected format".to_string(),
224                "Try using format conversion utilities".to_string(),
225            ],
226
227            ErrorKind::UnsupportedOperation { operation, target } => vec![
228                format!("Operation '{}' not supported on {}", operation, target),
229                "Check if the operation is available for this target".to_string(),
230                "Try using an alternative operation or target".to_string(),
231                "Verify feature compatibility".to_string(),
232            ],
233
234            ErrorKind::NotImplemented { feature } => vec![
235                format!("Feature '{}' is not yet implemented", feature),
236                "Check the roadmap for planned features".to_string(),
237                "Consider using alternative approaches".to_string(),
238                "Submit a feature request if needed".to_string(),
239            ],
240
241            ErrorKind::AutodiffError { reason } => vec![
242                format!("Automatic differentiation failed: {}", reason),
243                "Check that all operations support gradient computation".to_string(),
244                "Verify the computational graph is correctly built".to_string(),
245                "Enable gradient checking to validate gradients".to_string(),
246            ],
247
248            _ => vec!["Check the error details and context for more information".to_string()],
249        }
250    }
251
252    // Convenience methods for common error patterns
253    pub fn hardware_error(message: &str, operation: &str) -> Self {
254        TrustformersError::new(ErrorKind::HardwareError {
255            device: "unknown".to_string(),
256            reason: message.to_string(),
257        })
258        .with_operation(operation)
259    }
260
261    pub fn tensor_op_error(message: &str, operation: &str) -> Self {
262        TrustformersError::new(ErrorKind::TensorOpError {
263            operation: operation.to_string(),
264            reason: message.to_string(),
265        })
266        .with_operation(operation)
267    }
268
269    pub fn autodiff_error(message: String) -> Self {
270        TrustformersError::new(ErrorKind::AutodiffError { reason: message })
271    }
272
273    pub fn invalid_input(message: String) -> Self {
274        TrustformersError::new(ErrorKind::InvalidInput { reason: message })
275    }
276
277    pub fn config_error(message: &str, field: &str) -> Self {
278        TrustformersError::new(ErrorKind::InvalidConfiguration {
279            field: field.to_string(),
280            reason: message.to_string(),
281        })
282    }
283
284    pub fn invalid_config(message: String) -> Self {
285        TrustformersError::new(ErrorKind::InvalidConfiguration {
286            field: "config".to_string(),
287            reason: message,
288        })
289    }
290
291    pub fn model_error(message: String) -> Self {
292        TrustformersError::new(ErrorKind::ModelNotFound { name: message })
293    }
294
295    pub fn weight_load_error(message: String) -> Self {
296        TrustformersError::new(ErrorKind::WeightLoadingError { reason: message })
297    }
298
299    pub fn runtime_error(message: String) -> Self {
300        TrustformersError::new(ErrorKind::RuntimeError { reason: message })
301    }
302
303    pub fn io_error(message: String) -> Self {
304        TrustformersError::new(ErrorKind::IoError(std::io::Error::other(message)))
305    }
306
307    pub fn shape_error(message: String) -> Self {
308        TrustformersError::new(ErrorKind::ShapeError { reason: message })
309    }
310
311    pub fn safe_tensors_error(message: String) -> Self {
312        TrustformersError::new(ErrorKind::SafeTensorsError { reason: message })
313    }
314
315    pub fn dimension_mismatch(expected: String, actual: String) -> Self {
316        TrustformersError::new(ErrorKind::DimensionMismatch { expected, actual })
317    }
318
319    pub fn invalid_format(expected: String, actual: String) -> Self {
320        TrustformersError::new(ErrorKind::InvalidFormat { expected, actual })
321    }
322
323    pub fn invalid_format_simple(message: String) -> Self {
324        TrustformersError::new(ErrorKind::InvalidFormat {
325            expected: "valid format".to_string(),
326            actual: message,
327        })
328    }
329
330    pub fn not_implemented(feature: String) -> Self {
331        TrustformersError::new(ErrorKind::NotImplemented { feature })
332    }
333
334    pub fn invalid_input_simple(reason: String) -> Self {
335        TrustformersError::new(ErrorKind::InvalidInput { reason })
336    }
337
338    pub fn invalid_state(reason: String) -> Self {
339        TrustformersError::new(ErrorKind::InvalidState { reason })
340    }
341
342    pub fn invalid_operation(message: String) -> Self {
343        TrustformersError::new(ErrorKind::InvalidInput { reason: message })
344    }
345
346    pub fn other(message: String) -> Self {
347        TrustformersError::new(ErrorKind::Other(message))
348    }
349
350    pub fn resource_exhausted(message: String) -> Self {
351        TrustformersError::new(ErrorKind::ResourceExhausted {
352            resource: "memory".to_string(),
353            reason: message,
354        })
355    }
356
357    pub fn lock_error(message: String) -> Self {
358        TrustformersError::new(ErrorKind::Other(format!("Lock error: {}", message)))
359    }
360
361    pub fn serialization_error(message: String) -> Self {
362        TrustformersError::new(ErrorKind::SerializationError {
363            format: "unknown".to_string(),
364            reason: message,
365        })
366    }
367
368    pub fn plugin_error(message: String) -> Self {
369        TrustformersError::new(ErrorKind::Other(format!("Plugin error: {}", message)))
370    }
371
372    pub fn quantization_error(message: String) -> Self {
373        TrustformersError::new(ErrorKind::Other(format!("Quantization error: {}", message)))
374    }
375
376    pub fn invalid_argument(message: String) -> Self {
377        TrustformersError::new(ErrorKind::InvalidInput { reason: message })
378    }
379
380    pub fn file_not_found(message: String) -> Self {
381        TrustformersError::new(ErrorKind::FileNotFound { path: message })
382    }
383}
384
385impl fmt::Display for TrustformersError {
386    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387        // Header with error code
388        writeln!(f, "\nāŒ Error [{}]", self.code)?;
389        writeln!(f, "{}", "─".repeat(60))?;
390
391        // Error description
392        writeln!(f, "šŸ“ {}", self.kind)?;
393
394        // Context information
395        if self.context.has_info() {
396            writeln!(f, "\nšŸ“‹ Context:")?;
397            if let Some(op) = &self.context.operation {
398                writeln!(f, "   Operation: {}", op)?;
399            }
400            if let Some(comp) = &self.context.component {
401                writeln!(f, "   Component: {}", comp)?;
402            }
403            for (key, value) in &self.context.info {
404                writeln!(f, "   {}: {}", key, value)?;
405            }
406        }
407
408        // Recovery suggestions
409        if !self.suggestions.is_empty() {
410            writeln!(f, "\nšŸ’” Suggestions:")?;
411            for (i, suggestion) in self.suggestions.iter().enumerate() {
412                writeln!(f, "   {}. {}", i + 1, suggestion)?;
413            }
414        }
415
416        // Documentation link
417        writeln!(
418            f,
419            "\nšŸ“š For more information, see: https://docs.trustformers.ai/errors/{}",
420            self.code
421        )?;
422        writeln!(f, "{}", "─".repeat(60))?;
423
424        Ok(())
425    }
426}
427
428/// Specific error kinds
429#[derive(Debug, Error)]
430pub enum ErrorKind {
431    #[error("Dimension mismatch: expected {expected}, got {actual}")]
432    DimensionMismatch { expected: String, actual: String },
433
434    #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
435    ShapeMismatch {
436        expected: Vec<usize>,
437        actual: Vec<usize>,
438    },
439
440    #[error("Out of memory: required {required} bytes, available {available} bytes")]
441    OutOfMemory { required: usize, available: usize },
442
443    #[error("Invalid configuration: field '{field}' {reason}")]
444    InvalidConfiguration { field: String, reason: String },
445
446    #[error("Model not found: '{name}'")]
447    ModelNotFound { name: String },
448
449    #[error("Weight loading failed: {reason}")]
450    WeightLoadingError { reason: String },
451
452    #[error("Tokenization error: {reason}")]
453    TokenizationError { reason: String },
454
455    #[error("Quantization error: {reason}")]
456    QuantizationError { reason: String },
457
458    #[error("Device error on {device}: {reason}")]
459    DeviceError { device: String, reason: String },
460
461    #[error("Serialization error for {format}: {reason}")]
462    SerializationError { format: String, reason: String },
463
464    #[error("Compute error in {operation}: {reason}")]
465    ComputeError { operation: String, reason: String },
466
467    #[error("Training error: {reason}")]
468    TrainingError { reason: String },
469
470    #[error("Pipeline error: {reason}")]
471    PipelineError { reason: String },
472
473    #[error("Attention error: {reason}")]
474    AttentionError { reason: String },
475
476    #[error("Optimization error: {reason}")]
477    OptimizationError { reason: String },
478
479    #[error("Autodiff error: {reason}")]
480    AutodiffError { reason: String },
481
482    #[error("Tensor operation error: {operation} failed with {reason}")]
483    TensorOpError { operation: String, reason: String },
484
485    #[error("Memory allocation error: {reason}")]
486    MemoryError { reason: String },
487
488    #[error("Hardware error: {device} - {reason}")]
489    HardwareError { device: String, reason: String },
490
491    #[error("Performance error: {reason}")]
492    PerformanceError { reason: String },
493
494    #[error("Invalid input: {reason}")]
495    InvalidInput { reason: String },
496
497    #[error("Image processing error: {reason}")]
498    ImageProcessingError { reason: String },
499
500    #[error("Runtime error: {reason}")]
501    RuntimeError { reason: String },
502
503    #[error("Resource exhausted: {resource} - {reason}")]
504    ResourceExhausted { resource: String, reason: String },
505
506    #[error("Plugin error: {plugin} - {reason}")]
507    PluginError { plugin: String, reason: String },
508
509    #[error("Timeout error: operation '{operation}' exceeded {timeout_ms}ms")]
510    TimeoutError { operation: String, timeout_ms: u64 },
511
512    #[error("Network error: {reason}")]
513    NetworkError { reason: String },
514
515    #[error("File not found: {path}")]
516    FileNotFound { path: String },
517
518    #[error("Invalid format: expected {expected}, got {actual}")]
519    InvalidFormat { expected: String, actual: String },
520
521    #[error("Invalid state: {reason}")]
522    InvalidState { reason: String },
523
524    #[error("Unsupported operation: {operation} on {target}")]
525    UnsupportedOperation { operation: String, target: String },
526
527    #[error("IO error: {0}")]
528    IoError(#[from] std::io::Error),
529
530    #[error("Not implemented: {feature}")]
531    NotImplemented { feature: String },
532
533    #[error("Shape error: {reason}")]
534    ShapeError { reason: String },
535
536    #[error("SafeTensors error: {reason}")]
537    SafeTensorsError { reason: String },
538
539    #[error("Other error: {0}")]
540    Other(String),
541}
542
543/// Error context information
544#[derive(Debug, Default)]
545pub struct ErrorContext {
546    /// The operation being performed
547    pub operation: Option<String>,
548
549    /// The component or model involved
550    pub component: Option<String>,
551
552    /// Additional key-value information
553    pub info: Vec<(String, String)>,
554}
555
556impl ErrorContext {
557    /// Add contextual information
558    pub fn add(&mut self, key: &str, value: String) {
559        self.info.push((key.to_string(), value));
560    }
561
562    /// Check if context has any information
563    pub fn has_info(&self) -> bool {
564        self.operation.is_some() || self.component.is_some() || !self.info.is_empty()
565    }
566}
567
568/// Error codes for documentation
569#[derive(Debug, Clone, Copy)]
570pub enum ErrorCode {
571    E0001, // DimensionMismatch
572    E0002, // ShapeMismatch
573    E0003, // OutOfMemory
574    E0004, // InvalidConfiguration
575    E0005, // ModelNotFound
576    E0006, // WeightLoadingError
577    E0007, // TokenizationError
578    E0008, // QuantizationError
579    E0009, // DeviceError
580    E0010, // SerializationError
581    E0011, // ComputeError
582    E0012, // TrainingError
583    E0013, // PipelineError
584    E0014, // AttentionError
585    E0015, // OptimizationError
586    E0016, // TensorOpError
587    E0017, // MemoryError
588    E0018, // HardwareError
589    E0019, // PerformanceError
590    E0020, // InvalidInput
591    E0021, // ImageProcessingError
592    E0022, // RuntimeError
593    E0023, // ResourceExhausted
594    E0024, // PluginError
595    E0025, // TimeoutError
596    E0026, // NetworkError
597    E0027, // FileNotFound
598    E0028, // InvalidFormat
599    E0029, // InvalidState
600    E0030, // UnsupportedOperation
601    E0031, // IoError
602    E0032, // NotImplemented
603    E0033, // AutodiffError
604    E9999, // Other
605}
606
607impl ErrorCode {
608    /// Get error code from error kind
609    pub fn from_kind(kind: &ErrorKind) -> Self {
610        match kind {
611            ErrorKind::DimensionMismatch { .. } => ErrorCode::E0001,
612            ErrorKind::ShapeMismatch { .. } => ErrorCode::E0002,
613            ErrorKind::OutOfMemory { .. } => ErrorCode::E0003,
614            ErrorKind::InvalidConfiguration { .. } => ErrorCode::E0004,
615            ErrorKind::ModelNotFound { .. } => ErrorCode::E0005,
616            ErrorKind::WeightLoadingError { .. } => ErrorCode::E0006,
617            ErrorKind::TokenizationError { .. } => ErrorCode::E0007,
618            ErrorKind::QuantizationError { .. } => ErrorCode::E0008,
619            ErrorKind::DeviceError { .. } => ErrorCode::E0009,
620            ErrorKind::SerializationError { .. } => ErrorCode::E0010,
621            ErrorKind::ComputeError { .. } => ErrorCode::E0011,
622            ErrorKind::TrainingError { .. } => ErrorCode::E0012,
623            ErrorKind::PipelineError { .. } => ErrorCode::E0013,
624            ErrorKind::AttentionError { .. } => ErrorCode::E0014,
625            ErrorKind::OptimizationError { .. } => ErrorCode::E0015,
626            ErrorKind::AutodiffError { .. } => ErrorCode::E0033,
627            ErrorKind::TensorOpError { .. } => ErrorCode::E0016,
628            ErrorKind::MemoryError { .. } => ErrorCode::E0017,
629            ErrorKind::HardwareError { .. } => ErrorCode::E0018,
630            ErrorKind::PerformanceError { .. } => ErrorCode::E0019,
631            ErrorKind::InvalidInput { .. } => ErrorCode::E0020,
632            ErrorKind::ImageProcessingError { .. } => ErrorCode::E0021,
633            ErrorKind::RuntimeError { .. } => ErrorCode::E0022,
634            ErrorKind::ResourceExhausted { .. } => ErrorCode::E0023,
635            ErrorKind::PluginError { .. } => ErrorCode::E0024,
636            ErrorKind::TimeoutError { .. } => ErrorCode::E0025,
637            ErrorKind::NetworkError { .. } => ErrorCode::E0026,
638            ErrorKind::FileNotFound { .. } => ErrorCode::E0027,
639            ErrorKind::InvalidFormat { .. } => ErrorCode::E0028,
640            ErrorKind::InvalidState { .. } => ErrorCode::E0029,
641            ErrorKind::UnsupportedOperation { .. } => ErrorCode::E0030,
642            ErrorKind::IoError { .. } => ErrorCode::E0031,
643            ErrorKind::NotImplemented { .. } => ErrorCode::E0032,
644            _ => ErrorCode::E9999,
645        }
646    }
647}
648
649impl fmt::Display for ErrorCode {
650    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
651        write!(f, "{:?}", self)
652    }
653}
654
655/// Result type alias
656pub type Result<T> = std::result::Result<T, TrustformersError>;
657
658/// Helper macro for creating errors with context
659#[macro_export]
660macro_rules! tf_error {
661    ($kind:expr) => {
662        $crate::errors::TrustformersError::new($kind)
663    };
664
665    ($kind:expr, operation = $op:expr) => {
666        $crate::errors::TrustformersError::new($kind).with_operation($op)
667    };
668
669    ($kind:expr, component = $comp:expr) => {
670        $crate::errors::TrustformersError::new($kind).with_component($comp)
671    };
672
673    ($kind:expr, operation = $op:expr, component = $comp:expr) => {
674        $crate::errors::TrustformersError::new($kind)
675            .with_operation($op)
676            .with_component($comp)
677    };
678}
679
680/// Helper macro for adding context to existing errors
681#[macro_export]
682macro_rules! tf_context {
683    ($err:expr, $key:expr => $value:expr) => {
684        $err.with_context($key, $value.to_string())
685    };
686
687    ($err:expr, $key:expr => $value:expr, $($rest_key:expr => $rest_value:expr),+) => {
688        tf_context!($err.with_context($key, $value.to_string()), $($rest_key => $rest_value),+)
689    };
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695
696    #[test]
697    fn test_error_display() {
698        let error = TrustformersError::new(ErrorKind::DimensionMismatch {
699            expected: "[batch_size, 512, 768]".to_string(),
700            actual: "[batch_size, 256, 768]".to_string(),
701        })
702        .with_operation("MultiHeadAttention.forward")
703        .with_component("BERT")
704        .with_context("layer", "12".to_string())
705        .with_context("head_count", "12".to_string());
706
707        let display = format!("{}", error);
708        assert!(display.contains("Error [E0001]"));
709        assert!(display.contains("MultiHeadAttention.forward"));
710        assert!(display.contains("BERT"));
711        assert!(display.contains("layer: 12"));
712    }
713
714    #[test]
715    fn test_error_suggestions() {
716        let error = TrustformersError::new(ErrorKind::OutOfMemory {
717            required: 8_000_000_000,
718            available: 4_000_000_000,
719        });
720
721        assert!(!error.suggestions.is_empty());
722        assert!(error.suggestions.iter().any(|s| s.contains("batch size")));
723        assert!(error.suggestions.iter().any(|s| s.contains("mixed precision")));
724    }
725
726    #[test]
727    fn test_error_macros() {
728        let error = tf_error!(
729            ErrorKind::ModelNotFound {
730                name: "gpt-5".to_string()
731            },
732            operation = "Model::from_pretrained",
733            component = "ModelLoader"
734        );
735
736        assert_eq!(
737            error.context.operation,
738            Some("Model::from_pretrained".to_string())
739        );
740        assert_eq!(error.context.component, Some("ModelLoader".to_string()));
741    }
742}