tensorlogic_trustformers/sparse_attention/
error.rs1use thiserror::Error;
9
10#[derive(Debug, Clone, Error, PartialEq)]
12pub enum SparseAttentionError {
13 #[error("invalid window_size: must be > 0, got {0}")]
15 InvalidWindowSize(usize),
16
17 #[error("invalid sequence length: must be > 0, got {0}")]
19 InvalidSequenceLength(usize),
20
21 #[error("global token index {index} is out of bounds for sequence length {seq_len}")]
23 InvalidGlobalIndices { index: usize, seq_len: usize },
24
25 #[error("dimension mismatch: {context} — expected {expected}, got {got}")]
27 DimensionMismatch {
28 context: String,
29 expected: usize,
30 got: usize,
31 },
32
33 #[error("numerical instability: softmax denominator is zero at position {position}")]
35 NumericalInstability { position: usize },
36}
37
38pub type SparseAttentionResult<T> = Result<T, SparseAttentionError>;
40
41impl From<SparseAttentionError> for crate::error::TrustformerError {
42 fn from(err: SparseAttentionError) -> Self {
43 crate::error::TrustformerError::CompilationError(err.to_string())
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50
51 #[test]
52 fn display_contains_context() {
53 let err = SparseAttentionError::InvalidWindowSize(0);
54 let msg = err.to_string();
55 assert!(msg.contains("window_size"));
56 assert!(msg.contains("0"));
57 }
58
59 #[test]
60 fn bridges_into_trustformer_error() {
61 let err = SparseAttentionError::InvalidWindowSize(0);
62 let bridged: crate::error::TrustformerError = err.into();
63 assert!(bridged.to_string().contains("window_size"));
64 }
65
66 #[test]
67 fn global_index_error_message() {
68 let err = SparseAttentionError::InvalidGlobalIndices {
69 index: 42,
70 seq_len: 16,
71 };
72 assert!(err.to_string().contains("42"));
73 assert!(err.to_string().contains("16"));
74 }
75
76 #[test]
77 fn dimension_mismatch_message() {
78 let err = SparseAttentionError::DimensionMismatch {
79 context: "query rows".into(),
80 expected: 32,
81 got: 16,
82 };
83 let msg = err.to_string();
84 assert!(msg.contains("query rows"));
85 assert!(msg.contains("32"));
86 assert!(msg.contains("16"));
87 }
88}