Skip to main content

tensorlogic_trustformers/
error.rs

1//! Error types for tensorlogic-trustformers.
2
3use std::fmt;
4
5/// Errors that can occur in transformer operations.
6#[derive(Debug, Clone, PartialEq)]
7pub enum TrustformerError {
8    /// Invalid dimension configuration
9    InvalidDimension {
10        expected: usize,
11        got: usize,
12        context: String,
13    },
14    /// Head count doesn't divide model dimension evenly
15    InvalidHeadCount { d_model: usize, n_heads: usize },
16    /// Invalid attention mask shape
17    InvalidMaskShape {
18        expected: Vec<usize>,
19        got: Vec<usize>,
20    },
21    /// Missing required parameter
22    MissingParameter(String),
23    /// Compilation error when building einsum graph
24    CompilationError(String),
25    /// Error loading checkpoint file
26    CheckpointLoadError(String),
27}
28
29impl fmt::Display for TrustformerError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            Self::InvalidDimension {
33                expected,
34                got,
35                context,
36            } => write!(
37                f,
38                "Invalid dimension in {}: expected {}, got {}",
39                context, expected, got
40            ),
41            Self::InvalidHeadCount { d_model, n_heads } => write!(
42                f,
43                "d_model ({}) must be divisible by n_heads ({})",
44                d_model, n_heads
45            ),
46            Self::InvalidMaskShape { expected, got } => write!(
47                f,
48                "Invalid mask shape: expected {:?}, got {:?}",
49                expected, got
50            ),
51            Self::MissingParameter(param) => write!(f, "Missing required parameter: {}", param),
52            Self::CompilationError(msg) => write!(f, "Compilation error: {}", msg),
53            Self::CheckpointLoadError(msg) => write!(f, "Checkpoint load error: {}", msg),
54        }
55    }
56}
57
58impl std::error::Error for TrustformerError {}
59
60/// Convert IrError to TrustformerError (for ? operator)
61impl From<tensorlogic_ir::IrError> for TrustformerError {
62    fn from(err: tensorlogic_ir::IrError) -> Self {
63        TrustformerError::CompilationError(err.to_string())
64    }
65}
66
67/// Result type for transformer operations
68pub type Result<T> = std::result::Result<T, TrustformerError>;
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn test_error_display() {
76        let err = TrustformerError::InvalidDimension {
77            expected: 512,
78            got: 256,
79            context: "attention".to_string(),
80        };
81        assert!(err.to_string().contains("512"));
82        assert!(err.to_string().contains("256"));
83    }
84
85    #[test]
86    fn test_invalid_head_count() {
87        let err = TrustformerError::InvalidHeadCount {
88            d_model: 512,
89            n_heads: 7,
90        };
91        assert!(err.to_string().contains("512"));
92        assert!(err.to_string().contains("7"));
93    }
94}