Skip to main content

scirs2_neural/layers/
multi_query_attention.rs

1//! Multi-Query Attention (MQA) implementation
2//!
3//! This module implements Multi-Query Attention as described in:
4//! "Fast Transformer Decoding: One Write-Head is All You Need"
5//! by Noam Shazeer (2019).
6//!
7//! In MQA, all query heads share a single set of key and value projections.
8//! This drastically reduces the KV cache size during autoregressive generation
9//! (by a factor of `num_heads`), while maintaining most of the quality of
10//! standard multi-head attention.
11//!
12//! When `num_kv_heads == num_heads`, MQA degenerates to standard MHA.
13
14use crate::error::{NeuralError, Result};
15use crate::layers::Layer;
16use scirs2_core::ndarray::{s, Array, Array2, Array4, IxDyn, ScalarOperand};
17use scirs2_core::numeric::{Float, NumAssign};
18use scirs2_core::random::{Rng, RngExt};
19use std::fmt::Debug;
20
21// ---------------------------------------------------------------------------
22// Helpers
23// ---------------------------------------------------------------------------
24
25/// Xavier-uniform weight initialisation returning an IxDyn array.
26fn mk_weight<F: Float, R: Rng>(rows: usize, cols: usize, rng: &mut R) -> Result<Array<F, IxDyn>> {
27    let scale = (6.0_f64 / (rows + cols) as f64).sqrt();
28    let mut data = Vec::with_capacity(rows * cols);
29    for _ in 0..(rows * cols) {
30        let x: f64 = rng.random_range(-scale..scale);
31        let f = F::from(x)
32            .ok_or_else(|| NeuralError::InvalidArchitecture("xavier cast failed".into()))?;
33        data.push(f);
34    }
35    Array::from_shape_vec(IxDyn(&[rows, cols]), data)
36        .map_err(|e| NeuralError::InvalidArchitecture(format!("mk_weight: {e}")))
37}
38
39/// Softmax over a mutable slice (in-place, numerically stable).
40fn softmax_inplace<F: Float + NumAssign>(s: &mut [F]) {
41    let max_v = s
42        .iter()
43        .fold(F::neg_infinity(), |a, &b| if b > a { b } else { a });
44    let mut sum = F::zero();
45    for v in s.iter_mut() {
46        *v = (*v - max_v).exp();
47        sum += *v;
48    }
49    let eps = F::from(1e-12).unwrap_or(F::zero());
50    let norm = if sum < eps { eps } else { sum };
51    for v in s.iter_mut() {
52        *v /= norm;
53    }
54}
55
56// ---------------------------------------------------------------------------
57// KV Cache
58// ---------------------------------------------------------------------------
59
60/// Key-Value cache for autoregressive generation
61///
62/// Stores past key and value tensors so they do not need to be recomputed
63/// during incremental decoding.
64#[derive(Debug, Clone)]
65pub struct KvCache<F: Float> {
66    /// Cached keys: [batch, past_len, num_kv_heads, head_dim]
67    pub keys: Array<F, IxDyn>,
68    /// Cached values: [batch, past_len, num_kv_heads, head_dim]
69    pub values: Array<F, IxDyn>,
70}
71
72// ---------------------------------------------------------------------------
73// Configuration
74// ---------------------------------------------------------------------------
75
76/// Configuration for Multi-Query Attention
77#[derive(Debug, Clone)]
78pub struct MultiQueryAttentionConfig {
79    /// Number of query heads
80    pub num_heads: usize,
81    /// Number of KV heads (1 = pure MQA, num_heads = standard MHA)
82    pub num_kv_heads: usize,
83    /// Per-head dimension
84    pub head_dim: usize,
85    /// Dropout probability
86    pub dropout_prob: f64,
87    /// Whether to apply causal masking
88    pub causal: bool,
89}
90
91impl Default for MultiQueryAttentionConfig {
92    fn default() -> Self {
93        Self {
94            num_heads: 8,
95            num_kv_heads: 1,
96            head_dim: 64,
97            dropout_prob: 0.0,
98            causal: false,
99        }
100    }
101}
102
103impl MultiQueryAttentionConfig {
104    /// Create a pure MQA config (1 KV head)
105    pub fn new(num_heads: usize, head_dim: usize) -> Self {
106        Self {
107            num_heads,
108            num_kv_heads: 1,
109            head_dim,
110            ..Default::default()
111        }
112    }
113
114    /// Set number of KV heads (1 = MQA, num_heads = MHA)
115    pub fn with_num_kv_heads(mut self, n: usize) -> Self {
116        self.num_kv_heads = n;
117        self
118    }
119
120    /// Enable or disable causal masking
121    pub fn with_causal(mut self, causal: bool) -> Self {
122        self.causal = causal;
123        self
124    }
125
126    /// Set dropout probability
127    pub fn with_dropout(mut self, prob: f64) -> Self {
128        self.dropout_prob = prob;
129        self
130    }
131}
132
133// ---------------------------------------------------------------------------
134// Layer
135// ---------------------------------------------------------------------------
136
137/// Multi-Query Attention layer
138///
139/// Projects queries with `num_heads` independent heads but uses only
140/// `num_kv_heads` (default 1) shared key/value heads.
141///
142/// # Input
143/// 3D tensor `[batch, seq_len, d_model]`
144///
145/// # Output
146/// 3D tensor `[batch, seq_len, d_model]`
147///
148/// # Examples
149///
150/// ```rust
151/// use scirs2_neural::layers::{MultiQueryAttention, MultiQueryAttentionConfig, Layer};
152/// use scirs2_core::ndarray::Array3;
153/// use scirs2_core::random::rng;
154///
155/// let mut rng = rng();
156/// let config = MultiQueryAttentionConfig::new(4, 16); // 4 Q heads, 1 KV head
157/// let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("failed");
158///
159/// let input = Array3::<f64>::from_elem((2, 8, 64), 0.1).into_dyn();
160/// let output = mqa.forward(&input).expect("failed");
161/// assert_eq!(output.shape(), &[2, 8, 64]);
162/// ```
163#[derive(Debug)]
164pub struct MultiQueryAttention<F: Float + Debug + Send + Sync + NumAssign> {
165    d_model: usize,
166    config: MultiQueryAttentionConfig,
167    /// [d_model, num_heads * head_dim]
168    w_q: Array<F, IxDyn>,
169    /// [d_model, num_kv_heads * head_dim]
170    w_k: Array<F, IxDyn>,
171    /// [d_model, num_kv_heads * head_dim]
172    w_v: Array<F, IxDyn>,
173    /// [num_heads * head_dim, d_model]
174    w_o: Array<F, IxDyn>,
175    scale: F,
176}
177
178impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign> MultiQueryAttention<F> {
179    /// Create a new Multi-Query Attention layer
180    pub fn new<R: Rng>(
181        d_model: usize,
182        config: MultiQueryAttentionConfig,
183        rng: &mut R,
184    ) -> Result<Self> {
185        if config.num_heads == 0 || config.num_kv_heads == 0 || config.head_dim == 0 {
186            return Err(NeuralError::InvalidArchitecture(
187                "num_heads, num_kv_heads, head_dim must be > 0".into(),
188            ));
189        }
190
191        if !config.num_heads.is_multiple_of(config.num_kv_heads) {
192            return Err(NeuralError::InvalidArchitecture(format!(
193                "num_heads ({}) must be divisible by num_kv_heads ({})",
194                config.num_heads, config.num_kv_heads
195            )));
196        }
197
198        let q_dim = config.num_heads * config.head_dim;
199        let kv_dim = config.num_kv_heads * config.head_dim;
200
201        if q_dim != d_model {
202            return Err(NeuralError::InvalidArchitecture(format!(
203                "num_heads * head_dim ({q_dim}) must equal d_model ({d_model})"
204            )));
205        }
206
207        let w_q = mk_weight(d_model, q_dim, rng)?;
208        let w_k = mk_weight(d_model, kv_dim, rng)?;
209        let w_v = mk_weight(d_model, kv_dim, rng)?;
210        let w_o = mk_weight(q_dim, d_model, rng)?;
211
212        let scale = F::one()
213            / F::from(config.head_dim)
214                .ok_or_else(|| NeuralError::InvalidArchitecture("scale cast".into()))?
215                .sqrt();
216
217        Ok(Self {
218            d_model,
219            config,
220            w_q,
221            w_k,
222            w_v,
223            w_o,
224            scale,
225        })
226    }
227
228    /// Forward pass with optional KV cache for autoregressive generation
229    ///
230    /// # Arguments
231    /// * `input` - [batch, seq_len, d_model]
232    /// * `past_kv` - Optional past KV cache
233    ///
234    /// # Returns
235    /// (output [batch, seq_len, d_model], updated KV cache)
236    pub fn forward_with_cache(
237        &self,
238        input: &Array<F, IxDyn>,
239        past_kv: Option<&KvCache<F>>,
240    ) -> Result<(Array<F, IxDyn>, KvCache<F>)> {
241        if input.ndim() != 3 {
242            return Err(NeuralError::InvalidArchitecture(format!(
243                "MQA expects 3D input, got {}D",
244                input.ndim()
245            )));
246        }
247
248        let shape = input.shape();
249        let (batch, seq_len, d_model) = (shape[0], shape[1], shape[2]);
250
251        if d_model != self.d_model {
252            return Err(NeuralError::InvalidArchitecture(format!(
253                "input dim {d_model} != d_model {}",
254                self.d_model
255            )));
256        }
257
258        let num_heads = self.config.num_heads;
259        let num_kv_heads = self.config.num_kv_heads;
260        let head_dim = self.config.head_dim;
261        let group_size = num_heads / num_kv_heads;
262
263        // Project Q, K, V
264        let q_4d =
265            self.project_and_reshape(input, &self.w_q, batch, seq_len, num_heads, head_dim)?;
266        let k_new =
267            self.project_and_reshape(input, &self.w_k, batch, seq_len, num_kv_heads, head_dim)?;
268        let v_new =
269            self.project_and_reshape(input, &self.w_v, batch, seq_len, num_kv_heads, head_dim)?;
270
271        // Concatenate with past cache if provided
272        let (k_4d, v_4d, total_kv_len) = if let Some(cache) = past_kv {
273            let past_len = cache.keys.shape()[1];
274            let total = past_len + seq_len;
275            let k_full =
276                self.concat_cache(&cache.keys, &k_new, batch, total, num_kv_heads, head_dim)?;
277            let v_full =
278                self.concat_cache(&cache.values, &v_new, batch, total, num_kv_heads, head_dim)?;
279            (k_full, v_full, total)
280        } else {
281            (k_new.clone(), v_new.clone(), seq_len)
282        };
283
284        // Build updated cache
285        let new_cache = KvCache {
286            keys: k_4d.clone().into_dyn(),
287            values: v_4d.clone().into_dyn(),
288        };
289
290        // Compute attention
291        // Q: [batch, seq_len, num_heads, head_dim]
292        // K, V: [batch, total_kv_len, num_kv_heads, head_dim]
293        let mut output_4d = Array4::<F>::zeros((batch, seq_len, num_heads, head_dim));
294
295        for b in 0..batch {
296            for kv_h in 0..num_kv_heads {
297                let q_h_start = kv_h * group_size;
298                let q_h_end = q_h_start + group_size;
299
300                for q_h in q_h_start..q_h_end {
301                    for t in 0..seq_len {
302                        // Compute attention scores
303                        let global_t = if past_kv.is_some() {
304                            let past_len = past_kv.map(|c| c.keys.shape()[1]).unwrap_or(0);
305                            past_len + t
306                        } else {
307                            t
308                        };
309
310                        let mut scores = Vec::with_capacity(total_kv_len);
311                        for s_idx in 0..total_kv_len {
312                            if self.config.causal && s_idx > global_t {
313                                scores.push(F::neg_infinity());
314                            } else {
315                                let mut dot = F::zero();
316                                for d in 0..head_dim {
317                                    dot += q_4d[[b, t, q_h, d]] * k_4d[[b, s_idx, kv_h, d]];
318                                }
319                                scores.push(dot * self.scale);
320                            }
321                        }
322
323                        softmax_inplace(&mut scores);
324
325                        // Weighted sum of values
326                        for d in 0..head_dim {
327                            let mut acc = F::zero();
328                            for s_idx in 0..total_kv_len {
329                                acc += scores[s_idx] * v_4d[[b, s_idx, kv_h, d]];
330                            }
331                            output_4d[[b, t, q_h, d]] = acc;
332                        }
333                    }
334                }
335            }
336        }
337
338        // Reshape to [batch, seq_len, d_model] and project output
339        let output_3d = output_4d
340            .into_shape_with_order((batch, seq_len, d_model))
341            .map_err(|e| NeuralError::InferenceError(format!("reshape output: {e}")))?;
342
343        let output_2d = output_3d
344            .into_shape_with_order((batch * seq_len, d_model))
345            .map_err(|e| NeuralError::InferenceError(format!("reshape for O proj: {e}")))?;
346
347        let w_o_2d = self
348            .w_o
349            .view()
350            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
351            .map_err(|_| NeuralError::InferenceError("O weights 2D".into()))?;
352
353        let final_out = output_2d.dot(&w_o_2d);
354
355        let result = final_out
356            .into_shape_with_order((batch, seq_len, d_model))
357            .map_err(|e| NeuralError::InferenceError(format!("reshape final: {e}")))?;
358
359        Ok((result.into_dyn(), new_cache))
360    }
361
362    /// Project input and reshape to [batch, seq, heads, head_dim]
363    fn project_and_reshape(
364        &self,
365        input: &Array<F, IxDyn>,
366        weight: &Array<F, IxDyn>,
367        batch: usize,
368        seq: usize,
369        heads: usize,
370        head_dim: usize,
371    ) -> Result<Array4<F>> {
372        let d_model = input.shape()[2];
373        let proj_dim = heads * head_dim;
374
375        // [batch * seq, d_model] @ [d_model, proj_dim] = [batch * seq, proj_dim]
376        let input_2d = input
377            .clone()
378            .into_shape_with_order(IxDyn(&[batch * seq, d_model]))
379            .map_err(|e| NeuralError::InferenceError(format!("reshape: {e}")))?;
380
381        let input_2d_view = input_2d
382            .view()
383            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
384            .map_err(|_| NeuralError::InferenceError("to Ix2".into()))?;
385
386        let w_2d = weight
387            .view()
388            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
389            .map_err(|_| NeuralError::InferenceError("weight to Ix2".into()))?;
390
391        let projected = input_2d_view.dot(&w_2d);
392
393        projected
394            .into_shape_with_order((batch, seq, heads, head_dim))
395            .map_err(|e| NeuralError::InferenceError(format!("reshape projected: {e}")))
396    }
397
398    /// Concatenate past cache with new KV along the seq dimension
399    fn concat_cache(
400        &self,
401        past: &Array<F, IxDyn>,
402        new: &Array4<F>,
403        batch: usize,
404        total_len: usize,
405        heads: usize,
406        head_dim: usize,
407    ) -> Result<Array4<F>> {
408        let past_len = past.shape()[1];
409        let new_len = new.shape()[1];
410
411        if past_len + new_len != total_len {
412            return Err(NeuralError::InferenceError(
413                "cache concat length mismatch".into(),
414            ));
415        }
416
417        let mut result = Array4::<F>::zeros((batch, total_len, heads, head_dim));
418
419        // Copy past
420        for b in 0..batch {
421            for t in 0..past_len {
422                for h in 0..heads {
423                    for d in 0..head_dim {
424                        result[[b, t, h, d]] = past[[b, t, h, d]];
425                    }
426                }
427            }
428            // Copy new
429            for t in 0..new_len {
430                for h in 0..heads {
431                    for d in 0..head_dim {
432                        result[[b, past_len + t, h, d]] = new[[b, t, h, d]];
433                    }
434                }
435            }
436        }
437
438        Ok(result)
439    }
440
441    /// Get configuration
442    pub fn config(&self) -> &MultiQueryAttentionConfig {
443        &self.config
444    }
445
446    /// Get model dimension
447    pub fn d_model(&self) -> usize {
448        self.d_model
449    }
450}
451
452impl<F> Layer<F> for MultiQueryAttention<F>
453where
454    F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign,
455{
456    fn as_any(&self) -> &dyn std::any::Any {
457        self
458    }
459
460    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
461        self
462    }
463
464    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
465        let (output, _cache) = self.forward_with_cache(input, None)?;
466        Ok(output)
467    }
468
469    fn backward(
470        &self,
471        _input: &Array<F, IxDyn>,
472        _grad_output: &Array<F, IxDyn>,
473    ) -> Result<Array<F, IxDyn>> {
474        Err(NeuralError::NotImplementedError(
475            "MQA backward not yet implemented".into(),
476        ))
477    }
478
479    fn update(&mut self, _learning_rate: F) -> Result<()> {
480        Ok(())
481    }
482
483    fn layer_type(&self) -> &str {
484        "MultiQueryAttention"
485    }
486
487    fn parameter_count(&self) -> usize {
488        let q_dim = self.config.num_heads * self.config.head_dim;
489        let kv_dim = self.config.num_kv_heads * self.config.head_dim;
490        let dm = self.d_model;
491        dm * q_dim + 2 * dm * kv_dim + q_dim * dm
492    }
493}
494
495// ===========================================================================
496// Tests
497// ===========================================================================
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use scirs2_core::ndarray::Array3;
503
504    #[test]
505    fn test_mqa_creation() {
506        let mut rng = scirs2_core::random::rng();
507        let config = MultiQueryAttentionConfig::new(4, 16); // 4 Q heads, 1 KV head
508        let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng);
509        assert!(mqa.is_ok());
510    }
511
512    #[test]
513    fn test_mqa_forward_shape() {
514        let mut rng = scirs2_core::random::rng();
515        let config = MultiQueryAttentionConfig::new(4, 16);
516        let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("creation failed");
517
518        let input = Array3::<f64>::from_elem((2, 8, 64), 0.1).into_dyn();
519        let output = mqa.forward(&input).expect("forward failed");
520        assert_eq!(output.shape(), &[2, 8, 64]);
521    }
522
523    #[test]
524    fn test_mqa_kv_cache() {
525        let mut rng = scirs2_core::random::rng();
526        let config = MultiQueryAttentionConfig::new(4, 16).with_causal(true);
527        let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("creation failed");
528
529        // First step: process prefix
530        let prefix = Array3::<f64>::from_elem((1, 4, 64), 0.1).into_dyn();
531        let (out1, cache1) = mqa
532            .forward_with_cache(&prefix, None)
533            .expect("step 1 failed");
534        assert_eq!(out1.shape(), &[1, 4, 64]);
535        assert_eq!(cache1.keys.shape()[1], 4);
536        assert_eq!(cache1.values.shape()[1], 4);
537
538        // Second step: process one new token with cache
539        let new_token = Array3::<f64>::from_elem((1, 1, 64), 0.2).into_dyn();
540        let (out2, cache2) = mqa
541            .forward_with_cache(&new_token, Some(&cache1))
542            .expect("step 2 failed");
543        assert_eq!(out2.shape(), &[1, 1, 64]);
544        assert_eq!(cache2.keys.shape()[1], 5); // 4 + 1
545        assert_eq!(cache2.values.shape()[1], 5);
546    }
547
548    #[test]
549    fn test_mqa_with_num_heads_equals_mha() {
550        // When num_kv_heads == num_heads, MQA should behave like MHA
551        let mut rng = scirs2_core::random::rng();
552        let config = MultiQueryAttentionConfig::new(4, 16).with_num_kv_heads(4); // same as num_heads = MHA
553        let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("creation failed");
554
555        let input = Array3::<f64>::from_elem((1, 6, 64), 0.15).into_dyn();
556        let output = mqa.forward(&input).expect("forward failed");
557        assert_eq!(output.shape(), &[1, 6, 64]);
558
559        // Output should be finite
560        for val in output.iter() {
561            assert!(val.is_finite(), "MHA-mode output has non-finite value");
562        }
563    }
564
565    #[test]
566    fn test_mqa_causal_masking() {
567        let mut rng = scirs2_core::random::rng();
568        let config = MultiQueryAttentionConfig::new(2, 8).with_causal(true);
569        let mqa = MultiQueryAttention::<f64>::new(16, config, &mut rng).expect("creation failed");
570
571        let mut input = Array3::<f64>::zeros((1, 6, 16));
572        for t in 0..6 {
573            for d in 0..16 {
574                input[[0, t, d]] = (t as f64 + 1.0) * 0.1 + d as f64 * 0.01;
575            }
576        }
577
578        let output = mqa.forward(&input.into_dyn()).expect("forward failed");
579        assert_eq!(output.shape(), &[1, 6, 16]);
580
581        for val in output.iter() {
582            assert!(val.is_finite(), "causal output non-finite");
583        }
584    }
585
586    #[test]
587    fn test_mqa_invalid_config() {
588        let mut rng = scirs2_core::random::rng();
589
590        // num_heads not divisible by num_kv_heads
591        let config = MultiQueryAttentionConfig::new(5, 16).with_num_kv_heads(3);
592        let result = MultiQueryAttention::<f64>::new(80, config, &mut rng);
593        assert!(result.is_err());
594    }
595
596    #[test]
597    fn test_mqa_parameter_count() {
598        let mut rng = scirs2_core::random::rng();
599        let config = MultiQueryAttentionConfig::new(4, 16); // 1 KV head
600        let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("creation failed");
601
602        // Q: 64 * 64 = 4096
603        // K: 64 * 16 = 1024
604        // V: 64 * 16 = 1024
605        // O: 64 * 64 = 4096
606        assert_eq!(mqa.parameter_count(), 4096 + 1024 + 1024 + 4096);
607    }
608}