Skip to main content

tensorlogic_trustformers/sparse_attention/
error.rs

1//! Error taxonomy for the Longformer-style sparse attention module.
2//!
3//! Kept local (not merged into [`crate::error::TrustformerError`]) so that
4//! attention-internal diagnostics do not pollute the public transformer error
5//! enum.  A [`From`] bridge forwards errors to the crate-wide type when
6//! invoked from higher-level code.
7
8use thiserror::Error;
9
10/// Errors raised during sparse attention mask generation or forward pass.
11#[derive(Debug, Clone, Error, PartialEq)]
12pub enum SparseAttentionError {
13    /// Window size must be strictly positive.
14    #[error("invalid window_size: must be > 0, got {0}")]
15    InvalidWindowSize(usize),
16
17    /// Sequence length must be strictly positive.
18    #[error("invalid sequence length: must be > 0, got {0}")]
19    InvalidSequenceLength(usize),
20
21    /// One or more global token indices exceed the sequence length.
22    #[error("global token index {index} is out of bounds for sequence length {seq_len}")]
23    InvalidGlobalIndices { index: usize, seq_len: usize },
24
25    /// Query, key, or value tensors have incompatible shapes.
26    #[error("dimension mismatch: {context} — expected {expected}, got {got}")]
27    DimensionMismatch {
28        context: String,
29        expected: usize,
30        got: usize,
31    },
32
33    /// A softmax row collapsed to zero mass after masking.
34    #[error("numerical instability: softmax denominator is zero at position {position}")]
35    NumericalInstability { position: usize },
36}
37
38/// Result alias used across the sparse-attention module.
39pub type SparseAttentionResult<T> = Result<T, SparseAttentionError>;
40
41impl From<SparseAttentionError> for crate::error::TrustformerError {
42    fn from(err: SparseAttentionError) -> Self {
43        crate::error::TrustformerError::CompilationError(err.to_string())
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    #[test]
52    fn display_contains_context() {
53        let err = SparseAttentionError::InvalidWindowSize(0);
54        let msg = err.to_string();
55        assert!(msg.contains("window_size"));
56        assert!(msg.contains("0"));
57    }
58
59    #[test]
60    fn bridges_into_trustformer_error() {
61        let err = SparseAttentionError::InvalidWindowSize(0);
62        let bridged: crate::error::TrustformerError = err.into();
63        assert!(bridged.to_string().contains("window_size"));
64    }
65
66    #[test]
67    fn global_index_error_message() {
68        let err = SparseAttentionError::InvalidGlobalIndices {
69            index: 42,
70            seq_len: 16,
71        };
72        assert!(err.to_string().contains("42"));
73        assert!(err.to_string().contains("16"));
74    }
75
76    #[test]
77    fn dimension_mismatch_message() {
78        let err = SparseAttentionError::DimensionMismatch {
79            context: "query rows".into(),
80            expected: 32,
81            got: 16,
82        };
83        let msg = err.to_string();
84        assert!(msg.contains("query rows"));
85        assert!(msg.contains("32"));
86        assert!(msg.contains("16"));
87    }
88}