1use crate::error::{NeuralError, Result};
15use crate::layers::Layer;
16use scirs2_core::ndarray::{s, Array, Array2, Array4, IxDyn, ScalarOperand};
17use scirs2_core::numeric::{Float, NumAssign};
18use scirs2_core::random::{Rng, RngExt};
19use std::fmt::Debug;
20
21fn mk_weight<F: Float, R: Rng>(rows: usize, cols: usize, rng: &mut R) -> Result<Array<F, IxDyn>> {
27 let scale = (6.0_f64 / (rows + cols) as f64).sqrt();
28 let mut data = Vec::with_capacity(rows * cols);
29 for _ in 0..(rows * cols) {
30 let x: f64 = rng.random_range(-scale..scale);
31 let f = F::from(x)
32 .ok_or_else(|| NeuralError::InvalidArchitecture("xavier cast failed".into()))?;
33 data.push(f);
34 }
35 Array::from_shape_vec(IxDyn(&[rows, cols]), data)
36 .map_err(|e| NeuralError::InvalidArchitecture(format!("mk_weight: {e}")))
37}
38
39fn softmax_inplace<F: Float + NumAssign>(s: &mut [F]) {
41 let max_v = s
42 .iter()
43 .fold(F::neg_infinity(), |a, &b| if b > a { b } else { a });
44 let mut sum = F::zero();
45 for v in s.iter_mut() {
46 *v = (*v - max_v).exp();
47 sum += *v;
48 }
49 let eps = F::from(1e-12).unwrap_or(F::zero());
50 let norm = if sum < eps { eps } else { sum };
51 for v in s.iter_mut() {
52 *v /= norm;
53 }
54}
55
56#[derive(Debug, Clone)]
65pub struct KvCache<F: Float> {
66 pub keys: Array<F, IxDyn>,
68 pub values: Array<F, IxDyn>,
70}
71
72#[derive(Debug, Clone)]
78pub struct MultiQueryAttentionConfig {
79 pub num_heads: usize,
81 pub num_kv_heads: usize,
83 pub head_dim: usize,
85 pub dropout_prob: f64,
87 pub causal: bool,
89}
90
91impl Default for MultiQueryAttentionConfig {
92 fn default() -> Self {
93 Self {
94 num_heads: 8,
95 num_kv_heads: 1,
96 head_dim: 64,
97 dropout_prob: 0.0,
98 causal: false,
99 }
100 }
101}
102
103impl MultiQueryAttentionConfig {
104 pub fn new(num_heads: usize, head_dim: usize) -> Self {
106 Self {
107 num_heads,
108 num_kv_heads: 1,
109 head_dim,
110 ..Default::default()
111 }
112 }
113
114 pub fn with_num_kv_heads(mut self, n: usize) -> Self {
116 self.num_kv_heads = n;
117 self
118 }
119
120 pub fn with_causal(mut self, causal: bool) -> Self {
122 self.causal = causal;
123 self
124 }
125
126 pub fn with_dropout(mut self, prob: f64) -> Self {
128 self.dropout_prob = prob;
129 self
130 }
131}
132
133#[derive(Debug)]
164pub struct MultiQueryAttention<F: Float + Debug + Send + Sync + NumAssign> {
165 d_model: usize,
166 config: MultiQueryAttentionConfig,
167 w_q: Array<F, IxDyn>,
169 w_k: Array<F, IxDyn>,
171 w_v: Array<F, IxDyn>,
173 w_o: Array<F, IxDyn>,
175 scale: F,
176}
177
178impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign> MultiQueryAttention<F> {
179 pub fn new<R: Rng>(
181 d_model: usize,
182 config: MultiQueryAttentionConfig,
183 rng: &mut R,
184 ) -> Result<Self> {
185 if config.num_heads == 0 || config.num_kv_heads == 0 || config.head_dim == 0 {
186 return Err(NeuralError::InvalidArchitecture(
187 "num_heads, num_kv_heads, head_dim must be > 0".into(),
188 ));
189 }
190
191 if !config.num_heads.is_multiple_of(config.num_kv_heads) {
192 return Err(NeuralError::InvalidArchitecture(format!(
193 "num_heads ({}) must be divisible by num_kv_heads ({})",
194 config.num_heads, config.num_kv_heads
195 )));
196 }
197
198 let q_dim = config.num_heads * config.head_dim;
199 let kv_dim = config.num_kv_heads * config.head_dim;
200
201 if q_dim != d_model {
202 return Err(NeuralError::InvalidArchitecture(format!(
203 "num_heads * head_dim ({q_dim}) must equal d_model ({d_model})"
204 )));
205 }
206
207 let w_q = mk_weight(d_model, q_dim, rng)?;
208 let w_k = mk_weight(d_model, kv_dim, rng)?;
209 let w_v = mk_weight(d_model, kv_dim, rng)?;
210 let w_o = mk_weight(q_dim, d_model, rng)?;
211
212 let scale = F::one()
213 / F::from(config.head_dim)
214 .ok_or_else(|| NeuralError::InvalidArchitecture("scale cast".into()))?
215 .sqrt();
216
217 Ok(Self {
218 d_model,
219 config,
220 w_q,
221 w_k,
222 w_v,
223 w_o,
224 scale,
225 })
226 }
227
228 pub fn forward_with_cache(
237 &self,
238 input: &Array<F, IxDyn>,
239 past_kv: Option<&KvCache<F>>,
240 ) -> Result<(Array<F, IxDyn>, KvCache<F>)> {
241 if input.ndim() != 3 {
242 return Err(NeuralError::InvalidArchitecture(format!(
243 "MQA expects 3D input, got {}D",
244 input.ndim()
245 )));
246 }
247
248 let shape = input.shape();
249 let (batch, seq_len, d_model) = (shape[0], shape[1], shape[2]);
250
251 if d_model != self.d_model {
252 return Err(NeuralError::InvalidArchitecture(format!(
253 "input dim {d_model} != d_model {}",
254 self.d_model
255 )));
256 }
257
258 let num_heads = self.config.num_heads;
259 let num_kv_heads = self.config.num_kv_heads;
260 let head_dim = self.config.head_dim;
261 let group_size = num_heads / num_kv_heads;
262
263 let q_4d =
265 self.project_and_reshape(input, &self.w_q, batch, seq_len, num_heads, head_dim)?;
266 let k_new =
267 self.project_and_reshape(input, &self.w_k, batch, seq_len, num_kv_heads, head_dim)?;
268 let v_new =
269 self.project_and_reshape(input, &self.w_v, batch, seq_len, num_kv_heads, head_dim)?;
270
271 let (k_4d, v_4d, total_kv_len) = if let Some(cache) = past_kv {
273 let past_len = cache.keys.shape()[1];
274 let total = past_len + seq_len;
275 let k_full =
276 self.concat_cache(&cache.keys, &k_new, batch, total, num_kv_heads, head_dim)?;
277 let v_full =
278 self.concat_cache(&cache.values, &v_new, batch, total, num_kv_heads, head_dim)?;
279 (k_full, v_full, total)
280 } else {
281 (k_new.clone(), v_new.clone(), seq_len)
282 };
283
284 let new_cache = KvCache {
286 keys: k_4d.clone().into_dyn(),
287 values: v_4d.clone().into_dyn(),
288 };
289
290 let mut output_4d = Array4::<F>::zeros((batch, seq_len, num_heads, head_dim));
294
295 for b in 0..batch {
296 for kv_h in 0..num_kv_heads {
297 let q_h_start = kv_h * group_size;
298 let q_h_end = q_h_start + group_size;
299
300 for q_h in q_h_start..q_h_end {
301 for t in 0..seq_len {
302 let global_t = if past_kv.is_some() {
304 let past_len = past_kv.map(|c| c.keys.shape()[1]).unwrap_or(0);
305 past_len + t
306 } else {
307 t
308 };
309
310 let mut scores = Vec::with_capacity(total_kv_len);
311 for s_idx in 0..total_kv_len {
312 if self.config.causal && s_idx > global_t {
313 scores.push(F::neg_infinity());
314 } else {
315 let mut dot = F::zero();
316 for d in 0..head_dim {
317 dot += q_4d[[b, t, q_h, d]] * k_4d[[b, s_idx, kv_h, d]];
318 }
319 scores.push(dot * self.scale);
320 }
321 }
322
323 softmax_inplace(&mut scores);
324
325 for d in 0..head_dim {
327 let mut acc = F::zero();
328 for s_idx in 0..total_kv_len {
329 acc += scores[s_idx] * v_4d[[b, s_idx, kv_h, d]];
330 }
331 output_4d[[b, t, q_h, d]] = acc;
332 }
333 }
334 }
335 }
336 }
337
338 let output_3d = output_4d
340 .into_shape_with_order((batch, seq_len, d_model))
341 .map_err(|e| NeuralError::InferenceError(format!("reshape output: {e}")))?;
342
343 let output_2d = output_3d
344 .into_shape_with_order((batch * seq_len, d_model))
345 .map_err(|e| NeuralError::InferenceError(format!("reshape for O proj: {e}")))?;
346
347 let w_o_2d = self
348 .w_o
349 .view()
350 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
351 .map_err(|_| NeuralError::InferenceError("O weights 2D".into()))?;
352
353 let final_out = output_2d.dot(&w_o_2d);
354
355 let result = final_out
356 .into_shape_with_order((batch, seq_len, d_model))
357 .map_err(|e| NeuralError::InferenceError(format!("reshape final: {e}")))?;
358
359 Ok((result.into_dyn(), new_cache))
360 }
361
362 fn project_and_reshape(
364 &self,
365 input: &Array<F, IxDyn>,
366 weight: &Array<F, IxDyn>,
367 batch: usize,
368 seq: usize,
369 heads: usize,
370 head_dim: usize,
371 ) -> Result<Array4<F>> {
372 let d_model = input.shape()[2];
373 let proj_dim = heads * head_dim;
374
375 let input_2d = input
377 .clone()
378 .into_shape_with_order(IxDyn(&[batch * seq, d_model]))
379 .map_err(|e| NeuralError::InferenceError(format!("reshape: {e}")))?;
380
381 let input_2d_view = input_2d
382 .view()
383 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
384 .map_err(|_| NeuralError::InferenceError("to Ix2".into()))?;
385
386 let w_2d = weight
387 .view()
388 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
389 .map_err(|_| NeuralError::InferenceError("weight to Ix2".into()))?;
390
391 let projected = input_2d_view.dot(&w_2d);
392
393 projected
394 .into_shape_with_order((batch, seq, heads, head_dim))
395 .map_err(|e| NeuralError::InferenceError(format!("reshape projected: {e}")))
396 }
397
398 fn concat_cache(
400 &self,
401 past: &Array<F, IxDyn>,
402 new: &Array4<F>,
403 batch: usize,
404 total_len: usize,
405 heads: usize,
406 head_dim: usize,
407 ) -> Result<Array4<F>> {
408 let past_len = past.shape()[1];
409 let new_len = new.shape()[1];
410
411 if past_len + new_len != total_len {
412 return Err(NeuralError::InferenceError(
413 "cache concat length mismatch".into(),
414 ));
415 }
416
417 let mut result = Array4::<F>::zeros((batch, total_len, heads, head_dim));
418
419 for b in 0..batch {
421 for t in 0..past_len {
422 for h in 0..heads {
423 for d in 0..head_dim {
424 result[[b, t, h, d]] = past[[b, t, h, d]];
425 }
426 }
427 }
428 for t in 0..new_len {
430 for h in 0..heads {
431 for d in 0..head_dim {
432 result[[b, past_len + t, h, d]] = new[[b, t, h, d]];
433 }
434 }
435 }
436 }
437
438 Ok(result)
439 }
440
441 pub fn config(&self) -> &MultiQueryAttentionConfig {
443 &self.config
444 }
445
446 pub fn d_model(&self) -> usize {
448 self.d_model
449 }
450}
451
452impl<F> Layer<F> for MultiQueryAttention<F>
453where
454 F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign,
455{
456 fn as_any(&self) -> &dyn std::any::Any {
457 self
458 }
459
460 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
461 self
462 }
463
464 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
465 let (output, _cache) = self.forward_with_cache(input, None)?;
466 Ok(output)
467 }
468
469 fn backward(
470 &self,
471 _input: &Array<F, IxDyn>,
472 _grad_output: &Array<F, IxDyn>,
473 ) -> Result<Array<F, IxDyn>> {
474 Err(NeuralError::NotImplementedError(
475 "MQA backward not yet implemented".into(),
476 ))
477 }
478
479 fn update(&mut self, _learning_rate: F) -> Result<()> {
480 Ok(())
481 }
482
483 fn layer_type(&self) -> &str {
484 "MultiQueryAttention"
485 }
486
487 fn parameter_count(&self) -> usize {
488 let q_dim = self.config.num_heads * self.config.head_dim;
489 let kv_dim = self.config.num_kv_heads * self.config.head_dim;
490 let dm = self.d_model;
491 dm * q_dim + 2 * dm * kv_dim + q_dim * dm
492 }
493}
494
495#[cfg(test)]
500mod tests {
501 use super::*;
502 use scirs2_core::ndarray::Array3;
503
504 #[test]
505 fn test_mqa_creation() {
506 let mut rng = scirs2_core::random::rng();
507 let config = MultiQueryAttentionConfig::new(4, 16); let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng);
509 assert!(mqa.is_ok());
510 }
511
512 #[test]
513 fn test_mqa_forward_shape() {
514 let mut rng = scirs2_core::random::rng();
515 let config = MultiQueryAttentionConfig::new(4, 16);
516 let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("creation failed");
517
518 let input = Array3::<f64>::from_elem((2, 8, 64), 0.1).into_dyn();
519 let output = mqa.forward(&input).expect("forward failed");
520 assert_eq!(output.shape(), &[2, 8, 64]);
521 }
522
523 #[test]
524 fn test_mqa_kv_cache() {
525 let mut rng = scirs2_core::random::rng();
526 let config = MultiQueryAttentionConfig::new(4, 16).with_causal(true);
527 let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("creation failed");
528
529 let prefix = Array3::<f64>::from_elem((1, 4, 64), 0.1).into_dyn();
531 let (out1, cache1) = mqa
532 .forward_with_cache(&prefix, None)
533 .expect("step 1 failed");
534 assert_eq!(out1.shape(), &[1, 4, 64]);
535 assert_eq!(cache1.keys.shape()[1], 4);
536 assert_eq!(cache1.values.shape()[1], 4);
537
538 let new_token = Array3::<f64>::from_elem((1, 1, 64), 0.2).into_dyn();
540 let (out2, cache2) = mqa
541 .forward_with_cache(&new_token, Some(&cache1))
542 .expect("step 2 failed");
543 assert_eq!(out2.shape(), &[1, 1, 64]);
544 assert_eq!(cache2.keys.shape()[1], 5); assert_eq!(cache2.values.shape()[1], 5);
546 }
547
548 #[test]
549 fn test_mqa_with_num_heads_equals_mha() {
550 let mut rng = scirs2_core::random::rng();
552 let config = MultiQueryAttentionConfig::new(4, 16).with_num_kv_heads(4); let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("creation failed");
554
555 let input = Array3::<f64>::from_elem((1, 6, 64), 0.15).into_dyn();
556 let output = mqa.forward(&input).expect("forward failed");
557 assert_eq!(output.shape(), &[1, 6, 64]);
558
559 for val in output.iter() {
561 assert!(val.is_finite(), "MHA-mode output has non-finite value");
562 }
563 }
564
565 #[test]
566 fn test_mqa_causal_masking() {
567 let mut rng = scirs2_core::random::rng();
568 let config = MultiQueryAttentionConfig::new(2, 8).with_causal(true);
569 let mqa = MultiQueryAttention::<f64>::new(16, config, &mut rng).expect("creation failed");
570
571 let mut input = Array3::<f64>::zeros((1, 6, 16));
572 for t in 0..6 {
573 for d in 0..16 {
574 input[[0, t, d]] = (t as f64 + 1.0) * 0.1 + d as f64 * 0.01;
575 }
576 }
577
578 let output = mqa.forward(&input.into_dyn()).expect("forward failed");
579 assert_eq!(output.shape(), &[1, 6, 16]);
580
581 for val in output.iter() {
582 assert!(val.is_finite(), "causal output non-finite");
583 }
584 }
585
586 #[test]
587 fn test_mqa_invalid_config() {
588 let mut rng = scirs2_core::random::rng();
589
590 let config = MultiQueryAttentionConfig::new(5, 16).with_num_kv_heads(3);
592 let result = MultiQueryAttention::<f64>::new(80, config, &mut rng);
593 assert!(result.is_err());
594 }
595
596 #[test]
597 fn test_mqa_parameter_count() {
598 let mut rng = scirs2_core::random::rng();
599 let config = MultiQueryAttentionConfig::new(4, 16); let mqa = MultiQueryAttention::<f64>::new(64, config, &mut rng).expect("creation failed");
601
602 assert_eq!(mqa.parameter_count(), 4096 + 1024 + 1024 + 4096);
607 }
608}