ruvector_attention/
error.rs

1//! Error types for the ruvector-attention crate.
2//!
3//! This module defines all error types that can occur during attention computation,
4//! configuration, and training operations.
5
6use thiserror::Error;
7
8/// Errors that can occur during attention operations.
9#[derive(Error, Debug, Clone)]
10pub enum AttentionError {
11    /// Dimension mismatch between query, key, or value tensors.
12    #[error("Dimension mismatch: expected {expected}, got {actual}")]
13    DimensionMismatch {
14        /// Expected dimension size
15        expected: usize,
16        /// Actual dimension size
17        actual: usize,
18    },
19
20    /// Invalid configuration parameter.
21    #[error("Invalid configuration: {0}")]
22    InvalidConfig(String),
23
24    /// Error during attention computation.
25    #[error("Computation error: {0}")]
26    ComputationError(String),
27
28    /// Memory allocation failure.
29    #[error("Memory allocation failed: {0}")]
30    MemoryError(String),
31
32    /// Invalid head configuration for multi-head attention.
33    #[error("Invalid head count: dimension {dim} not divisible by {num_heads} heads")]
34    InvalidHeadCount {
35        /// Model dimension
36        dim: usize,
37        /// Number of attention heads
38        num_heads: usize,
39    },
40
41    /// Empty input provided.
42    #[error("Empty input: {0}")]
43    EmptyInput(String),
44
45    /// Invalid edge configuration for graph attention.
46    #[error("Invalid edge configuration: {0}")]
47    InvalidEdges(String),
48
49    /// Numerical instability detected.
50    #[error("Numerical instability: {0}")]
51    NumericalInstability(String),
52
53    /// Invalid mask dimensions.
54    #[error("Invalid mask dimensions: expected {expected}, got {actual}")]
55    InvalidMask {
56        /// Expected mask dimensions
57        expected: String,
58        /// Actual mask dimensions
59        actual: String,
60    },
61}
62
63/// Result type for attention operations.
64pub type AttentionResult<T> = Result<T, AttentionError>;
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn test_error_display() {
72        let err = AttentionError::DimensionMismatch {
73            expected: 512,
74            actual: 256,
75        };
76        assert_eq!(err.to_string(), "Dimension mismatch: expected 512, got 256");
77
78        let err = AttentionError::InvalidConfig("dropout must be in [0, 1]".to_string());
79        assert_eq!(
80            err.to_string(),
81            "Invalid configuration: dropout must be in [0, 1]"
82        );
83    }
84
85    #[test]
86    fn test_error_clone() {
87        let err = AttentionError::ComputationError("test".to_string());
88        let cloned = err.clone();
89        assert_eq!(err.to_string(), cloned.to_string());
90    }
91}