Skip to main content

ruvector_attention/attention/
mla.rs

1//! Multi-Head Latent Attention (MLA) from DeepSeek-V2/V3.
2//!
3//! Achieves ~93% KV-cache reduction by compressing key-value pairs into a
4//! low-dimensional latent space. Instead of caching full K,V per head per
5//! position (`2 * num_heads * head_dim` floats), MLA caches only the latent
6//! vector `c_kv` (`latent_dim` floats) and decompresses K,V on-the-fly:
7//!
8//! 1. Down-project: `c_kv = x @ W_dkv` (d_model -> latent_dim)
9//! 2. Up-project:   `K = c_kv @ W_uk`, `V = c_kv @ W_uv`
10//! 3. Query path:   `c_q = x @ W_dq`, `Q = c_q @ W_uq` (same low-rank trick)
11//! 4. RoPE bypass:  A `rope_dim`-sized portion of each key skips compression
12//!    and receives Rotary Position Embeddings directly.
13
14use crate::error::{AttentionError, AttentionResult};
15use crate::traits::Attention;
16
17/// Configuration for Multi-Head Latent Attention.
18#[derive(Clone, Debug)]
19pub struct MLAConfig {
20    pub d_model: usize,
21    pub latent_dim: usize,
22    pub latent_dim_q: Option<usize>,
23    pub num_heads: usize,
24    pub head_dim: usize,
25    /// Must be even and <= head_dim. Set to 0 to disable RoPE decoupling.
26    pub rope_dim: usize,
27}
28
29impl MLAConfig {
30    pub fn validate(&self) -> AttentionResult<()> {
31        let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
32        if self.d_model == 0 { return err("d_model must be > 0"); }
33        if self.num_heads == 0 { return err("num_heads must be > 0"); }
34        if self.head_dim == 0 { return err("head_dim must be > 0"); }
35        if self.latent_dim == 0 { return err("latent_dim must be > 0"); }
36        if self.latent_dim >= self.full_kv_dim() {
37            return err("latent_dim must be < num_heads * head_dim");
38        }
39        if self.rope_dim > self.head_dim {
40            return err("rope_dim must be <= head_dim");
41        }
42        if self.rope_dim > 0 && self.rope_dim % 2 != 0 {
43            return err("rope_dim must be even (RoPE operates on pairs)");
44        }
45        Ok(())
46    }
47
48    pub fn effective_latent_dim_q(&self) -> usize {
49        self.latent_dim_q.unwrap_or(self.latent_dim)
50    }
51
52    pub fn full_kv_dim(&self) -> usize {
53        self.num_heads * self.head_dim
54    }
55}
56
57/// KV cache storing only latent vectors instead of full K,V per head.
58#[derive(Clone, Debug)]
59pub struct MLACache {
60    pub latent_vectors: Vec<Vec<f32>>,
61    pub rope_keys: Vec<Vec<f32>>,
62    latent_dim: usize,
63    rope_dim: usize,
64    num_heads: usize,
65    head_dim: usize,
66}
67
68impl MLACache {
69    pub fn new(config: &MLAConfig) -> Self {
70        Self {
71            latent_vectors: Vec::new(), rope_keys: Vec::new(),
72            latent_dim: config.latent_dim, rope_dim: config.rope_dim,
73            num_heads: config.num_heads, head_dim: config.head_dim,
74        }
75    }
76
77    pub fn push(&mut self, latent: Vec<f32>, rope_key: Vec<f32>) {
78        self.latent_vectors.push(latent);
79        self.rope_keys.push(rope_key);
80    }
81
82    pub fn len(&self) -> usize { self.latent_vectors.len() }
83    pub fn is_empty(&self) -> bool { self.latent_vectors.is_empty() }
84
85    /// Total floats stored in this MLA cache.
86    pub fn cache_size(&self) -> usize {
87        self.len() * (self.latent_dim + self.rope_dim)
88    }
89
90    /// Total floats standard MHA would store for the same positions.
91    pub fn mha_equivalent_size(&self) -> usize {
92        self.len() * 2 * self.num_heads * self.head_dim
93    }
94
95    /// KV-cache reduction ratio (e.g. 0.9375 = 93.75% reduction vs MHA).
96    pub fn reduction_ratio(&self) -> f32 {
97        if self.len() == 0 { return 0.0; }
98        1.0 - (self.cache_size() as f32 / self.mha_equivalent_size() as f32)
99    }
100}
101
102/// Multi-Head Latent Attention layer with projection weights (row-major).
103pub struct MLALayer {
104    config: MLAConfig,
105    w_dkv: Vec<f32>,  // d_model -> latent_dim
106    w_uk: Vec<f32>,   // latent_dim -> full_kv_dim (keys)
107    w_uv: Vec<f32>,   // latent_dim -> full_kv_dim (values)
108    w_dq: Vec<f32>,   // d_model -> latent_dim_q
109    w_uq: Vec<f32>,   // latent_dim_q -> full_kv_dim
110    w_rope: Vec<f32>, // d_model -> rope_dim
111    w_out: Vec<f32>,  // full_kv_dim -> d_model
112}
113
114impl MLALayer {
115    /// Creates a new MLA layer with deterministic Xavier-style initialization.
116    pub fn new(config: MLAConfig) -> AttentionResult<Self> {
117        config.validate()?;
118        let fd = config.full_kv_dim();
119        let lq = config.effective_latent_dim_q();
120        Ok(Self {
121            w_dkv: init_weight(config.d_model, config.latent_dim),
122            w_uk: init_weight(config.latent_dim, fd),
123            w_uv: init_weight(config.latent_dim, fd),
124            w_dq: init_weight(config.d_model, lq),
125            w_uq: init_weight(lq, fd),
126            w_rope: init_weight(config.d_model, config.rope_dim),
127            w_out: init_weight(fd, config.d_model),
128            config,
129        })
130    }
131
132    pub fn config(&self) -> &MLAConfig { &self.config }
133
134    /// Compress input to KV latent: `c_kv = x @ W_dkv`.
135    pub fn compress_kv(&self, x: &[f32]) -> Vec<f32> {
136        matvec(&self.w_dkv, x, self.config.d_model, self.config.latent_dim)
137    }
138
139    /// Decompress latent to keys: `K = c_kv @ W_uk`.
140    pub fn decompress_keys(&self, c: &[f32]) -> Vec<f32> {
141        matvec(&self.w_uk, c, self.config.latent_dim, self.config.full_kv_dim())
142    }
143
144    /// Decompress latent to values: `V = c_kv @ W_uv`.
145    pub fn decompress_values(&self, c: &[f32]) -> Vec<f32> {
146        matvec(&self.w_uv, c, self.config.latent_dim, self.config.full_kv_dim())
147    }
148
149    fn compute_rope_keys(&self, x: &[f32]) -> Vec<f32> {
150        if self.config.rope_dim == 0 { return Vec::new(); }
151        matvec(&self.w_rope, x, self.config.d_model, self.config.rope_dim)
152    }
153
154    fn compute_query(&self, x: &[f32]) -> Vec<f32> {
155        let lq = self.config.effective_latent_dim_q();
156        let c_q = matvec(&self.w_dq, x, self.config.d_model, lq);
157        matvec(&self.w_uq, &c_q, lq, self.config.full_kv_dim())
158    }
159
160    /// Applies RoPE rotation to pairs of dimensions based on position.
161    fn apply_rope(v: &mut [f32], position: usize) {
162        let dim = v.len();
163        for i in (0..dim).step_by(2) {
164            if i + 1 >= dim { break; }
165            let freq = 1.0 / (10000.0_f32).powf(i as f32 / dim as f32);
166            let theta = position as f32 * freq;
167            let (cos_t, sin_t) = (theta.cos(), theta.sin());
168            let (x0, x1) = (v[i], v[i + 1]);
169            v[i] = x0 * cos_t - x1 * sin_t;
170            v[i + 1] = x0 * sin_t + x1 * cos_t;
171        }
172    }
173
174    /// Core attention computation shared by `forward` and `forward_cached`.
175    fn attend(
176        &self, q_full: &[f32], all_keys: &[Vec<f32>], all_values: &[Vec<f32>],
177    ) -> Vec<f32> {
178        let (nh, hd) = (self.config.num_heads, self.config.head_dim);
179        let scale = (hd as f32).sqrt();
180        let mut out = vec![0.0_f32; nh * hd];
181        for h in 0..nh {
182            let off = h * hd;
183            let qh = &q_full[off..off + hd];
184            let mut scores: Vec<f32> = all_keys
185                .iter()
186                .map(|k| dot(&k[off..off + hd], qh) / scale)
187                .collect();
188            softmax_inplace(&mut scores);
189            for (si, &w) in scores.iter().enumerate() {
190                let vh = &all_values[si][off..off + hd];
191                for d in 0..hd { out[off + d] += w * vh[d]; }
192            }
193        }
194        matvec(&self.w_out, &out, self.config.full_kv_dim(), self.config.d_model)
195    }
196
197    /// Prepares query with RoPE applied to the decoupled portion of each head.
198    fn prepare_query(&self, input: &[f32], pos: usize) -> Vec<f32> {
199        let mut q = self.compute_query(input);
200        let (nh, hd, rd) = (self.config.num_heads, self.config.head_dim, self.config.rope_dim);
201        if rd > 0 {
202            for h in 0..nh { Self::apply_rope(&mut q[h * hd..h * hd + rd], pos); }
203        }
204        q
205    }
206
207    /// Decompresses a latent+rope pair into full keys/values for one position.
208    fn decompress_position(
209        &self, latent: &[f32], rope: &[f32], pos: usize,
210    ) -> (Vec<f32>, Vec<f32>) {
211        let mut keys = self.decompress_keys(latent);
212        let values = self.decompress_values(latent);
213        let (nh, hd, rd) = (self.config.num_heads, self.config.head_dim, self.config.rope_dim);
214        if rd > 0 {
215            let mut rp = rope.to_vec();
216            Self::apply_rope(&mut rp, pos);
217            for h in 0..nh { keys[h * hd..h * hd + rd].copy_from_slice(&rp); }
218        }
219        (keys, values)
220    }
221
222    /// Full MLA forward pass for a single query position.
223    pub fn forward(
224        &self, query_input: &[f32], kv_inputs: &[&[f32]],
225        query_pos: usize, kv_positions: &[usize],
226    ) -> AttentionResult<Vec<f32>> {
227        if query_input.len() != self.config.d_model {
228            return Err(AttentionError::DimensionMismatch {
229                expected: self.config.d_model, actual: query_input.len(),
230            });
231        }
232        if kv_inputs.is_empty() {
233            return Err(AttentionError::EmptyInput("kv_inputs".into()));
234        }
235        if kv_inputs.len() != kv_positions.len() {
236            return Err(AttentionError::DimensionMismatch {
237                expected: kv_inputs.len(), actual: kv_positions.len(),
238            });
239        }
240        let q_full = self.prepare_query(query_input, query_pos);
241        let mut all_k = Vec::with_capacity(kv_inputs.len());
242        let mut all_v = Vec::with_capacity(kv_inputs.len());
243        for (i, &kv) in kv_inputs.iter().enumerate() {
244            if kv.len() != self.config.d_model {
245                return Err(AttentionError::DimensionMismatch {
246                    expected: self.config.d_model, actual: kv.len(),
247                });
248            }
249            let c = self.compress_kv(kv);
250            let rope = self.compute_rope_keys(kv);
251            let (k, v) = self.decompress_position(&c, &rope, kv_positions[i]);
252            all_k.push(k);
253            all_v.push(v);
254        }
255        Ok(self.attend(&q_full, &all_k, &all_v))
256    }
257
258    /// Forward pass using incremental MLA cache (for autoregressive decoding).
259    pub fn forward_cached(
260        &self, query_input: &[f32], new_kv_input: &[f32],
261        query_pos: usize, cache: &mut MLACache,
262    ) -> AttentionResult<Vec<f32>> {
263        if new_kv_input.len() != self.config.d_model {
264            return Err(AttentionError::DimensionMismatch {
265                expected: self.config.d_model, actual: new_kv_input.len(),
266            });
267        }
268        cache.push(self.compress_kv(new_kv_input), self.compute_rope_keys(new_kv_input));
269        let q_full = self.prepare_query(query_input, query_pos);
270        let mut all_k = Vec::with_capacity(cache.len());
271        let mut all_v = Vec::with_capacity(cache.len());
272        for pos in 0..cache.len() {
273            let (k, v) = self.decompress_position(
274                &cache.latent_vectors[pos], &cache.rope_keys[pos], pos,
275            );
276            all_k.push(k);
277            all_v.push(v);
278        }
279        Ok(self.attend(&q_full, &all_k, &all_v))
280    }
281
282    /// Memory comparison report: MLA vs standard MHA caching.
283    pub fn memory_comparison(&self, seq_len: usize) -> MemoryComparison {
284        let mha = seq_len * 2 * self.config.num_heads * self.config.head_dim;
285        let mla = seq_len * (self.config.latent_dim + self.config.rope_dim);
286        MemoryComparison {
287            seq_len, mha_cache_floats: mha, mla_cache_floats: mla,
288            mha_cache_bytes: mha * 4, mla_cache_bytes: mla * 4,
289            reduction_ratio: 1.0 - (mla as f32 / mha as f32),
290        }
291    }
292}
293
294/// Report comparing MLA vs MHA cache memory usage.
295#[derive(Clone, Debug)]
296pub struct MemoryComparison {
297    pub seq_len: usize,
298    pub mha_cache_floats: usize,
299    pub mla_cache_floats: usize,
300    pub mha_cache_bytes: usize,
301    pub mla_cache_bytes: usize,
302    pub reduction_ratio: f32,
303}
304
305impl Attention for MLALayer {
306    fn compute(
307        &self, query: &[f32], keys: &[&[f32]], values: &[&[f32]],
308    ) -> AttentionResult<Vec<f32>> {
309        let _ = values; // MLA derives V from the same inputs as K
310        let positions: Vec<usize> = (0..keys.len()).collect();
311        self.forward(query, keys, 0, &positions)
312    }
313
314    fn compute_with_mask(
315        &self, query: &[f32], keys: &[&[f32]], values: &[&[f32]],
316        _mask: Option<&[bool]>,
317    ) -> AttentionResult<Vec<f32>> {
318        self.compute(query, keys, values)
319    }
320
321    fn dim(&self) -> usize { self.config.d_model }
322    fn num_heads(&self) -> usize { self.config.num_heads }
323}
324
325// -- Utility functions --------------------------------------------------------
326
327fn matvec(w: &[f32], x: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
328    (0..out_d)
329        .map(|r| {
330            let off = r * in_d;
331            (0..in_d).map(|c| w[off + c] * x[c]).sum()
332        })
333        .collect()
334}
335
336fn dot(a: &[f32], b: &[f32]) -> f32 {
337    a.iter().zip(b).map(|(x, y)| x * y).sum()
338}
339
340fn softmax_inplace(s: &mut [f32]) {
341    let max = s.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
342    let mut sum = 0.0_f32;
343    for v in s.iter_mut() { *v = (*v - max).exp(); sum += *v; }
344    for v in s.iter_mut() { *v /= sum; }
345}
346
347fn init_weight(in_d: usize, out_d: usize) -> Vec<f32> {
348    let scale = (2.0 / (in_d + out_d) as f32).sqrt();
349    let period = (in_d + out_d).max(1);
350    (0..in_d * out_d)
351        .map(|i| scale * ((i % period) as f32 / period as f32 - 0.5))
352        .collect()
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    fn cfg() -> MLAConfig {
360        MLAConfig {
361            d_model: 32, latent_dim: 8, latent_dim_q: None,
362            num_heads: 4, head_dim: 8, rope_dim: 4,
363        }
364    }
365
366    #[test]
367    fn test_config_valid() { assert!(cfg().validate().is_ok()); }
368
369    #[test]
370    fn test_config_latent_too_large() {
371        let mut c = cfg(); c.latent_dim = 999;
372        assert!(c.validate().is_err());
373    }
374
375    #[test]
376    fn test_config_rope_dim_odd() {
377        let mut c = cfg(); c.rope_dim = 3;
378        assert!(c.validate().is_err());
379    }
380
381    #[test]
382    fn test_config_zero_heads() {
383        let mut c = cfg(); c.num_heads = 0;
384        assert!(c.validate().is_err());
385    }
386
387    #[test]
388    fn test_forward_output_shape() {
389        let c = cfg();
390        let layer = MLALayer::new(c.clone()).unwrap();
391        let q = vec![0.1_f32; c.d_model];
392        let kv1 = vec![0.2_f32; c.d_model];
393        let kv2 = vec![0.3_f32; c.d_model];
394        let out = layer.forward(&q, &[&kv1, &kv2], 0, &[0, 1]).unwrap();
395        assert_eq!(out.len(), c.d_model);
396    }
397
398    #[test]
399    fn test_forward_dimension_mismatch() {
400        let layer = MLALayer::new(cfg()).unwrap();
401        let bad_q = vec![0.1_f32; 5];
402        let kv = vec![0.2_f32; 32];
403        assert!(layer.forward(&bad_q, &[&kv[..]], 0, &[0]).is_err());
404    }
405
406    #[test]
407    fn test_cache_size_reduction() {
408        let c = cfg();
409        let mut cache = MLACache::new(&c);
410        for _ in 0..10 { cache.push(vec![0.0; c.latent_dim], vec![0.0; c.rope_dim]); }
411        assert_eq!(cache.len(), 10);
412        assert_eq!(cache.cache_size(), 120);        // 10 * (8+4)
413        assert_eq!(cache.mha_equivalent_size(), 640); // 10 * 2*4*8
414        assert!((cache.reduction_ratio() - 0.8125).abs() < 1e-4);
415    }
416
417    #[test]
418    fn test_memory_comparison_report() {
419        let c = MLAConfig {
420            d_model: 2048, latent_dim: 256, latent_dim_q: None,
421            num_heads: 16, head_dim: 128, rope_dim: 0,
422        };
423        let layer = MLALayer::new(c).unwrap();
424        let r = layer.memory_comparison(1024);
425        assert_eq!(r.mha_cache_floats, 4_194_304);
426        assert_eq!(r.mla_cache_floats, 262_144);
427        assert!((r.reduction_ratio - 0.9375).abs() < 1e-4);
428    }
429
430    #[test]
431    fn test_cached_forward_multi_position() {
432        let c = cfg();
433        let layer = MLALayer::new(c.clone()).unwrap();
434        let mut cache = MLACache::new(&c);
435        let q = vec![0.1_f32; c.d_model];
436        for pos in 0..3 {
437            let kv = vec![(pos as f32 + 1.0) * 0.1; c.d_model];
438            let out = layer.forward_cached(&q, &kv, pos, &mut cache).unwrap();
439            assert_eq!(out.len(), c.d_model);
440        }
441        assert_eq!(cache.len(), 3);
442        let kv_last = vec![0.4_f32; c.d_model];
443        let out = layer.forward_cached(&q, &kv_last, 3, &mut cache).unwrap();
444        assert!(out.iter().all(|v| v.is_finite()));
445        assert_eq!(cache.len(), 4);
446    }
447
448    #[test]
449    fn test_rope_identity_at_zero() {
450        let mut v = vec![1.0, 2.0, 3.0, 4.0];
451        let orig = v.clone();
452        MLALayer::apply_rope(&mut v, 0);
453        for (a, b) in v.iter().zip(&orig) { assert!((a - b).abs() < 1e-6); }
454    }
455
456    #[test]
457    fn test_rope_preserves_norm() {
458        let mut v = vec![1.0, 2.0, 3.0, 4.0];
459        let norm_before: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
460        MLALayer::apply_rope(&mut v, 42);
461        let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
462        assert!((norm_before - norm_after).abs() < 1e-5);
463    }
464
465    #[test]
466    fn test_compress_decompress_dimensions() {
467        let c = cfg();
468        let layer = MLALayer::new(c.clone()).unwrap();
469        let x = vec![0.5_f32; c.d_model];
470        let ckv = layer.compress_kv(&x);
471        assert_eq!(ckv.len(), c.latent_dim);
472        assert_eq!(layer.decompress_keys(&ckv).len(), c.full_kv_dim());
473        assert_eq!(layer.decompress_values(&ckv).len(), c.full_kv_dim());
474    }
475
476    #[test]
477    fn test_attention_trait() {
478        let c = cfg();
479        let layer = MLALayer::new(c.clone()).unwrap();
480        assert_eq!(layer.dim(), c.d_model);
481        assert_eq!(layer.num_heads(), c.num_heads);
482        let q = vec![0.1_f32; c.d_model];
483        let kv1 = vec![0.2_f32; c.d_model];
484        let kv2 = vec![0.3_f32; c.d_model];
485        let out = layer.compute(&q, &[&kv1[..], &kv2[..]], &[&kv1[..], &kv2[..]]).unwrap();
486        assert_eq!(out.len(), c.d_model);
487        assert!(out.iter().all(|v| v.is_finite()));
488    }
489
490    #[test]
491    fn test_empty_cache_ratio() {
492        let cache = MLACache::new(&cfg());
493        assert_eq!(cache.reduction_ratio(), 0.0);
494        assert!(cache.is_empty());
495    }
496}