tensorlogic_trustformers/
error.rs1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq)]
7pub enum TrustformerError {
8 InvalidDimension {
10 expected: usize,
11 got: usize,
12 context: String,
13 },
14 InvalidHeadCount { d_model: usize, n_heads: usize },
16 InvalidMaskShape {
18 expected: Vec<usize>,
19 got: Vec<usize>,
20 },
21 MissingParameter(String),
23 CompilationError(String),
25 CheckpointLoadError(String),
27}
28
29impl fmt::Display for TrustformerError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Self::InvalidDimension {
33 expected,
34 got,
35 context,
36 } => write!(
37 f,
38 "Invalid dimension in {}: expected {}, got {}",
39 context, expected, got
40 ),
41 Self::InvalidHeadCount { d_model, n_heads } => write!(
42 f,
43 "d_model ({}) must be divisible by n_heads ({})",
44 d_model, n_heads
45 ),
46 Self::InvalidMaskShape { expected, got } => write!(
47 f,
48 "Invalid mask shape: expected {:?}, got {:?}",
49 expected, got
50 ),
51 Self::MissingParameter(param) => write!(f, "Missing required parameter: {}", param),
52 Self::CompilationError(msg) => write!(f, "Compilation error: {}", msg),
53 Self::CheckpointLoadError(msg) => write!(f, "Checkpoint load error: {}", msg),
54 }
55 }
56}
57
58impl std::error::Error for TrustformerError {}
59
60impl From<tensorlogic_ir::IrError> for TrustformerError {
62 fn from(err: tensorlogic_ir::IrError) -> Self {
63 TrustformerError::CompilationError(err.to_string())
64 }
65}
66
67pub type Result<T> = std::result::Result<T, TrustformerError>;
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn test_error_display() {
76 let err = TrustformerError::InvalidDimension {
77 expected: 512,
78 got: 256,
79 context: "attention".to_string(),
80 };
81 assert!(err.to_string().contains("512"));
82 assert!(err.to_string().contains("256"));
83 }
84
85 #[test]
86 fn test_invalid_head_count() {
87 let err = TrustformerError::InvalidHeadCount {
88 d_model: 512,
89 n_heads: 7,
90 };
91 assert!(err.to_string().contains("512"));
92 assert!(err.to_string().contains("7"));
93 }
94}