scirs2_neural/attention/types.rs
1//! Common types for attention mechanisms.
2//!
3//! This module provides shared configuration and output types used across
4//! Flash Attention, RoPE, ALiBi, sliding-window, and other attention
5//! variants in this crate.
6
7/// Configuration for scaled dot-product attention layers.
8///
9/// # Defaults
10///
11/// ```rust
12/// use scirs2_neural::attention::types::AttentionConfig;
13/// let cfg = AttentionConfig::default();
14/// assert_eq!(cfg.num_heads, 8);
15/// assert_eq!(cfg.head_dim, 64);
16/// assert_eq!(cfg.dropout_prob, 0.0);
17/// assert!(!cfg.causal);
18/// assert!(cfg.use_flash);
19/// assert!(cfg.scale.is_none());
20/// ```
21#[non_exhaustive]
22#[derive(Debug, Clone)]
23pub struct AttentionConfig {
24 /// Number of attention heads.
25 pub num_heads: usize,
26
27 /// Dimensionality of each attention head.
28 pub head_dim: usize,
29
30 /// Dropout probability applied to attention weights (0.0 = disabled).
31 pub dropout_prob: f64,
32
33 /// Whether to use causal (autoregressive) masking.
34 pub causal: bool,
35
36 /// Whether to use Flash Attention 2 for memory-efficient computation.
37 pub use_flash: bool,
38
39 /// Optional custom attention scale factor. When `None` the standard
40 /// `1 / sqrt(head_dim)` is used.
41 pub scale: Option<f64>,
42}
43
44impl Default for AttentionConfig {
45 fn default() -> Self {
46 Self {
47 num_heads: 8,
48 head_dim: 64,
49 dropout_prob: 0.0,
50 causal: false,
51 use_flash: true,
52 scale: None,
53 }
54 }
55}
56
57impl AttentionConfig {
58 /// Compute the effective scale factor.
59 ///
60 /// Returns `scale` if explicitly set, otherwise `1 / sqrt(head_dim)`.
61 pub fn effective_scale(&self) -> f64 {
62 self.scale
63 .unwrap_or_else(|| 1.0 / (self.head_dim as f64).sqrt())
64 }
65}
66
67// ---------------------------------------------------------------------------
68// AttentionMask
69// ---------------------------------------------------------------------------
70
71/// Attention mask variants controlling which query–key pairs are visible.
72///
73/// The `#[non_exhaustive]` attribute ensures that adding new variants is
74/// backwards-compatible.
75#[non_exhaustive]
76#[derive(Debug, Clone)]
77pub enum AttentionMask {
78 /// No masking — every position can attend to every other position.
79 None,
80
81 /// Causal (autoregressive) mask — position `i` may only attend to
82 /// positions `j ≤ i`.
83 Causal,
84
85 /// Custom boolean mask. `mask[i][j] == true` means query `i` may attend
86 /// to key `j`.
87 Custom(Vec<Vec<bool>>),
88
89 /// Padding mask expressed as the *valid* sequence length per batch item.
90 /// Positions `>= lengths[b]` for batch item `b` are masked.
91 PaddingMask(Vec<usize>),
92}
93
94// ---------------------------------------------------------------------------
95// AttentionOutput
96// ---------------------------------------------------------------------------
97
98/// Output of an attention operation.
99///
100/// `output` always contains the context vectors. `attention_weights` is
101/// populated only when the caller requests it (e.g. for visualisation), since
102/// materialising the full weight tensor is expensive for long sequences.
103#[derive(Debug, Clone)]
104pub struct AttentionOutput {
105 /// Context vectors with logical shape `[seq_len, embed_dim]`, stored as
106 /// a 2-D `Vec<Vec<f64>>`.
107 pub output: Vec<Vec<f64>>,
108
109 /// Optional attention weights with shape `[num_heads, seq_len, seq_len]`.
110 /// Each element `weights[h][i][j]` is the softmax weight that query `i`
111 /// assigns to key `j` under head `h`.
112 pub attention_weights: Option<Vec<Vec<Vec<f64>>>>,
113}
114
115impl AttentionOutput {
116 /// Construct an output with no stored weights.
117 pub fn new(output: Vec<Vec<f64>>) -> Self {
118 Self {
119 output,
120 attention_weights: None,
121 }
122 }
123
124 /// Construct an output that includes the full weight tensor.
125 pub fn with_weights(output: Vec<Vec<f64>>, weights: Vec<Vec<Vec<f64>>>) -> Self {
126 Self {
127 output,
128 attention_weights: Some(weights),
129 }
130 }
131}
132
133// ---------------------------------------------------------------------------
134// PositionEncoding
135// ---------------------------------------------------------------------------
136
137/// Position-encoding strategies supported by attention layers in this crate.
138#[non_exhaustive]
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
140pub enum PositionEncoding {
141 /// Classic sinusoidal position embeddings (Vaswani et al., 2017).
142 Sinusoidal,
143
144 /// Learned absolute position embeddings (BERT, GPT-2 style).
145 Learned,
146
147 /// Rotary Position Embedding — RoPE (Su et al., 2021).
148 RoPE,
149
150 /// Attention with Linear Biases — ALiBi (Press et al., 2021).
151 ALiBi,
152
153 /// No position encoding (NoPE) — relies on relative information in the
154 /// input or is provided externally.
155 NoPE,
156}
157
158// ---------------------------------------------------------------------------
159// Tests
160// ---------------------------------------------------------------------------
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_attention_config_default() {
168 let cfg = AttentionConfig::default();
169 assert_eq!(cfg.num_heads, 8);
170 assert_eq!(cfg.head_dim, 64);
171 assert_eq!(cfg.dropout_prob, 0.0);
172 assert!(!cfg.causal);
173 assert!(cfg.use_flash);
174 assert!(cfg.scale.is_none());
175 }
176
177 #[test]
178 fn test_attention_config_effective_scale_default() {
179 let cfg = AttentionConfig::default();
180 let expected = 1.0 / (64.0_f64).sqrt();
181 assert!((cfg.effective_scale() - expected).abs() < 1e-12);
182 }
183
184 #[test]
185 fn test_attention_config_effective_scale_custom() {
186 let cfg = AttentionConfig {
187 scale: Some(0.5),
188 ..Default::default()
189 };
190 assert!((cfg.effective_scale() - 0.5).abs() < 1e-12);
191 }
192
193 #[test]
194 fn test_attention_output_struct() {
195 let output = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
196 let ao = AttentionOutput::new(output.clone());
197 assert_eq!(ao.output, output);
198 assert!(ao.attention_weights.is_none());
199 }
200
201 #[test]
202 fn test_attention_output_with_weights() {
203 let output = vec![vec![0.5; 4]];
204 let w = vec![vec![vec![0.25; 4]; 4]; 2];
205 let ao = AttentionOutput::with_weights(output, w.clone());
206 assert!(ao.attention_weights.is_some());
207 assert_eq!(ao.attention_weights.as_ref().map(|x| x.len()), Some(2));
208 }
209
210 #[test]
211 fn test_position_encoding_variants() {
212 // Just ensure all variants can be constructed and compared.
213 let variants = [
214 PositionEncoding::Sinusoidal,
215 PositionEncoding::Learned,
216 PositionEncoding::RoPE,
217 PositionEncoding::ALiBi,
218 PositionEncoding::NoPE,
219 ];
220 for v in &variants {
221 assert_eq!(v, v);
222 }
223 }
224
225 #[test]
226 fn test_attention_mask_causal_variant() {
227 let mask = AttentionMask::Causal;
228 // Pattern-matching must cover `_` arm for #[non_exhaustive] enums
229 // used outside the crate; inside the crate we can match explicitly.
230 match mask {
231 AttentionMask::Causal => {}
232 _ => panic!("expected Causal"),
233 }
234 }
235}