tiny_recursive_rs/layers/
attention.rs

1/// Multi-head attention implementation
2///
3/// Based on the Python implementation in layers.py
4use candle_core::{Result, Tensor, DType, Device};
5use candle_nn::VarBuilder;
6use super::activations::CastedLinear;
7use super::positional::{apply_rotary_pos_emb, RotaryEmbedding};
8
9/// Multi-head attention with optional grouped-query attention
10///
11/// Supports:
12/// - Standard multi-head attention (num_heads == num_key_value_heads)
13/// - Grouped-query attention (num_key_value_heads < num_heads)
14/// - RoPE positional embeddings
15/// - Optional causal masking
16pub struct Attention {
17    hidden_size: usize,
18    head_dim: usize,
19    output_size: usize,
20    num_heads: usize,
21    num_key_value_heads: usize,
22    causal: bool,
23
24    qkv_proj: CastedLinear,
25    o_proj: CastedLinear,
26}
27
28impl Attention {
29    /// Create new Attention layer
30    ///
31    /// # Arguments
32    /// * `hidden_size` - Input/output dimension
33    /// * `head_dim` - Dimension per attention head
34    /// * `num_heads` - Number of query heads
35    /// * `num_key_value_heads` - Number of key/value heads (for GQA)
36    /// * `causal` - Whether to use causal masking
37    /// * `vb` - VarBuilder for parameter initialization
38    pub fn new(
39        hidden_size: usize,
40        head_dim: usize,
41        num_heads: usize,
42        num_key_value_heads: usize,
43        causal: bool,
44        vb: VarBuilder,
45    ) -> Result<Self> {
46        let output_size = head_dim * num_heads;
47
48        // QKV projection: projects to (num_heads + 2 * num_key_value_heads) * head_dim
49        let qkv_size = (num_heads + 2 * num_key_value_heads) * head_dim;
50        let qkv_proj = CastedLinear::new(
51            hidden_size,
52            qkv_size,
53            false,
54            vb.pp("qkv_proj"),
55        )?;
56
57        // Output projection
58        let o_proj = CastedLinear::new(
59            output_size,
60            hidden_size,
61            false,
62            vb.pp("o_proj"),
63        )?;
64
65        Ok(Self {
66            hidden_size,
67            head_dim,
68            output_size,
69            num_heads,
70            num_key_value_heads,
71            causal,
72            qkv_proj,
73            o_proj,
74        })
75    }
76
77    /// Forward pass
78    ///
79    /// # Arguments
80    /// * `hidden_states` - Input tensor [batch, seq_len, hidden_size]
81    /// * `cos_sin` - Optional RoPE embeddings (cos, sin) each [seq_len, head_dim]
82    ///
83    /// # Returns
84    /// Output tensor [batch, seq_len, hidden_size]
85    pub fn forward(
86        &self,
87        hidden_states: &Tensor,
88        cos_sin: Option<(&Tensor, &Tensor)>,
89    ) -> Result<Tensor> {
90        let (batch_size, seq_len, _) = hidden_states.dims3()?;
91
92        // Project to QKV
93        let qkv = self.qkv_proj.forward(hidden_states)?;
94
95        // Reshape and split into Q, K, V
96        // qkv: [batch, seq_len, (num_heads + 2 * num_kv_heads) * head_dim]
97        // -> [batch, seq_len, num_heads + 2 * num_kv_heads, head_dim]
98        let qkv = qkv.reshape((
99            batch_size,
100            seq_len,
101            self.num_heads + 2 * self.num_key_value_heads,
102            self.head_dim,
103        ))?;
104
105        // Split Q, K, V
106        let query = qkv.narrow(2, 0, self.num_heads)?; // [batch, seq_len, num_heads, head_dim]
107        let key = qkv.narrow(2, self.num_heads, self.num_key_value_heads)?;
108        let value = qkv.narrow(2, self.num_heads + self.num_key_value_heads, self.num_key_value_heads)?;
109
110        // Apply RoPE if provided
111        let (query, key) = if let Some((cos, sin)) = cos_sin {
112            apply_rotary_pos_emb(&query, &key, cos, sin)?
113        } else {
114            (query, key)
115        };
116
117        // Reshape for attention: [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim]
118        let query = query.transpose(1, 2)?.contiguous()?;
119        let key = key.transpose(1, 2)?.contiguous()?;
120        let value = value.transpose(1, 2)?.contiguous()?;
121
122        // Handle grouped-query attention by repeating key/value heads if needed
123        let (key, value) = if self.num_key_value_heads < self.num_heads {
124            let repeat_factor = self.num_heads / self.num_key_value_heads;
125            (
126                repeat_kv(&key, repeat_factor)?,
127                repeat_kv(&value, repeat_factor)?,
128            )
129        } else {
130            (key, value)
131        };
132
133        // Scaled dot-product attention
134        let attn_output = scaled_dot_product_attention(
135            &query,
136            &key,
137            &value,
138            self.causal,
139        )?;
140
141        // Reshape back: [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, num_heads, head_dim]
142        let attn_output = attn_output.transpose(1, 2)?;
143
144        // Concatenate heads: [batch, seq_len, num_heads * head_dim]
145        let attn_output = attn_output.reshape((batch_size, seq_len, self.output_size))?;
146
147        // Output projection
148        self.o_proj.forward(&attn_output)
149    }
150}
151
152/// Repeat key/value heads for grouped-query attention
153///
154/// Repeats each head `n` times along the head dimension.
155fn repeat_kv(x: &Tensor, n: usize) -> Result<Tensor> {
156    if n == 1 {
157        return Ok(x.clone());
158    }
159
160    let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?;
161
162    // Expand: [batch, num_kv_heads, seq_len, head_dim]
163    // -> [batch, num_kv_heads, n, seq_len, head_dim]
164    let x = x.unsqueeze(2)?;
165    let x = x.broadcast_as((batch, num_kv_heads, n, seq_len, head_dim))?;
166
167    // Reshape: [batch, num_kv_heads * n, seq_len, head_dim]
168    x.reshape((batch, num_kv_heads * n, seq_len, head_dim))
169}
170
171/// Scaled dot-product attention
172///
173/// attention = softmax(Q @ K^T / sqrt(d_k)) @ V
174///
175/// # Arguments
176/// * `query` - [batch, num_heads, seq_len, head_dim]
177/// * `key` - [batch, num_heads, seq_len, head_dim]
178/// * `value` - [batch, num_heads, seq_len, head_dim]
179/// * `causal` - Whether to apply causal masking
180fn scaled_dot_product_attention(
181    query: &Tensor,
182    key: &Tensor,
183    value: &Tensor,
184    causal: bool,
185) -> Result<Tensor> {
186    let (_batch, _num_heads, seq_len, head_dim) = query.dims4()?;
187    let scale = 1.0 / (head_dim as f64).sqrt();
188
189    // Q @ K^T: [batch, num_heads, seq_len, seq_len]
190    let scores = query.matmul(&key.transpose(2, 3)?)?;
191    let scores = (scores * scale)?;
192
193    // Apply causal mask if needed
194    let scores = if causal {
195        let mask = create_causal_mask(seq_len, scores.device())?;
196        scores.broadcast_add(&mask)?
197    } else {
198        scores
199    };
200
201    // Softmax over last dimension
202    let attn_weights = candle_nn::ops::softmax_last_dim(&scores)?;
203
204    // attn_weights @ V: [batch, num_heads, seq_len, head_dim]
205    attn_weights.matmul(value)
206}
207
208/// Create causal attention mask
209///
210/// Returns a mask with 0s on/below diagonal and -inf above diagonal.
211/// This masks out future positions in self-attention.
212fn create_causal_mask(seq_len: usize, device: &Device) -> Result<Tensor> {
213    let mut mask_data = vec![0.0f32; seq_len * seq_len];
214
215    for i in 0..seq_len {
216        for j in (i + 1)..seq_len {
217            mask_data[i * seq_len + j] = f32::NEG_INFINITY;
218        }
219    }
220
221    Tensor::from_vec(mask_data, (seq_len, seq_len), device)
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use candle_nn::VarMap;
228
229    #[test]
230    fn test_attention_shape() -> Result<()> {
231        let device = Device::Cpu;
232        let varmap = VarMap::new();
233        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
234
235        let attn = Attention::new(256, 32, 8, 8, false, vb)?;
236
237        let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
238        let out = attn.forward(&x, None)?;
239
240        assert_eq!(out.dims(), &[2, 16, 256]);
241
242        Ok(())
243    }
244
245    #[test]
246    fn test_attention_with_rope() -> Result<()> {
247        let device = Device::Cpu;
248        let varmap = VarMap::new();
249        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
250
251        let attn = Attention::new(256, 32, 8, 8, false, vb)?;
252
253        let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
254
255        // Create RoPE embeddings
256        let rope = RotaryEmbedding::new(32, 512, 10000.0, &device)?;
257        let (cos, sin) = rope.forward_with_len(16)?;
258
259        let out = attn.forward(&x, Some((&cos, &sin)))?;
260
261        assert_eq!(out.dims(), &[2, 16, 256]);
262
263        Ok(())
264    }
265
266    #[test]
267    fn test_grouped_query_attention() -> Result<()> {
268        let device = Device::Cpu;
269        let varmap = VarMap::new();
270        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
271
272        // 8 query heads, 2 key/value heads
273        let attn = Attention::new(256, 32, 8, 2, false, vb)?;
274
275        let x = Tensor::randn(0f32, 1.0, (2, 16, 256), &device)?;
276        let out = attn.forward(&x, None)?;
277
278        assert_eq!(out.dims(), &[2, 16, 256]);
279
280        Ok(())
281    }
282
283    #[test]
284    fn test_causal_mask() -> Result<()> {
285        let device = Device::Cpu;
286        let mask = create_causal_mask(4, &device)?;
287
288        // Check shape
289        assert_eq!(mask.dims(), &[4, 4]);
290
291        // Check that lower triangle is 0 and upper triangle is -inf
292        let mask_vec = mask.flatten_all()?.to_vec1::<f32>()?;
293
294        // First row: [0, -inf, -inf, -inf]
295        assert_eq!(mask_vec[0], 0.0);
296        assert!(mask_vec[1].is_infinite() && mask_vec[1].is_sign_negative());
297
298        // Second row: [0, 0, -inf, -inf]
299        assert_eq!(mask_vec[4], 0.0);
300        assert_eq!(mask_vec[5], 0.0);
301        assert!(mask_vec[6].is_infinite() && mask_vec[6].is_sign_negative());
302
303        Ok(())
304    }
305
306    #[test]
307    fn test_repeat_kv() -> Result<()> {
308        let device = Device::Cpu;
309
310        let x = Tensor::randn(0f32, 1.0, (2, 2, 16, 32), &device)?;
311        let repeated = repeat_kv(&x, 4)?;
312
313        assert_eq!(repeated.dims(), &[2, 8, 16, 32]);
314
315        Ok(())
316    }
317}