1use crate::error::{AttentionError, AttentionResult};
15use crate::traits::Attention;
16
17#[derive(Clone, Debug)]
19pub struct MLAConfig {
20 pub d_model: usize,
21 pub latent_dim: usize,
22 pub latent_dim_q: Option<usize>,
23 pub num_heads: usize,
24 pub head_dim: usize,
25 pub rope_dim: usize,
27}
28
29impl MLAConfig {
30 pub fn validate(&self) -> AttentionResult<()> {
31 let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
32 if self.d_model == 0 {
33 return err("d_model must be > 0");
34 }
35 if self.num_heads == 0 {
36 return err("num_heads must be > 0");
37 }
38 if self.head_dim == 0 {
39 return err("head_dim must be > 0");
40 }
41 if self.latent_dim == 0 {
42 return err("latent_dim must be > 0");
43 }
44 if self.latent_dim >= self.full_kv_dim() {
45 return err("latent_dim must be < num_heads * head_dim");
46 }
47 if self.rope_dim > self.head_dim {
48 return err("rope_dim must be <= head_dim");
49 }
50 if self.rope_dim > 0 && self.rope_dim % 2 != 0 {
51 return err("rope_dim must be even (RoPE operates on pairs)");
52 }
53 Ok(())
54 }
55
56 pub fn effective_latent_dim_q(&self) -> usize {
57 self.latent_dim_q.unwrap_or(self.latent_dim)
58 }
59
60 pub fn full_kv_dim(&self) -> usize {
61 self.num_heads * self.head_dim
62 }
63}
64
65#[derive(Clone, Debug)]
67pub struct MLACache {
68 pub latent_vectors: Vec<Vec<f32>>,
69 pub rope_keys: Vec<Vec<f32>>,
70 latent_dim: usize,
71 rope_dim: usize,
72 num_heads: usize,
73 head_dim: usize,
74}
75
76impl MLACache {
77 pub fn new(config: &MLAConfig) -> Self {
78 Self {
79 latent_vectors: Vec::new(),
80 rope_keys: Vec::new(),
81 latent_dim: config.latent_dim,
82 rope_dim: config.rope_dim,
83 num_heads: config.num_heads,
84 head_dim: config.head_dim,
85 }
86 }
87
88 pub fn push(&mut self, latent: Vec<f32>, rope_key: Vec<f32>) {
89 self.latent_vectors.push(latent);
90 self.rope_keys.push(rope_key);
91 }
92
93 pub fn len(&self) -> usize {
94 self.latent_vectors.len()
95 }
96 pub fn is_empty(&self) -> bool {
97 self.latent_vectors.is_empty()
98 }
99
100 pub fn cache_size(&self) -> usize {
102 self.len() * (self.latent_dim + self.rope_dim)
103 }
104
105 pub fn mha_equivalent_size(&self) -> usize {
107 self.len() * 2 * self.num_heads * self.head_dim
108 }
109
110 pub fn reduction_ratio(&self) -> f32 {
112 if self.len() == 0 {
113 return 0.0;
114 }
115 1.0 - (self.cache_size() as f32 / self.mha_equivalent_size() as f32)
116 }
117}
118
119pub struct MLALayer {
121 config: MLAConfig,
122 w_dkv: Vec<f32>, w_uk: Vec<f32>, w_uv: Vec<f32>, w_dq: Vec<f32>, w_uq: Vec<f32>, w_rope: Vec<f32>, w_out: Vec<f32>, }
130
131impl MLALayer {
132 pub fn new(config: MLAConfig) -> AttentionResult<Self> {
134 config.validate()?;
135 let fd = config.full_kv_dim();
136 let lq = config.effective_latent_dim_q();
137 Ok(Self {
138 w_dkv: init_weight(config.d_model, config.latent_dim),
139 w_uk: init_weight(config.latent_dim, fd),
140 w_uv: init_weight(config.latent_dim, fd),
141 w_dq: init_weight(config.d_model, lq),
142 w_uq: init_weight(lq, fd),
143 w_rope: init_weight(config.d_model, config.rope_dim),
144 w_out: init_weight(fd, config.d_model),
145 config,
146 })
147 }
148
149 pub fn config(&self) -> &MLAConfig {
150 &self.config
151 }
152
153 pub fn compress_kv(&self, x: &[f32]) -> Vec<f32> {
155 matvec(&self.w_dkv, x, self.config.d_model, self.config.latent_dim)
156 }
157
158 pub fn decompress_keys(&self, c: &[f32]) -> Vec<f32> {
160 matvec(
161 &self.w_uk,
162 c,
163 self.config.latent_dim,
164 self.config.full_kv_dim(),
165 )
166 }
167
168 pub fn decompress_values(&self, c: &[f32]) -> Vec<f32> {
170 matvec(
171 &self.w_uv,
172 c,
173 self.config.latent_dim,
174 self.config.full_kv_dim(),
175 )
176 }
177
178 fn compute_rope_keys(&self, x: &[f32]) -> Vec<f32> {
179 if self.config.rope_dim == 0 {
180 return Vec::new();
181 }
182 matvec(&self.w_rope, x, self.config.d_model, self.config.rope_dim)
183 }
184
185 fn compute_query(&self, x: &[f32]) -> Vec<f32> {
186 let lq = self.config.effective_latent_dim_q();
187 let c_q = matvec(&self.w_dq, x, self.config.d_model, lq);
188 matvec(&self.w_uq, &c_q, lq, self.config.full_kv_dim())
189 }
190
191 fn apply_rope(v: &mut [f32], position: usize) {
193 let dim = v.len();
194 for i in (0..dim).step_by(2) {
195 if i + 1 >= dim {
196 break;
197 }
198 let freq = 1.0 / (10000.0_f32).powf(i as f32 / dim as f32);
199 let theta = position as f32 * freq;
200 let (cos_t, sin_t) = (theta.cos(), theta.sin());
201 let (x0, x1) = (v[i], v[i + 1]);
202 v[i] = x0 * cos_t - x1 * sin_t;
203 v[i + 1] = x0 * sin_t + x1 * cos_t;
204 }
205 }
206
207 fn attend(&self, q_full: &[f32], all_keys: &[Vec<f32>], all_values: &[Vec<f32>]) -> Vec<f32> {
209 let (nh, hd) = (self.config.num_heads, self.config.head_dim);
210 let scale = (hd as f32).sqrt();
211 let mut out = vec![0.0_f32; nh * hd];
212 for h in 0..nh {
213 let off = h * hd;
214 let qh = &q_full[off..off + hd];
215 let mut scores: Vec<f32> = all_keys
216 .iter()
217 .map(|k| dot(&k[off..off + hd], qh) / scale)
218 .collect();
219 softmax_inplace(&mut scores);
220 for (si, &w) in scores.iter().enumerate() {
221 let vh = &all_values[si][off..off + hd];
222 for d in 0..hd {
223 out[off + d] += w * vh[d];
224 }
225 }
226 }
227 matvec(
228 &self.w_out,
229 &out,
230 self.config.full_kv_dim(),
231 self.config.d_model,
232 )
233 }
234
235 fn prepare_query(&self, input: &[f32], pos: usize) -> Vec<f32> {
237 let mut q = self.compute_query(input);
238 let (nh, hd, rd) = (
239 self.config.num_heads,
240 self.config.head_dim,
241 self.config.rope_dim,
242 );
243 if rd > 0 {
244 for h in 0..nh {
245 Self::apply_rope(&mut q[h * hd..h * hd + rd], pos);
246 }
247 }
248 q
249 }
250
251 fn decompress_position(
253 &self,
254 latent: &[f32],
255 rope: &[f32],
256 pos: usize,
257 ) -> (Vec<f32>, Vec<f32>) {
258 let mut keys = self.decompress_keys(latent);
259 let values = self.decompress_values(latent);
260 let (nh, hd, rd) = (
261 self.config.num_heads,
262 self.config.head_dim,
263 self.config.rope_dim,
264 );
265 if rd > 0 {
266 let mut rp = rope.to_vec();
267 Self::apply_rope(&mut rp, pos);
268 for h in 0..nh {
269 keys[h * hd..h * hd + rd].copy_from_slice(&rp);
270 }
271 }
272 (keys, values)
273 }
274
275 pub fn forward(
277 &self,
278 query_input: &[f32],
279 kv_inputs: &[&[f32]],
280 query_pos: usize,
281 kv_positions: &[usize],
282 ) -> AttentionResult<Vec<f32>> {
283 if query_input.len() != self.config.d_model {
284 return Err(AttentionError::DimensionMismatch {
285 expected: self.config.d_model,
286 actual: query_input.len(),
287 });
288 }
289 if kv_inputs.is_empty() {
290 return Err(AttentionError::EmptyInput("kv_inputs".into()));
291 }
292 if kv_inputs.len() != kv_positions.len() {
293 return Err(AttentionError::DimensionMismatch {
294 expected: kv_inputs.len(),
295 actual: kv_positions.len(),
296 });
297 }
298 let q_full = self.prepare_query(query_input, query_pos);
299 let mut all_k = Vec::with_capacity(kv_inputs.len());
300 let mut all_v = Vec::with_capacity(kv_inputs.len());
301 for (i, &kv) in kv_inputs.iter().enumerate() {
302 if kv.len() != self.config.d_model {
303 return Err(AttentionError::DimensionMismatch {
304 expected: self.config.d_model,
305 actual: kv.len(),
306 });
307 }
308 let c = self.compress_kv(kv);
309 let rope = self.compute_rope_keys(kv);
310 let (k, v) = self.decompress_position(&c, &rope, kv_positions[i]);
311 all_k.push(k);
312 all_v.push(v);
313 }
314 Ok(self.attend(&q_full, &all_k, &all_v))
315 }
316
317 pub fn forward_cached(
319 &self,
320 query_input: &[f32],
321 new_kv_input: &[f32],
322 query_pos: usize,
323 cache: &mut MLACache,
324 ) -> AttentionResult<Vec<f32>> {
325 if new_kv_input.len() != self.config.d_model {
326 return Err(AttentionError::DimensionMismatch {
327 expected: self.config.d_model,
328 actual: new_kv_input.len(),
329 });
330 }
331 cache.push(
332 self.compress_kv(new_kv_input),
333 self.compute_rope_keys(new_kv_input),
334 );
335 let q_full = self.prepare_query(query_input, query_pos);
336 let mut all_k = Vec::with_capacity(cache.len());
337 let mut all_v = Vec::with_capacity(cache.len());
338 for pos in 0..cache.len() {
339 let (k, v) =
340 self.decompress_position(&cache.latent_vectors[pos], &cache.rope_keys[pos], pos);
341 all_k.push(k);
342 all_v.push(v);
343 }
344 Ok(self.attend(&q_full, &all_k, &all_v))
345 }
346
347 pub fn memory_comparison(&self, seq_len: usize) -> MemoryComparison {
349 let mha = seq_len * 2 * self.config.num_heads * self.config.head_dim;
350 let mla = seq_len * (self.config.latent_dim + self.config.rope_dim);
351 MemoryComparison {
352 seq_len,
353 mha_cache_floats: mha,
354 mla_cache_floats: mla,
355 mha_cache_bytes: mha * 4,
356 mla_cache_bytes: mla * 4,
357 reduction_ratio: 1.0 - (mla as f32 / mha as f32),
358 }
359 }
360}
361
362#[derive(Clone, Debug)]
364pub struct MemoryComparison {
365 pub seq_len: usize,
366 pub mha_cache_floats: usize,
367 pub mla_cache_floats: usize,
368 pub mha_cache_bytes: usize,
369 pub mla_cache_bytes: usize,
370 pub reduction_ratio: f32,
371}
372
373impl Attention for MLALayer {
374 fn compute(
375 &self,
376 query: &[f32],
377 keys: &[&[f32]],
378 values: &[&[f32]],
379 ) -> AttentionResult<Vec<f32>> {
380 let _ = values; let positions: Vec<usize> = (0..keys.len()).collect();
382 self.forward(query, keys, 0, &positions)
383 }
384
385 fn compute_with_mask(
386 &self,
387 query: &[f32],
388 keys: &[&[f32]],
389 values: &[&[f32]],
390 _mask: Option<&[bool]>,
391 ) -> AttentionResult<Vec<f32>> {
392 self.compute(query, keys, values)
393 }
394
395 fn dim(&self) -> usize {
396 self.config.d_model
397 }
398 fn num_heads(&self) -> usize {
399 self.config.num_heads
400 }
401}
402
403fn matvec(w: &[f32], x: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
406 (0..out_d)
407 .map(|r| {
408 let off = r * in_d;
409 (0..in_d).map(|c| w[off + c] * x[c]).sum()
410 })
411 .collect()
412}
413
414fn dot(a: &[f32], b: &[f32]) -> f32 {
415 a.iter().zip(b).map(|(x, y)| x * y).sum()
416}
417
418fn softmax_inplace(s: &mut [f32]) {
419 let max = s.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
420 let mut sum = 0.0_f32;
421 for v in s.iter_mut() {
422 *v = (*v - max).exp();
423 sum += *v;
424 }
425 for v in s.iter_mut() {
426 *v /= sum;
427 }
428}
429
430fn init_weight(in_d: usize, out_d: usize) -> Vec<f32> {
431 let scale = (2.0 / (in_d + out_d) as f32).sqrt();
432 let period = (in_d + out_d).max(1);
433 (0..in_d * out_d)
434 .map(|i| scale * ((i % period) as f32 / period as f32 - 0.5))
435 .collect()
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 fn cfg() -> MLAConfig {
443 MLAConfig {
444 d_model: 32,
445 latent_dim: 8,
446 latent_dim_q: None,
447 num_heads: 4,
448 head_dim: 8,
449 rope_dim: 4,
450 }
451 }
452
453 #[test]
454 fn test_config_valid() {
455 assert!(cfg().validate().is_ok());
456 }
457
458 #[test]
459 fn test_config_latent_too_large() {
460 let mut c = cfg();
461 c.latent_dim = 999;
462 assert!(c.validate().is_err());
463 }
464
465 #[test]
466 fn test_config_rope_dim_odd() {
467 let mut c = cfg();
468 c.rope_dim = 3;
469 assert!(c.validate().is_err());
470 }
471
472 #[test]
473 fn test_config_zero_heads() {
474 let mut c = cfg();
475 c.num_heads = 0;
476 assert!(c.validate().is_err());
477 }
478
479 #[test]
480 fn test_forward_output_shape() {
481 let c = cfg();
482 let layer = MLALayer::new(c.clone()).unwrap();
483 let q = vec![0.1_f32; c.d_model];
484 let kv1 = vec![0.2_f32; c.d_model];
485 let kv2 = vec![0.3_f32; c.d_model];
486 let out = layer.forward(&q, &[&kv1, &kv2], 0, &[0, 1]).unwrap();
487 assert_eq!(out.len(), c.d_model);
488 }
489
490 #[test]
491 fn test_forward_dimension_mismatch() {
492 let layer = MLALayer::new(cfg()).unwrap();
493 let bad_q = vec![0.1_f32; 5];
494 let kv = vec![0.2_f32; 32];
495 assert!(layer.forward(&bad_q, &[&kv[..]], 0, &[0]).is_err());
496 }
497
498 #[test]
499 fn test_cache_size_reduction() {
500 let c = cfg();
501 let mut cache = MLACache::new(&c);
502 for _ in 0..10 {
503 cache.push(vec![0.0; c.latent_dim], vec![0.0; c.rope_dim]);
504 }
505 assert_eq!(cache.len(), 10);
506 assert_eq!(cache.cache_size(), 120); assert_eq!(cache.mha_equivalent_size(), 640); assert!((cache.reduction_ratio() - 0.8125).abs() < 1e-4);
509 }
510
511 #[test]
512 fn test_memory_comparison_report() {
513 let c = MLAConfig {
514 d_model: 2048,
515 latent_dim: 256,
516 latent_dim_q: None,
517 num_heads: 16,
518 head_dim: 128,
519 rope_dim: 0,
520 };
521 let layer = MLALayer::new(c).unwrap();
522 let r = layer.memory_comparison(1024);
523 assert_eq!(r.mha_cache_floats, 4_194_304);
524 assert_eq!(r.mla_cache_floats, 262_144);
525 assert!((r.reduction_ratio - 0.9375).abs() < 1e-4);
526 }
527
528 #[test]
529 fn test_cached_forward_multi_position() {
530 let c = cfg();
531 let layer = MLALayer::new(c.clone()).unwrap();
532 let mut cache = MLACache::new(&c);
533 let q = vec![0.1_f32; c.d_model];
534 for pos in 0..3 {
535 let kv = vec![(pos as f32 + 1.0) * 0.1; c.d_model];
536 let out = layer.forward_cached(&q, &kv, pos, &mut cache).unwrap();
537 assert_eq!(out.len(), c.d_model);
538 }
539 assert_eq!(cache.len(), 3);
540 let kv_last = vec![0.4_f32; c.d_model];
541 let out = layer.forward_cached(&q, &kv_last, 3, &mut cache).unwrap();
542 assert!(out.iter().all(|v| v.is_finite()));
543 assert_eq!(cache.len(), 4);
544 }
545
546 #[test]
547 fn test_rope_identity_at_zero() {
548 let mut v = vec![1.0, 2.0, 3.0, 4.0];
549 let orig = v.clone();
550 MLALayer::apply_rope(&mut v, 0);
551 for (a, b) in v.iter().zip(&orig) {
552 assert!((a - b).abs() < 1e-6);
553 }
554 }
555
556 #[test]
557 fn test_rope_preserves_norm() {
558 let mut v = vec![1.0, 2.0, 3.0, 4.0];
559 let norm_before: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
560 MLALayer::apply_rope(&mut v, 42);
561 let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
562 assert!((norm_before - norm_after).abs() < 1e-5);
563 }
564
565 #[test]
566 fn test_compress_decompress_dimensions() {
567 let c = cfg();
568 let layer = MLALayer::new(c.clone()).unwrap();
569 let x = vec![0.5_f32; c.d_model];
570 let ckv = layer.compress_kv(&x);
571 assert_eq!(ckv.len(), c.latent_dim);
572 assert_eq!(layer.decompress_keys(&ckv).len(), c.full_kv_dim());
573 assert_eq!(layer.decompress_values(&ckv).len(), c.full_kv_dim());
574 }
575
576 #[test]
577 fn test_attention_trait() {
578 let c = cfg();
579 let layer = MLALayer::new(c.clone()).unwrap();
580 assert_eq!(layer.dim(), c.d_model);
581 assert_eq!(layer.num_heads(), c.num_heads);
582 let q = vec![0.1_f32; c.d_model];
583 let kv1 = vec![0.2_f32; c.d_model];
584 let kv2 = vec![0.3_f32; c.d_model];
585 let out = layer
586 .compute(&q, &[&kv1[..], &kv2[..]], &[&kv1[..], &kv2[..]])
587 .unwrap();
588 assert_eq!(out.len(), c.d_model);
589 assert!(out.iter().all(|v| v.is_finite()));
590 }
591
592 #[test]
593 fn test_empty_cache_ratio() {
594 let cache = MLACache::new(&cfg());
595 assert_eq!(cache.reduction_ratio(), 0.0);
596 assert!(cache.is_empty());
597 }
598}