Skip to main content

scirs2_neural/attention/
types.rs

1//! Common types for attention mechanisms.
2//!
3//! This module provides shared configuration and output types used across
4//! Flash Attention, RoPE, ALiBi, sliding-window, and other attention
5//! variants in this crate.
6
7/// Configuration for scaled dot-product attention layers.
8///
9/// # Defaults
10///
11/// ```rust
12/// use scirs2_neural::attention::types::AttentionConfig;
13/// let cfg = AttentionConfig::default();
14/// assert_eq!(cfg.num_heads, 8);
15/// assert_eq!(cfg.head_dim, 64);
16/// assert_eq!(cfg.dropout_prob, 0.0);
17/// assert!(!cfg.causal);
18/// assert!(cfg.use_flash);
19/// assert!(cfg.scale.is_none());
20/// ```
21#[non_exhaustive]
22#[derive(Debug, Clone)]
23pub struct AttentionConfig {
24    /// Number of attention heads.
25    pub num_heads: usize,
26
27    /// Dimensionality of each attention head.
28    pub head_dim: usize,
29
30    /// Dropout probability applied to attention weights (0.0 = disabled).
31    pub dropout_prob: f64,
32
33    /// Whether to use causal (autoregressive) masking.
34    pub causal: bool,
35
36    /// Whether to use Flash Attention 2 for memory-efficient computation.
37    pub use_flash: bool,
38
39    /// Optional custom attention scale factor.  When `None` the standard
40    /// `1 / sqrt(head_dim)` is used.
41    pub scale: Option<f64>,
42}
43
44impl Default for AttentionConfig {
45    fn default() -> Self {
46        Self {
47            num_heads: 8,
48            head_dim: 64,
49            dropout_prob: 0.0,
50            causal: false,
51            use_flash: true,
52            scale: None,
53        }
54    }
55}
56
57impl AttentionConfig {
58    /// Compute the effective scale factor.
59    ///
60    /// Returns `scale` if explicitly set, otherwise `1 / sqrt(head_dim)`.
61    pub fn effective_scale(&self) -> f64 {
62        self.scale
63            .unwrap_or_else(|| 1.0 / (self.head_dim as f64).sqrt())
64    }
65}
66
67// ---------------------------------------------------------------------------
68// AttentionMask
69// ---------------------------------------------------------------------------
70
71/// Attention mask variants controlling which query–key pairs are visible.
72///
73/// The `#[non_exhaustive]` attribute ensures that adding new variants is
74/// backwards-compatible.
75#[non_exhaustive]
76#[derive(Debug, Clone)]
77pub enum AttentionMask {
78    /// No masking — every position can attend to every other position.
79    None,
80
81    /// Causal (autoregressive) mask — position `i` may only attend to
82    /// positions `j ≤ i`.
83    Causal,
84
85    /// Custom boolean mask.  `mask[i][j] == true` means query `i` may attend
86    /// to key `j`.
87    Custom(Vec<Vec<bool>>),
88
89    /// Padding mask expressed as the *valid* sequence length per batch item.
90    /// Positions `>= lengths[b]` for batch item `b` are masked.
91    PaddingMask(Vec<usize>),
92}
93
94// ---------------------------------------------------------------------------
95// AttentionOutput
96// ---------------------------------------------------------------------------
97
98/// Output of an attention operation.
99///
100/// `output` always contains the context vectors.  `attention_weights` is
101/// populated only when the caller requests it (e.g. for visualisation), since
102/// materialising the full weight tensor is expensive for long sequences.
103#[derive(Debug, Clone)]
104pub struct AttentionOutput {
105    /// Context vectors with logical shape `[seq_len, embed_dim]`, stored as
106    /// a 2-D `Vec<Vec<f64>>`.
107    pub output: Vec<Vec<f64>>,
108
109    /// Optional attention weights with shape `[num_heads, seq_len, seq_len]`.
110    /// Each element `weights[h][i][j]` is the softmax weight that query `i`
111    /// assigns to key `j` under head `h`.
112    pub attention_weights: Option<Vec<Vec<Vec<f64>>>>,
113}
114
115impl AttentionOutput {
116    /// Construct an output with no stored weights.
117    pub fn new(output: Vec<Vec<f64>>) -> Self {
118        Self {
119            output,
120            attention_weights: None,
121        }
122    }
123
124    /// Construct an output that includes the full weight tensor.
125    pub fn with_weights(output: Vec<Vec<f64>>, weights: Vec<Vec<Vec<f64>>>) -> Self {
126        Self {
127            output,
128            attention_weights: Some(weights),
129        }
130    }
131}
132
133// ---------------------------------------------------------------------------
134// PositionEncoding
135// ---------------------------------------------------------------------------
136
137/// Position-encoding strategies supported by attention layers in this crate.
138#[non_exhaustive]
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
140pub enum PositionEncoding {
141    /// Classic sinusoidal position embeddings (Vaswani et al., 2017).
142    Sinusoidal,
143
144    /// Learned absolute position embeddings (BERT, GPT-2 style).
145    Learned,
146
147    /// Rotary Position Embedding — RoPE (Su et al., 2021).
148    RoPE,
149
150    /// Attention with Linear Biases — ALiBi (Press et al., 2021).
151    ALiBi,
152
153    /// No position encoding (NoPE) — relies on relative information in the
154    /// input or is provided externally.
155    NoPE,
156}
157
158// ---------------------------------------------------------------------------
159// Tests
160// ---------------------------------------------------------------------------
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_attention_config_default() {
168        let cfg = AttentionConfig::default();
169        assert_eq!(cfg.num_heads, 8);
170        assert_eq!(cfg.head_dim, 64);
171        assert_eq!(cfg.dropout_prob, 0.0);
172        assert!(!cfg.causal);
173        assert!(cfg.use_flash);
174        assert!(cfg.scale.is_none());
175    }
176
177    #[test]
178    fn test_attention_config_effective_scale_default() {
179        let cfg = AttentionConfig::default();
180        let expected = 1.0 / (64.0_f64).sqrt();
181        assert!((cfg.effective_scale() - expected).abs() < 1e-12);
182    }
183
184    #[test]
185    fn test_attention_config_effective_scale_custom() {
186        let cfg = AttentionConfig {
187            scale: Some(0.5),
188            ..Default::default()
189        };
190        assert!((cfg.effective_scale() - 0.5).abs() < 1e-12);
191    }
192
193    #[test]
194    fn test_attention_output_struct() {
195        let output = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
196        let ao = AttentionOutput::new(output.clone());
197        assert_eq!(ao.output, output);
198        assert!(ao.attention_weights.is_none());
199    }
200
201    #[test]
202    fn test_attention_output_with_weights() {
203        let output = vec![vec![0.5; 4]];
204        let w = vec![vec![vec![0.25; 4]; 4]; 2];
205        let ao = AttentionOutput::with_weights(output, w.clone());
206        assert!(ao.attention_weights.is_some());
207        assert_eq!(ao.attention_weights.as_ref().map(|x| x.len()), Some(2));
208    }
209
210    #[test]
211    fn test_position_encoding_variants() {
212        // Just ensure all variants can be constructed and compared.
213        let variants = [
214            PositionEncoding::Sinusoidal,
215            PositionEncoding::Learned,
216            PositionEncoding::RoPE,
217            PositionEncoding::ALiBi,
218            PositionEncoding::NoPE,
219        ];
220        for v in &variants {
221            assert_eq!(v, v);
222        }
223    }
224
225    #[test]
226    fn test_attention_mask_causal_variant() {
227        let mask = AttentionMask::Causal;
228        // Pattern-matching must cover `_` arm for #[non_exhaustive] enums
229        // used outside the crate; inside the crate we can match explicitly.
230        match mask {
231            AttentionMask::Causal => {}
232            _ => panic!("expected Causal"),
233        }
234    }
235}