1use std::fmt;
7use thiserror::Error;
8
9mod conversions;
10mod standardization;
11
12pub use conversions::{
13 acceleration_error, checkpoint_error, compute_error, dimension_mismatch, file_not_found,
14 hardware_error, invalid_config, invalid_format, invalid_input, memory_error,
15 model_compatibility_error, model_not_found, not_implemented, out_of_memory, performance_error,
16 quantization_error, resource_exhausted, runtime_error, shape_mismatch, tensor_op_error,
17 timed_error, timeout_error, unsupported_operation, ResultExt, TimedResultExt,
18};
19
20pub use standardization::{ErrorMigrationHelper, ResultStandardization, StandardError};
21
22#[derive(Debug, Error)]
24pub struct TrustformersError {
25 #[source]
27 pub kind: ErrorKind,
28
29 pub context: ErrorContext,
31
32 pub suggestions: Vec<String>,
34
35 pub code: ErrorCode,
37}
38
39impl TrustformersError {
40 pub fn new(kind: ErrorKind) -> Self {
42 let code = ErrorCode::from_kind(&kind);
43 let suggestions = Self::default_suggestions(&kind);
44
45 Self {
46 kind,
47 context: ErrorContext::default(),
48 suggestions,
49 code,
50 }
51 }
52
53 pub fn with_context(mut self, key: &str, value: String) -> Self {
55 self.context.add(key, value);
56 self
57 }
58
59 pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
61 self.suggestions.push(suggestion.into());
62 self
63 }
64
65 pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
67 self.context.operation = Some(operation.into());
68 self
69 }
70
71 pub fn with_component(mut self, component: impl Into<String>) -> Self {
73 self.context.component = Some(component.into());
74 self
75 }
76
77 fn default_suggestions(kind: &ErrorKind) -> Vec<String> {
79 match kind {
80 ErrorKind::DimensionMismatch { expected, actual } => vec![
81 format!(
82 "Check that input tensors have shape {}, not {}",
83 expected, actual
84 ),
85 "Verify the model configuration matches your input dimensions".to_string(),
86 "Use .view() or .reshape() to adjust tensor dimensions".to_string(),
87 ],
88
89 ErrorKind::ShapeMismatch { expected, actual } => vec![
90 format!(
91 "Check that tensor shapes match: expected {:?}, got {:?}",
92 expected, actual
93 ),
94 "Use .reshape() to adjust tensor dimensions".to_string(),
95 "Verify input data shapes match expected model dimensions".to_string(),
96 "Consider using broadcasting operations if appropriate".to_string(),
97 ],
98
99 ErrorKind::OutOfMemory {
100 required,
101 available: _available,
102 } => vec![
103 "Try reducing batch size".to_string(),
104 "Enable gradient checkpointing to trade compute for memory".to_string(),
105 "Use mixed precision training (fp16/bf16) to reduce memory usage".to_string(),
106 format!(
107 "Consider using model parallelism if model requires >{}GB",
108 required / 1_000_000_000
109 ),
110 ],
111
112 ErrorKind::InvalidConfiguration { field, reason } => vec![
113 format!("Check the '{}' field in your configuration", field),
114 format!("Reason: {}", reason),
115 "Refer to the model's configuration documentation".to_string(),
116 "Use Model::from_pretrained() for validated configurations".to_string(),
117 ],
118
119 ErrorKind::ModelNotFound { name } => vec![
120 format!("Verify the model name '{}' is correct", name),
121 "Check available models with Model::list_available()".to_string(),
122 "Ensure you have internet connectivity for downloading".to_string(),
123 "Try specifying a revision if the model was recently updated".to_string(),
124 ],
125
126 ErrorKind::QuantizationError { reason } => vec![
127 "Ensure the model supports the requested quantization type".to_string(),
128 format!("Issue: {}", reason),
129 "Try a different quantization method (int8, int4, gptq, awq)".to_string(),
130 "Check if calibration data is required for this quantization".to_string(),
131 ],
132
133 ErrorKind::DeviceError { device, reason } => vec![
134 format!("Check that {} is available and properly configured", device),
135 format!("Error: {}", reason),
136 "Try running on CPU as a fallback".to_string(),
137 "Verify driver installation and versions".to_string(),
138 ],
139
140 ErrorKind::SerializationError { format, reason } => vec![
141 format!("Check the {} file format", format),
142 format!("Issue: {}", reason),
143 "Ensure the file is not corrupted".to_string(),
144 "Try converting to a different format".to_string(),
145 ],
146
147 ErrorKind::ComputeError { operation, reason } => vec![
148 format!("The {} operation failed: {}", operation, reason),
149 "Check for numerical instability (NaN/Inf values)".to_string(),
150 "Try using different precision (fp32 instead of fp16)".to_string(),
151 "Enable debug mode for detailed tensor information".to_string(),
152 ],
153
154 ErrorKind::TensorOpError { operation, reason } => vec![
155 format!("Tensor operation '{}' failed: {}", operation, reason),
156 "Check tensor dimensions and compatibility".to_string(),
157 "Verify data types are compatible".to_string(),
158 "Enable tensor debugging to see intermediate values".to_string(),
159 ],
160
161 ErrorKind::MemoryError { reason } => vec![
162 format!("Memory operation failed: {}", reason),
163 "Try reducing memory usage by clearing unused tensors".to_string(),
164 "Enable memory optimization settings".to_string(),
165 "Consider using CPU offloading for large tensors".to_string(),
166 ],
167
168 ErrorKind::HardwareError { device, reason } => vec![
169 format!("Hardware error on {}: {}", device, reason),
170 "Check device drivers and installation".to_string(),
171 "Verify hardware is properly connected".to_string(),
172 "Try falling back to CPU execution".to_string(),
173 ],
174
175 ErrorKind::PerformanceError { reason } => vec![
176 format!("Performance issue: {}", reason),
177 "Try optimizing batch size or model parameters".to_string(),
178 "Enable performance profiling to identify bottlenecks".to_string(),
179 "Consider using more efficient operations".to_string(),
180 ],
181
182 ErrorKind::InvalidInput { reason } => vec![
183 format!("Invalid input: {}", reason),
184 "Check input data format and types".to_string(),
185 "Verify input shapes match model expectations".to_string(),
186 "Ensure input data is properly preprocessed".to_string(),
187 ],
188
189 ErrorKind::RuntimeError { reason } => vec![
190 format!("Runtime error: {}", reason),
191 "Check system resources and dependencies".to_string(),
192 "Verify configuration settings".to_string(),
193 "Try restarting the operation".to_string(),
194 ],
195
196 ErrorKind::ResourceExhausted { resource, reason } => vec![
197 format!("Resource '{}' exhausted: {}", resource, reason),
198 "Reduce resource usage by optimizing operations".to_string(),
199 "Consider using resource pooling or management".to_string(),
200 "Check system resource limits".to_string(),
201 ],
202
203 ErrorKind::TimeoutError {
204 operation,
205 timeout_ms,
206 } => vec![
207 format!("Operation '{}' timed out after {}ms", operation, timeout_ms),
208 "Increase timeout duration if operation is expected to take longer".to_string(),
209 "Optimize the operation for better performance".to_string(),
210 "Check for deadlocks or infinite loops".to_string(),
211 ],
212
213 ErrorKind::FileNotFound { path } => vec![
214 format!("File not found: {}", path),
215 "Check that the file path is correct".to_string(),
216 "Verify file permissions".to_string(),
217 "Ensure the file exists in the expected location".to_string(),
218 ],
219
220 ErrorKind::InvalidFormat { expected, actual } => vec![
221 format!("Invalid format: expected {}, got {}", expected, actual),
222 "Check the file format and conversion requirements".to_string(),
223 "Verify the data is in the expected format".to_string(),
224 "Try using format conversion utilities".to_string(),
225 ],
226
227 ErrorKind::UnsupportedOperation { operation, target } => vec![
228 format!("Operation '{}' not supported on {}", operation, target),
229 "Check if the operation is available for this target".to_string(),
230 "Try using an alternative operation or target".to_string(),
231 "Verify feature compatibility".to_string(),
232 ],
233
234 ErrorKind::NotImplemented { feature } => vec![
235 format!("Feature '{}' is not yet implemented", feature),
236 "Check the roadmap for planned features".to_string(),
237 "Consider using alternative approaches".to_string(),
238 "Submit a feature request if needed".to_string(),
239 ],
240
241 ErrorKind::AutodiffError { reason } => vec![
242 format!("Automatic differentiation failed: {}", reason),
243 "Check that all operations support gradient computation".to_string(),
244 "Verify the computational graph is correctly built".to_string(),
245 "Enable gradient checking to validate gradients".to_string(),
246 ],
247
248 _ => vec!["Check the error details and context for more information".to_string()],
249 }
250 }
251
252 pub fn hardware_error(message: &str, operation: &str) -> Self {
254 TrustformersError::new(ErrorKind::HardwareError {
255 device: "unknown".to_string(),
256 reason: message.to_string(),
257 })
258 .with_operation(operation)
259 }
260
261 pub fn tensor_op_error(message: &str, operation: &str) -> Self {
262 TrustformersError::new(ErrorKind::TensorOpError {
263 operation: operation.to_string(),
264 reason: message.to_string(),
265 })
266 .with_operation(operation)
267 }
268
269 pub fn autodiff_error(message: String) -> Self {
270 TrustformersError::new(ErrorKind::AutodiffError { reason: message })
271 }
272
273 pub fn invalid_input(message: String) -> Self {
274 TrustformersError::new(ErrorKind::InvalidInput { reason: message })
275 }
276
277 pub fn config_error(message: &str, field: &str) -> Self {
278 TrustformersError::new(ErrorKind::InvalidConfiguration {
279 field: field.to_string(),
280 reason: message.to_string(),
281 })
282 }
283
284 pub fn invalid_config(message: String) -> Self {
285 TrustformersError::new(ErrorKind::InvalidConfiguration {
286 field: "config".to_string(),
287 reason: message,
288 })
289 }
290
291 pub fn model_error(message: String) -> Self {
292 TrustformersError::new(ErrorKind::ModelNotFound { name: message })
293 }
294
295 pub fn weight_load_error(message: String) -> Self {
296 TrustformersError::new(ErrorKind::WeightLoadingError { reason: message })
297 }
298
299 pub fn runtime_error(message: String) -> Self {
300 TrustformersError::new(ErrorKind::RuntimeError { reason: message })
301 }
302
303 pub fn io_error(message: String) -> Self {
304 TrustformersError::new(ErrorKind::IoError(std::io::Error::other(message)))
305 }
306
307 pub fn shape_error(message: String) -> Self {
308 TrustformersError::new(ErrorKind::ShapeError { reason: message })
309 }
310
311 pub fn safe_tensors_error(message: String) -> Self {
312 TrustformersError::new(ErrorKind::SafeTensorsError { reason: message })
313 }
314
315 pub fn dimension_mismatch(expected: String, actual: String) -> Self {
316 TrustformersError::new(ErrorKind::DimensionMismatch { expected, actual })
317 }
318
319 pub fn invalid_format(expected: String, actual: String) -> Self {
320 TrustformersError::new(ErrorKind::InvalidFormat { expected, actual })
321 }
322
323 pub fn invalid_format_simple(message: String) -> Self {
324 TrustformersError::new(ErrorKind::InvalidFormat {
325 expected: "valid format".to_string(),
326 actual: message,
327 })
328 }
329
330 pub fn not_implemented(feature: String) -> Self {
331 TrustformersError::new(ErrorKind::NotImplemented { feature })
332 }
333
334 pub fn invalid_input_simple(reason: String) -> Self {
335 TrustformersError::new(ErrorKind::InvalidInput { reason })
336 }
337
338 pub fn invalid_state(reason: String) -> Self {
339 TrustformersError::new(ErrorKind::InvalidState { reason })
340 }
341
342 pub fn invalid_operation(message: String) -> Self {
343 TrustformersError::new(ErrorKind::InvalidInput { reason: message })
344 }
345
346 pub fn other(message: String) -> Self {
347 TrustformersError::new(ErrorKind::Other(message))
348 }
349
350 pub fn resource_exhausted(message: String) -> Self {
351 TrustformersError::new(ErrorKind::ResourceExhausted {
352 resource: "memory".to_string(),
353 reason: message,
354 })
355 }
356
357 pub fn lock_error(message: String) -> Self {
358 TrustformersError::new(ErrorKind::Other(format!("Lock error: {}", message)))
359 }
360
361 pub fn serialization_error(message: String) -> Self {
362 TrustformersError::new(ErrorKind::SerializationError {
363 format: "unknown".to_string(),
364 reason: message,
365 })
366 }
367
368 pub fn plugin_error(message: String) -> Self {
369 TrustformersError::new(ErrorKind::Other(format!("Plugin error: {}", message)))
370 }
371
372 pub fn quantization_error(message: String) -> Self {
373 TrustformersError::new(ErrorKind::Other(format!("Quantization error: {}", message)))
374 }
375
376 pub fn invalid_argument(message: String) -> Self {
377 TrustformersError::new(ErrorKind::InvalidInput { reason: message })
378 }
379
380 pub fn file_not_found(message: String) -> Self {
381 TrustformersError::new(ErrorKind::FileNotFound { path: message })
382 }
383}
384
385impl fmt::Display for TrustformersError {
386 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387 writeln!(f, "\nā Error [{}]", self.code)?;
389 writeln!(f, "{}", "ā".repeat(60))?;
390
391 writeln!(f, "š {}", self.kind)?;
393
394 if self.context.has_info() {
396 writeln!(f, "\nš Context:")?;
397 if let Some(op) = &self.context.operation {
398 writeln!(f, " Operation: {}", op)?;
399 }
400 if let Some(comp) = &self.context.component {
401 writeln!(f, " Component: {}", comp)?;
402 }
403 for (key, value) in &self.context.info {
404 writeln!(f, " {}: {}", key, value)?;
405 }
406 }
407
408 if !self.suggestions.is_empty() {
410 writeln!(f, "\nš” Suggestions:")?;
411 for (i, suggestion) in self.suggestions.iter().enumerate() {
412 writeln!(f, " {}. {}", i + 1, suggestion)?;
413 }
414 }
415
416 writeln!(
418 f,
419 "\nš For more information, see: https://docs.trustformers.ai/errors/{}",
420 self.code
421 )?;
422 writeln!(f, "{}", "ā".repeat(60))?;
423
424 Ok(())
425 }
426}
427
428#[derive(Debug, Error)]
430pub enum ErrorKind {
431 #[error("Dimension mismatch: expected {expected}, got {actual}")]
432 DimensionMismatch { expected: String, actual: String },
433
434 #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
435 ShapeMismatch {
436 expected: Vec<usize>,
437 actual: Vec<usize>,
438 },
439
440 #[error("Out of memory: required {required} bytes, available {available} bytes")]
441 OutOfMemory { required: usize, available: usize },
442
443 #[error("Invalid configuration: field '{field}' {reason}")]
444 InvalidConfiguration { field: String, reason: String },
445
446 #[error("Model not found: '{name}'")]
447 ModelNotFound { name: String },
448
449 #[error("Weight loading failed: {reason}")]
450 WeightLoadingError { reason: String },
451
452 #[error("Tokenization error: {reason}")]
453 TokenizationError { reason: String },
454
455 #[error("Quantization error: {reason}")]
456 QuantizationError { reason: String },
457
458 #[error("Device error on {device}: {reason}")]
459 DeviceError { device: String, reason: String },
460
461 #[error("Serialization error for {format}: {reason}")]
462 SerializationError { format: String, reason: String },
463
464 #[error("Compute error in {operation}: {reason}")]
465 ComputeError { operation: String, reason: String },
466
467 #[error("Training error: {reason}")]
468 TrainingError { reason: String },
469
470 #[error("Pipeline error: {reason}")]
471 PipelineError { reason: String },
472
473 #[error("Attention error: {reason}")]
474 AttentionError { reason: String },
475
476 #[error("Optimization error: {reason}")]
477 OptimizationError { reason: String },
478
479 #[error("Autodiff error: {reason}")]
480 AutodiffError { reason: String },
481
482 #[error("Tensor operation error: {operation} failed with {reason}")]
483 TensorOpError { operation: String, reason: String },
484
485 #[error("Memory allocation error: {reason}")]
486 MemoryError { reason: String },
487
488 #[error("Hardware error: {device} - {reason}")]
489 HardwareError { device: String, reason: String },
490
491 #[error("Performance error: {reason}")]
492 PerformanceError { reason: String },
493
494 #[error("Invalid input: {reason}")]
495 InvalidInput { reason: String },
496
497 #[error("Image processing error: {reason}")]
498 ImageProcessingError { reason: String },
499
500 #[error("Runtime error: {reason}")]
501 RuntimeError { reason: String },
502
503 #[error("Resource exhausted: {resource} - {reason}")]
504 ResourceExhausted { resource: String, reason: String },
505
506 #[error("Plugin error: {plugin} - {reason}")]
507 PluginError { plugin: String, reason: String },
508
509 #[error("Timeout error: operation '{operation}' exceeded {timeout_ms}ms")]
510 TimeoutError { operation: String, timeout_ms: u64 },
511
512 #[error("Network error: {reason}")]
513 NetworkError { reason: String },
514
515 #[error("File not found: {path}")]
516 FileNotFound { path: String },
517
518 #[error("Invalid format: expected {expected}, got {actual}")]
519 InvalidFormat { expected: String, actual: String },
520
521 #[error("Invalid state: {reason}")]
522 InvalidState { reason: String },
523
524 #[error("Unsupported operation: {operation} on {target}")]
525 UnsupportedOperation { operation: String, target: String },
526
527 #[error("IO error: {0}")]
528 IoError(#[from] std::io::Error),
529
530 #[error("Not implemented: {feature}")]
531 NotImplemented { feature: String },
532
533 #[error("Shape error: {reason}")]
534 ShapeError { reason: String },
535
536 #[error("SafeTensors error: {reason}")]
537 SafeTensorsError { reason: String },
538
539 #[error("Other error: {0}")]
540 Other(String),
541}
542
543#[derive(Debug, Default)]
545pub struct ErrorContext {
546 pub operation: Option<String>,
548
549 pub component: Option<String>,
551
552 pub info: Vec<(String, String)>,
554}
555
556impl ErrorContext {
557 pub fn add(&mut self, key: &str, value: String) {
559 self.info.push((key.to_string(), value));
560 }
561
562 pub fn has_info(&self) -> bool {
564 self.operation.is_some() || self.component.is_some() || !self.info.is_empty()
565 }
566}
567
568#[derive(Debug, Clone, Copy)]
570pub enum ErrorCode {
571 E0001, E0002, E0003, E0004, E0005, E0006, E0007, E0008, E0009, E0010, E0011, E0012, E0013, E0014, E0015, E0016, E0017, E0018, E0019, E0020, E0021, E0022, E0023, E0024, E0025, E0026, E0027, E0028, E0029, E0030, E0031, E0032, E0033, E9999, }
606
607impl ErrorCode {
608 pub fn from_kind(kind: &ErrorKind) -> Self {
610 match kind {
611 ErrorKind::DimensionMismatch { .. } => ErrorCode::E0001,
612 ErrorKind::ShapeMismatch { .. } => ErrorCode::E0002,
613 ErrorKind::OutOfMemory { .. } => ErrorCode::E0003,
614 ErrorKind::InvalidConfiguration { .. } => ErrorCode::E0004,
615 ErrorKind::ModelNotFound { .. } => ErrorCode::E0005,
616 ErrorKind::WeightLoadingError { .. } => ErrorCode::E0006,
617 ErrorKind::TokenizationError { .. } => ErrorCode::E0007,
618 ErrorKind::QuantizationError { .. } => ErrorCode::E0008,
619 ErrorKind::DeviceError { .. } => ErrorCode::E0009,
620 ErrorKind::SerializationError { .. } => ErrorCode::E0010,
621 ErrorKind::ComputeError { .. } => ErrorCode::E0011,
622 ErrorKind::TrainingError { .. } => ErrorCode::E0012,
623 ErrorKind::PipelineError { .. } => ErrorCode::E0013,
624 ErrorKind::AttentionError { .. } => ErrorCode::E0014,
625 ErrorKind::OptimizationError { .. } => ErrorCode::E0015,
626 ErrorKind::AutodiffError { .. } => ErrorCode::E0033,
627 ErrorKind::TensorOpError { .. } => ErrorCode::E0016,
628 ErrorKind::MemoryError { .. } => ErrorCode::E0017,
629 ErrorKind::HardwareError { .. } => ErrorCode::E0018,
630 ErrorKind::PerformanceError { .. } => ErrorCode::E0019,
631 ErrorKind::InvalidInput { .. } => ErrorCode::E0020,
632 ErrorKind::ImageProcessingError { .. } => ErrorCode::E0021,
633 ErrorKind::RuntimeError { .. } => ErrorCode::E0022,
634 ErrorKind::ResourceExhausted { .. } => ErrorCode::E0023,
635 ErrorKind::PluginError { .. } => ErrorCode::E0024,
636 ErrorKind::TimeoutError { .. } => ErrorCode::E0025,
637 ErrorKind::NetworkError { .. } => ErrorCode::E0026,
638 ErrorKind::FileNotFound { .. } => ErrorCode::E0027,
639 ErrorKind::InvalidFormat { .. } => ErrorCode::E0028,
640 ErrorKind::InvalidState { .. } => ErrorCode::E0029,
641 ErrorKind::UnsupportedOperation { .. } => ErrorCode::E0030,
642 ErrorKind::IoError { .. } => ErrorCode::E0031,
643 ErrorKind::NotImplemented { .. } => ErrorCode::E0032,
644 _ => ErrorCode::E9999,
645 }
646 }
647}
648
649impl fmt::Display for ErrorCode {
650 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
651 write!(f, "{:?}", self)
652 }
653}
654
655pub type Result<T> = std::result::Result<T, TrustformersError>;
657
658#[macro_export]
660macro_rules! tf_error {
661 ($kind:expr) => {
662 $crate::errors::TrustformersError::new($kind)
663 };
664
665 ($kind:expr, operation = $op:expr) => {
666 $crate::errors::TrustformersError::new($kind).with_operation($op)
667 };
668
669 ($kind:expr, component = $comp:expr) => {
670 $crate::errors::TrustformersError::new($kind).with_component($comp)
671 };
672
673 ($kind:expr, operation = $op:expr, component = $comp:expr) => {
674 $crate::errors::TrustformersError::new($kind)
675 .with_operation($op)
676 .with_component($comp)
677 };
678}
679
680#[macro_export]
682macro_rules! tf_context {
683 ($err:expr, $key:expr => $value:expr) => {
684 $err.with_context($key, $value.to_string())
685 };
686
687 ($err:expr, $key:expr => $value:expr, $($rest_key:expr => $rest_value:expr),+) => {
688 tf_context!($err.with_context($key, $value.to_string()), $($rest_key => $rest_value),+)
689 };
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695
696 #[test]
697 fn test_error_display() {
698 let error = TrustformersError::new(ErrorKind::DimensionMismatch {
699 expected: "[batch_size, 512, 768]".to_string(),
700 actual: "[batch_size, 256, 768]".to_string(),
701 })
702 .with_operation("MultiHeadAttention.forward")
703 .with_component("BERT")
704 .with_context("layer", "12".to_string())
705 .with_context("head_count", "12".to_string());
706
707 let display = format!("{}", error);
708 assert!(display.contains("Error [E0001]"));
709 assert!(display.contains("MultiHeadAttention.forward"));
710 assert!(display.contains("BERT"));
711 assert!(display.contains("layer: 12"));
712 }
713
714 #[test]
715 fn test_error_suggestions() {
716 let error = TrustformersError::new(ErrorKind::OutOfMemory {
717 required: 8_000_000_000,
718 available: 4_000_000_000,
719 });
720
721 assert!(!error.suggestions.is_empty());
722 assert!(error.suggestions.iter().any(|s| s.contains("batch size")));
723 assert!(error.suggestions.iter().any(|s| s.contains("mixed precision")));
724 }
725
726 #[test]
727 fn test_error_macros() {
728 let error = tf_error!(
729 ErrorKind::ModelNotFound {
730 name: "gpt-5".to_string()
731 },
732 operation = "Model::from_pretrained",
733 component = "ModelLoader"
734 );
735
736 assert_eq!(
737 error.context.operation,
738 Some("Model::from_pretrained".to_string())
739 );
740 assert_eq!(error.context.component, Some("ModelLoader".to_string()));
741 }
742}