Skip to main content

ruvector_attention/attention/
mla.rs

1//! Multi-Head Latent Attention (MLA) from DeepSeek-V2/V3.
2//!
3//! Achieves ~93% KV-cache reduction by compressing key-value pairs into a
4//! low-dimensional latent space. Instead of caching full K,V per head per
5//! position (`2 * num_heads * head_dim` floats), MLA caches only the latent
6//! vector `c_kv` (`latent_dim` floats) and decompresses K,V on-the-fly:
7//!
8//! 1. Down-project: `c_kv = x @ W_dkv` (d_model -> latent_dim)
9//! 2. Up-project:   `K = c_kv @ W_uk`, `V = c_kv @ W_uv`
10//! 3. Query path:   `c_q = x @ W_dq`, `Q = c_q @ W_uq` (same low-rank trick)
11//! 4. RoPE bypass:  A `rope_dim`-sized portion of each key skips compression
12//!    and receives Rotary Position Embeddings directly.
13
14use crate::error::{AttentionError, AttentionResult};
15use crate::traits::Attention;
16
17/// Configuration for Multi-Head Latent Attention.
18#[derive(Clone, Debug)]
19pub struct MLAConfig {
20    pub d_model: usize,
21    pub latent_dim: usize,
22    pub latent_dim_q: Option<usize>,
23    pub num_heads: usize,
24    pub head_dim: usize,
25    /// Must be even and <= head_dim. Set to 0 to disable RoPE decoupling.
26    pub rope_dim: usize,
27}
28
29impl MLAConfig {
30    pub fn validate(&self) -> AttentionResult<()> {
31        let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
32        if self.d_model == 0 {
33            return err("d_model must be > 0");
34        }
35        if self.num_heads == 0 {
36            return err("num_heads must be > 0");
37        }
38        if self.head_dim == 0 {
39            return err("head_dim must be > 0");
40        }
41        if self.latent_dim == 0 {
42            return err("latent_dim must be > 0");
43        }
44        if self.latent_dim >= self.full_kv_dim() {
45            return err("latent_dim must be < num_heads * head_dim");
46        }
47        if self.rope_dim > self.head_dim {
48            return err("rope_dim must be <= head_dim");
49        }
50        if self.rope_dim > 0 && self.rope_dim % 2 != 0 {
51            return err("rope_dim must be even (RoPE operates on pairs)");
52        }
53        Ok(())
54    }
55
56    pub fn effective_latent_dim_q(&self) -> usize {
57        self.latent_dim_q.unwrap_or(self.latent_dim)
58    }
59
60    pub fn full_kv_dim(&self) -> usize {
61        self.num_heads * self.head_dim
62    }
63}
64
65/// KV cache storing only latent vectors instead of full K,V per head.
66#[derive(Clone, Debug)]
67pub struct MLACache {
68    pub latent_vectors: Vec<Vec<f32>>,
69    pub rope_keys: Vec<Vec<f32>>,
70    latent_dim: usize,
71    rope_dim: usize,
72    num_heads: usize,
73    head_dim: usize,
74}
75
76impl MLACache {
77    pub fn new(config: &MLAConfig) -> Self {
78        Self {
79            latent_vectors: Vec::new(),
80            rope_keys: Vec::new(),
81            latent_dim: config.latent_dim,
82            rope_dim: config.rope_dim,
83            num_heads: config.num_heads,
84            head_dim: config.head_dim,
85        }
86    }
87
88    pub fn push(&mut self, latent: Vec<f32>, rope_key: Vec<f32>) {
89        self.latent_vectors.push(latent);
90        self.rope_keys.push(rope_key);
91    }
92
93    pub fn len(&self) -> usize {
94        self.latent_vectors.len()
95    }
96    pub fn is_empty(&self) -> bool {
97        self.latent_vectors.is_empty()
98    }
99
100    /// Total floats stored in this MLA cache.
101    pub fn cache_size(&self) -> usize {
102        self.len() * (self.latent_dim + self.rope_dim)
103    }
104
105    /// Total floats standard MHA would store for the same positions.
106    pub fn mha_equivalent_size(&self) -> usize {
107        self.len() * 2 * self.num_heads * self.head_dim
108    }
109
110    /// KV-cache reduction ratio (e.g. 0.9375 = 93.75% reduction vs MHA).
111    pub fn reduction_ratio(&self) -> f32 {
112        if self.len() == 0 {
113            return 0.0;
114        }
115        1.0 - (self.cache_size() as f32 / self.mha_equivalent_size() as f32)
116    }
117}
118
119/// Multi-Head Latent Attention layer with projection weights (row-major).
120pub struct MLALayer {
121    config: MLAConfig,
122    w_dkv: Vec<f32>,  // d_model -> latent_dim
123    w_uk: Vec<f32>,   // latent_dim -> full_kv_dim (keys)
124    w_uv: Vec<f32>,   // latent_dim -> full_kv_dim (values)
125    w_dq: Vec<f32>,   // d_model -> latent_dim_q
126    w_uq: Vec<f32>,   // latent_dim_q -> full_kv_dim
127    w_rope: Vec<f32>, // d_model -> rope_dim
128    w_out: Vec<f32>,  // full_kv_dim -> d_model
129}
130
131impl MLALayer {
132    /// Creates a new MLA layer with deterministic Xavier-style initialization.
133    pub fn new(config: MLAConfig) -> AttentionResult<Self> {
134        config.validate()?;
135        let fd = config.full_kv_dim();
136        let lq = config.effective_latent_dim_q();
137        Ok(Self {
138            w_dkv: init_weight(config.d_model, config.latent_dim),
139            w_uk: init_weight(config.latent_dim, fd),
140            w_uv: init_weight(config.latent_dim, fd),
141            w_dq: init_weight(config.d_model, lq),
142            w_uq: init_weight(lq, fd),
143            w_rope: init_weight(config.d_model, config.rope_dim),
144            w_out: init_weight(fd, config.d_model),
145            config,
146        })
147    }
148
149    pub fn config(&self) -> &MLAConfig {
150        &self.config
151    }
152
153    /// Compress input to KV latent: `c_kv = x @ W_dkv`.
154    pub fn compress_kv(&self, x: &[f32]) -> Vec<f32> {
155        matvec(&self.w_dkv, x, self.config.d_model, self.config.latent_dim)
156    }
157
158    /// Decompress latent to keys: `K = c_kv @ W_uk`.
159    pub fn decompress_keys(&self, c: &[f32]) -> Vec<f32> {
160        matvec(
161            &self.w_uk,
162            c,
163            self.config.latent_dim,
164            self.config.full_kv_dim(),
165        )
166    }
167
168    /// Decompress latent to values: `V = c_kv @ W_uv`.
169    pub fn decompress_values(&self, c: &[f32]) -> Vec<f32> {
170        matvec(
171            &self.w_uv,
172            c,
173            self.config.latent_dim,
174            self.config.full_kv_dim(),
175        )
176    }
177
178    fn compute_rope_keys(&self, x: &[f32]) -> Vec<f32> {
179        if self.config.rope_dim == 0 {
180            return Vec::new();
181        }
182        matvec(&self.w_rope, x, self.config.d_model, self.config.rope_dim)
183    }
184
185    fn compute_query(&self, x: &[f32]) -> Vec<f32> {
186        let lq = self.config.effective_latent_dim_q();
187        let c_q = matvec(&self.w_dq, x, self.config.d_model, lq);
188        matvec(&self.w_uq, &c_q, lq, self.config.full_kv_dim())
189    }
190
191    /// Applies RoPE rotation to pairs of dimensions based on position.
192    fn apply_rope(v: &mut [f32], position: usize) {
193        let dim = v.len();
194        for i in (0..dim).step_by(2) {
195            if i + 1 >= dim {
196                break;
197            }
198            let freq = 1.0 / (10000.0_f32).powf(i as f32 / dim as f32);
199            let theta = position as f32 * freq;
200            let (cos_t, sin_t) = (theta.cos(), theta.sin());
201            let (x0, x1) = (v[i], v[i + 1]);
202            v[i] = x0 * cos_t - x1 * sin_t;
203            v[i + 1] = x0 * sin_t + x1 * cos_t;
204        }
205    }
206
207    /// Core attention computation shared by `forward` and `forward_cached`.
208    fn attend(&self, q_full: &[f32], all_keys: &[Vec<f32>], all_values: &[Vec<f32>]) -> Vec<f32> {
209        let (nh, hd) = (self.config.num_heads, self.config.head_dim);
210        let scale = (hd as f32).sqrt();
211        let mut out = vec![0.0_f32; nh * hd];
212        for h in 0..nh {
213            let off = h * hd;
214            let qh = &q_full[off..off + hd];
215            let mut scores: Vec<f32> = all_keys
216                .iter()
217                .map(|k| dot(&k[off..off + hd], qh) / scale)
218                .collect();
219            softmax_inplace(&mut scores);
220            for (si, &w) in scores.iter().enumerate() {
221                let vh = &all_values[si][off..off + hd];
222                for d in 0..hd {
223                    out[off + d] += w * vh[d];
224                }
225            }
226        }
227        matvec(
228            &self.w_out,
229            &out,
230            self.config.full_kv_dim(),
231            self.config.d_model,
232        )
233    }
234
235    /// Prepares query with RoPE applied to the decoupled portion of each head.
236    fn prepare_query(&self, input: &[f32], pos: usize) -> Vec<f32> {
237        let mut q = self.compute_query(input);
238        let (nh, hd, rd) = (
239            self.config.num_heads,
240            self.config.head_dim,
241            self.config.rope_dim,
242        );
243        if rd > 0 {
244            for h in 0..nh {
245                Self::apply_rope(&mut q[h * hd..h * hd + rd], pos);
246            }
247        }
248        q
249    }
250
251    /// Decompresses a latent+rope pair into full keys/values for one position.
252    fn decompress_position(
253        &self,
254        latent: &[f32],
255        rope: &[f32],
256        pos: usize,
257    ) -> (Vec<f32>, Vec<f32>) {
258        let mut keys = self.decompress_keys(latent);
259        let values = self.decompress_values(latent);
260        let (nh, hd, rd) = (
261            self.config.num_heads,
262            self.config.head_dim,
263            self.config.rope_dim,
264        );
265        if rd > 0 {
266            let mut rp = rope.to_vec();
267            Self::apply_rope(&mut rp, pos);
268            for h in 0..nh {
269                keys[h * hd..h * hd + rd].copy_from_slice(&rp);
270            }
271        }
272        (keys, values)
273    }
274
275    /// Full MLA forward pass for a single query position.
276    pub fn forward(
277        &self,
278        query_input: &[f32],
279        kv_inputs: &[&[f32]],
280        query_pos: usize,
281        kv_positions: &[usize],
282    ) -> AttentionResult<Vec<f32>> {
283        if query_input.len() != self.config.d_model {
284            return Err(AttentionError::DimensionMismatch {
285                expected: self.config.d_model,
286                actual: query_input.len(),
287            });
288        }
289        if kv_inputs.is_empty() {
290            return Err(AttentionError::EmptyInput("kv_inputs".into()));
291        }
292        if kv_inputs.len() != kv_positions.len() {
293            return Err(AttentionError::DimensionMismatch {
294                expected: kv_inputs.len(),
295                actual: kv_positions.len(),
296            });
297        }
298        let q_full = self.prepare_query(query_input, query_pos);
299        let mut all_k = Vec::with_capacity(kv_inputs.len());
300        let mut all_v = Vec::with_capacity(kv_inputs.len());
301        for (i, &kv) in kv_inputs.iter().enumerate() {
302            if kv.len() != self.config.d_model {
303                return Err(AttentionError::DimensionMismatch {
304                    expected: self.config.d_model,
305                    actual: kv.len(),
306                });
307            }
308            let c = self.compress_kv(kv);
309            let rope = self.compute_rope_keys(kv);
310            let (k, v) = self.decompress_position(&c, &rope, kv_positions[i]);
311            all_k.push(k);
312            all_v.push(v);
313        }
314        Ok(self.attend(&q_full, &all_k, &all_v))
315    }
316
317    /// Forward pass using incremental MLA cache (for autoregressive decoding).
318    pub fn forward_cached(
319        &self,
320        query_input: &[f32],
321        new_kv_input: &[f32],
322        query_pos: usize,
323        cache: &mut MLACache,
324    ) -> AttentionResult<Vec<f32>> {
325        if new_kv_input.len() != self.config.d_model {
326            return Err(AttentionError::DimensionMismatch {
327                expected: self.config.d_model,
328                actual: new_kv_input.len(),
329            });
330        }
331        cache.push(
332            self.compress_kv(new_kv_input),
333            self.compute_rope_keys(new_kv_input),
334        );
335        let q_full = self.prepare_query(query_input, query_pos);
336        let mut all_k = Vec::with_capacity(cache.len());
337        let mut all_v = Vec::with_capacity(cache.len());
338        for pos in 0..cache.len() {
339            let (k, v) =
340                self.decompress_position(&cache.latent_vectors[pos], &cache.rope_keys[pos], pos);
341            all_k.push(k);
342            all_v.push(v);
343        }
344        Ok(self.attend(&q_full, &all_k, &all_v))
345    }
346
347    /// Memory comparison report: MLA vs standard MHA caching.
348    pub fn memory_comparison(&self, seq_len: usize) -> MemoryComparison {
349        let mha = seq_len * 2 * self.config.num_heads * self.config.head_dim;
350        let mla = seq_len * (self.config.latent_dim + self.config.rope_dim);
351        MemoryComparison {
352            seq_len,
353            mha_cache_floats: mha,
354            mla_cache_floats: mla,
355            mha_cache_bytes: mha * 4,
356            mla_cache_bytes: mla * 4,
357            reduction_ratio: 1.0 - (mla as f32 / mha as f32),
358        }
359    }
360}
361
362/// Report comparing MLA vs MHA cache memory usage.
363#[derive(Clone, Debug)]
364pub struct MemoryComparison {
365    pub seq_len: usize,
366    pub mha_cache_floats: usize,
367    pub mla_cache_floats: usize,
368    pub mha_cache_bytes: usize,
369    pub mla_cache_bytes: usize,
370    pub reduction_ratio: f32,
371}
372
373impl Attention for MLALayer {
374    fn compute(
375        &self,
376        query: &[f32],
377        keys: &[&[f32]],
378        values: &[&[f32]],
379    ) -> AttentionResult<Vec<f32>> {
380        let _ = values; // MLA derives V from the same inputs as K
381        let positions: Vec<usize> = (0..keys.len()).collect();
382        self.forward(query, keys, 0, &positions)
383    }
384
385    fn compute_with_mask(
386        &self,
387        query: &[f32],
388        keys: &[&[f32]],
389        values: &[&[f32]],
390        _mask: Option<&[bool]>,
391    ) -> AttentionResult<Vec<f32>> {
392        self.compute(query, keys, values)
393    }
394
395    fn dim(&self) -> usize {
396        self.config.d_model
397    }
398    fn num_heads(&self) -> usize {
399        self.config.num_heads
400    }
401}
402
403// -- Utility functions --------------------------------------------------------
404
405fn matvec(w: &[f32], x: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
406    (0..out_d)
407        .map(|r| {
408            let off = r * in_d;
409            (0..in_d).map(|c| w[off + c] * x[c]).sum()
410        })
411        .collect()
412}
413
414fn dot(a: &[f32], b: &[f32]) -> f32 {
415    a.iter().zip(b).map(|(x, y)| x * y).sum()
416}
417
418fn softmax_inplace(s: &mut [f32]) {
419    let max = s.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
420    let mut sum = 0.0_f32;
421    for v in s.iter_mut() {
422        *v = (*v - max).exp();
423        sum += *v;
424    }
425    for v in s.iter_mut() {
426        *v /= sum;
427    }
428}
429
430fn init_weight(in_d: usize, out_d: usize) -> Vec<f32> {
431    let scale = (2.0 / (in_d + out_d) as f32).sqrt();
432    let period = (in_d + out_d).max(1);
433    (0..in_d * out_d)
434        .map(|i| scale * ((i % period) as f32 / period as f32 - 0.5))
435        .collect()
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    fn cfg() -> MLAConfig {
443        MLAConfig {
444            d_model: 32,
445            latent_dim: 8,
446            latent_dim_q: None,
447            num_heads: 4,
448            head_dim: 8,
449            rope_dim: 4,
450        }
451    }
452
453    #[test]
454    fn test_config_valid() {
455        assert!(cfg().validate().is_ok());
456    }
457
458    #[test]
459    fn test_config_latent_too_large() {
460        let mut c = cfg();
461        c.latent_dim = 999;
462        assert!(c.validate().is_err());
463    }
464
465    #[test]
466    fn test_config_rope_dim_odd() {
467        let mut c = cfg();
468        c.rope_dim = 3;
469        assert!(c.validate().is_err());
470    }
471
472    #[test]
473    fn test_config_zero_heads() {
474        let mut c = cfg();
475        c.num_heads = 0;
476        assert!(c.validate().is_err());
477    }
478
479    #[test]
480    fn test_forward_output_shape() {
481        let c = cfg();
482        let layer = MLALayer::new(c.clone()).unwrap();
483        let q = vec![0.1_f32; c.d_model];
484        let kv1 = vec![0.2_f32; c.d_model];
485        let kv2 = vec![0.3_f32; c.d_model];
486        let out = layer.forward(&q, &[&kv1, &kv2], 0, &[0, 1]).unwrap();
487        assert_eq!(out.len(), c.d_model);
488    }
489
490    #[test]
491    fn test_forward_dimension_mismatch() {
492        let layer = MLALayer::new(cfg()).unwrap();
493        let bad_q = vec![0.1_f32; 5];
494        let kv = vec![0.2_f32; 32];
495        assert!(layer.forward(&bad_q, &[&kv[..]], 0, &[0]).is_err());
496    }
497
498    #[test]
499    fn test_cache_size_reduction() {
500        let c = cfg();
501        let mut cache = MLACache::new(&c);
502        for _ in 0..10 {
503            cache.push(vec![0.0; c.latent_dim], vec![0.0; c.rope_dim]);
504        }
505        assert_eq!(cache.len(), 10);
506        assert_eq!(cache.cache_size(), 120); // 10 * (8+4)
507        assert_eq!(cache.mha_equivalent_size(), 640); // 10 * 2*4*8
508        assert!((cache.reduction_ratio() - 0.8125).abs() < 1e-4);
509    }
510
511    #[test]
512    fn test_memory_comparison_report() {
513        let c = MLAConfig {
514            d_model: 2048,
515            latent_dim: 256,
516            latent_dim_q: None,
517            num_heads: 16,
518            head_dim: 128,
519            rope_dim: 0,
520        };
521        let layer = MLALayer::new(c).unwrap();
522        let r = layer.memory_comparison(1024);
523        assert_eq!(r.mha_cache_floats, 4_194_304);
524        assert_eq!(r.mla_cache_floats, 262_144);
525        assert!((r.reduction_ratio - 0.9375).abs() < 1e-4);
526    }
527
528    #[test]
529    fn test_cached_forward_multi_position() {
530        let c = cfg();
531        let layer = MLALayer::new(c.clone()).unwrap();
532        let mut cache = MLACache::new(&c);
533        let q = vec![0.1_f32; c.d_model];
534        for pos in 0..3 {
535            let kv = vec![(pos as f32 + 1.0) * 0.1; c.d_model];
536            let out = layer.forward_cached(&q, &kv, pos, &mut cache).unwrap();
537            assert_eq!(out.len(), c.d_model);
538        }
539        assert_eq!(cache.len(), 3);
540        let kv_last = vec![0.4_f32; c.d_model];
541        let out = layer.forward_cached(&q, &kv_last, 3, &mut cache).unwrap();
542        assert!(out.iter().all(|v| v.is_finite()));
543        assert_eq!(cache.len(), 4);
544    }
545
546    #[test]
547    fn test_rope_identity_at_zero() {
548        let mut v = vec![1.0, 2.0, 3.0, 4.0];
549        let orig = v.clone();
550        MLALayer::apply_rope(&mut v, 0);
551        for (a, b) in v.iter().zip(&orig) {
552            assert!((a - b).abs() < 1e-6);
553        }
554    }
555
556    #[test]
557    fn test_rope_preserves_norm() {
558        let mut v = vec![1.0, 2.0, 3.0, 4.0];
559        let norm_before: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
560        MLALayer::apply_rope(&mut v, 42);
561        let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
562        assert!((norm_before - norm_after).abs() < 1e-5);
563    }
564
565    #[test]
566    fn test_compress_decompress_dimensions() {
567        let c = cfg();
568        let layer = MLALayer::new(c.clone()).unwrap();
569        let x = vec![0.5_f32; c.d_model];
570        let ckv = layer.compress_kv(&x);
571        assert_eq!(ckv.len(), c.latent_dim);
572        assert_eq!(layer.decompress_keys(&ckv).len(), c.full_kv_dim());
573        assert_eq!(layer.decompress_values(&ckv).len(), c.full_kv_dim());
574    }
575
576    #[test]
577    fn test_attention_trait() {
578        let c = cfg();
579        let layer = MLALayer::new(c.clone()).unwrap();
580        assert_eq!(layer.dim(), c.d_model);
581        assert_eq!(layer.num_heads(), c.num_heads);
582        let q = vec![0.1_f32; c.d_model];
583        let kv1 = vec![0.2_f32; c.d_model];
584        let kv2 = vec![0.3_f32; c.d_model];
585        let out = layer
586            .compute(&q, &[&kv1[..], &kv2[..]], &[&kv1[..], &kv2[..]])
587            .unwrap();
588        assert_eq!(out.len(), c.d_model);
589        assert!(out.iter().all(|v| v.is_finite()));
590    }
591
592    #[test]
593    fn test_empty_cache_ratio() {
594        let cache = MLACache::new(&cfg());
595        assert_eq!(cache.reduction_ratio(), 0.0);
596        assert!(cache.is_empty());
597    }
598}