1use crate::error::{AttentionError, AttentionResult};
15use crate::traits::Attention;
16
17#[derive(Clone, Debug)]
19pub struct MLAConfig {
20 pub d_model: usize,
21 pub latent_dim: usize,
22 pub latent_dim_q: Option<usize>,
23 pub num_heads: usize,
24 pub head_dim: usize,
25 pub rope_dim: usize,
27}
28
29impl MLAConfig {
30 pub fn validate(&self) -> AttentionResult<()> {
31 let err = |msg: &str| Err(AttentionError::InvalidConfig(msg.into()));
32 if self.d_model == 0 { return err("d_model must be > 0"); }
33 if self.num_heads == 0 { return err("num_heads must be > 0"); }
34 if self.head_dim == 0 { return err("head_dim must be > 0"); }
35 if self.latent_dim == 0 { return err("latent_dim must be > 0"); }
36 if self.latent_dim >= self.full_kv_dim() {
37 return err("latent_dim must be < num_heads * head_dim");
38 }
39 if self.rope_dim > self.head_dim {
40 return err("rope_dim must be <= head_dim");
41 }
42 if self.rope_dim > 0 && self.rope_dim % 2 != 0 {
43 return err("rope_dim must be even (RoPE operates on pairs)");
44 }
45 Ok(())
46 }
47
48 pub fn effective_latent_dim_q(&self) -> usize {
49 self.latent_dim_q.unwrap_or(self.latent_dim)
50 }
51
52 pub fn full_kv_dim(&self) -> usize {
53 self.num_heads * self.head_dim
54 }
55}
56
57#[derive(Clone, Debug)]
59pub struct MLACache {
60 pub latent_vectors: Vec<Vec<f32>>,
61 pub rope_keys: Vec<Vec<f32>>,
62 latent_dim: usize,
63 rope_dim: usize,
64 num_heads: usize,
65 head_dim: usize,
66}
67
68impl MLACache {
69 pub fn new(config: &MLAConfig) -> Self {
70 Self {
71 latent_vectors: Vec::new(), rope_keys: Vec::new(),
72 latent_dim: config.latent_dim, rope_dim: config.rope_dim,
73 num_heads: config.num_heads, head_dim: config.head_dim,
74 }
75 }
76
77 pub fn push(&mut self, latent: Vec<f32>, rope_key: Vec<f32>) {
78 self.latent_vectors.push(latent);
79 self.rope_keys.push(rope_key);
80 }
81
82 pub fn len(&self) -> usize { self.latent_vectors.len() }
83 pub fn is_empty(&self) -> bool { self.latent_vectors.is_empty() }
84
85 pub fn cache_size(&self) -> usize {
87 self.len() * (self.latent_dim + self.rope_dim)
88 }
89
90 pub fn mha_equivalent_size(&self) -> usize {
92 self.len() * 2 * self.num_heads * self.head_dim
93 }
94
95 pub fn reduction_ratio(&self) -> f32 {
97 if self.len() == 0 { return 0.0; }
98 1.0 - (self.cache_size() as f32 / self.mha_equivalent_size() as f32)
99 }
100}
101
102pub struct MLALayer {
104 config: MLAConfig,
105 w_dkv: Vec<f32>, w_uk: Vec<f32>, w_uv: Vec<f32>, w_dq: Vec<f32>, w_uq: Vec<f32>, w_rope: Vec<f32>, w_out: Vec<f32>, }
113
114impl MLALayer {
115 pub fn new(config: MLAConfig) -> AttentionResult<Self> {
117 config.validate()?;
118 let fd = config.full_kv_dim();
119 let lq = config.effective_latent_dim_q();
120 Ok(Self {
121 w_dkv: init_weight(config.d_model, config.latent_dim),
122 w_uk: init_weight(config.latent_dim, fd),
123 w_uv: init_weight(config.latent_dim, fd),
124 w_dq: init_weight(config.d_model, lq),
125 w_uq: init_weight(lq, fd),
126 w_rope: init_weight(config.d_model, config.rope_dim),
127 w_out: init_weight(fd, config.d_model),
128 config,
129 })
130 }
131
132 pub fn config(&self) -> &MLAConfig { &self.config }
133
134 pub fn compress_kv(&self, x: &[f32]) -> Vec<f32> {
136 matvec(&self.w_dkv, x, self.config.d_model, self.config.latent_dim)
137 }
138
139 pub fn decompress_keys(&self, c: &[f32]) -> Vec<f32> {
141 matvec(&self.w_uk, c, self.config.latent_dim, self.config.full_kv_dim())
142 }
143
144 pub fn decompress_values(&self, c: &[f32]) -> Vec<f32> {
146 matvec(&self.w_uv, c, self.config.latent_dim, self.config.full_kv_dim())
147 }
148
149 fn compute_rope_keys(&self, x: &[f32]) -> Vec<f32> {
150 if self.config.rope_dim == 0 { return Vec::new(); }
151 matvec(&self.w_rope, x, self.config.d_model, self.config.rope_dim)
152 }
153
154 fn compute_query(&self, x: &[f32]) -> Vec<f32> {
155 let lq = self.config.effective_latent_dim_q();
156 let c_q = matvec(&self.w_dq, x, self.config.d_model, lq);
157 matvec(&self.w_uq, &c_q, lq, self.config.full_kv_dim())
158 }
159
160 fn apply_rope(v: &mut [f32], position: usize) {
162 let dim = v.len();
163 for i in (0..dim).step_by(2) {
164 if i + 1 >= dim { break; }
165 let freq = 1.0 / (10000.0_f32).powf(i as f32 / dim as f32);
166 let theta = position as f32 * freq;
167 let (cos_t, sin_t) = (theta.cos(), theta.sin());
168 let (x0, x1) = (v[i], v[i + 1]);
169 v[i] = x0 * cos_t - x1 * sin_t;
170 v[i + 1] = x0 * sin_t + x1 * cos_t;
171 }
172 }
173
174 fn attend(
176 &self, q_full: &[f32], all_keys: &[Vec<f32>], all_values: &[Vec<f32>],
177 ) -> Vec<f32> {
178 let (nh, hd) = (self.config.num_heads, self.config.head_dim);
179 let scale = (hd as f32).sqrt();
180 let mut out = vec![0.0_f32; nh * hd];
181 for h in 0..nh {
182 let off = h * hd;
183 let qh = &q_full[off..off + hd];
184 let mut scores: Vec<f32> = all_keys
185 .iter()
186 .map(|k| dot(&k[off..off + hd], qh) / scale)
187 .collect();
188 softmax_inplace(&mut scores);
189 for (si, &w) in scores.iter().enumerate() {
190 let vh = &all_values[si][off..off + hd];
191 for d in 0..hd { out[off + d] += w * vh[d]; }
192 }
193 }
194 matvec(&self.w_out, &out, self.config.full_kv_dim(), self.config.d_model)
195 }
196
197 fn prepare_query(&self, input: &[f32], pos: usize) -> Vec<f32> {
199 let mut q = self.compute_query(input);
200 let (nh, hd, rd) = (self.config.num_heads, self.config.head_dim, self.config.rope_dim);
201 if rd > 0 {
202 for h in 0..nh { Self::apply_rope(&mut q[h * hd..h * hd + rd], pos); }
203 }
204 q
205 }
206
207 fn decompress_position(
209 &self, latent: &[f32], rope: &[f32], pos: usize,
210 ) -> (Vec<f32>, Vec<f32>) {
211 let mut keys = self.decompress_keys(latent);
212 let values = self.decompress_values(latent);
213 let (nh, hd, rd) = (self.config.num_heads, self.config.head_dim, self.config.rope_dim);
214 if rd > 0 {
215 let mut rp = rope.to_vec();
216 Self::apply_rope(&mut rp, pos);
217 for h in 0..nh { keys[h * hd..h * hd + rd].copy_from_slice(&rp); }
218 }
219 (keys, values)
220 }
221
222 pub fn forward(
224 &self, query_input: &[f32], kv_inputs: &[&[f32]],
225 query_pos: usize, kv_positions: &[usize],
226 ) -> AttentionResult<Vec<f32>> {
227 if query_input.len() != self.config.d_model {
228 return Err(AttentionError::DimensionMismatch {
229 expected: self.config.d_model, actual: query_input.len(),
230 });
231 }
232 if kv_inputs.is_empty() {
233 return Err(AttentionError::EmptyInput("kv_inputs".into()));
234 }
235 if kv_inputs.len() != kv_positions.len() {
236 return Err(AttentionError::DimensionMismatch {
237 expected: kv_inputs.len(), actual: kv_positions.len(),
238 });
239 }
240 let q_full = self.prepare_query(query_input, query_pos);
241 let mut all_k = Vec::with_capacity(kv_inputs.len());
242 let mut all_v = Vec::with_capacity(kv_inputs.len());
243 for (i, &kv) in kv_inputs.iter().enumerate() {
244 if kv.len() != self.config.d_model {
245 return Err(AttentionError::DimensionMismatch {
246 expected: self.config.d_model, actual: kv.len(),
247 });
248 }
249 let c = self.compress_kv(kv);
250 let rope = self.compute_rope_keys(kv);
251 let (k, v) = self.decompress_position(&c, &rope, kv_positions[i]);
252 all_k.push(k);
253 all_v.push(v);
254 }
255 Ok(self.attend(&q_full, &all_k, &all_v))
256 }
257
258 pub fn forward_cached(
260 &self, query_input: &[f32], new_kv_input: &[f32],
261 query_pos: usize, cache: &mut MLACache,
262 ) -> AttentionResult<Vec<f32>> {
263 if new_kv_input.len() != self.config.d_model {
264 return Err(AttentionError::DimensionMismatch {
265 expected: self.config.d_model, actual: new_kv_input.len(),
266 });
267 }
268 cache.push(self.compress_kv(new_kv_input), self.compute_rope_keys(new_kv_input));
269 let q_full = self.prepare_query(query_input, query_pos);
270 let mut all_k = Vec::with_capacity(cache.len());
271 let mut all_v = Vec::with_capacity(cache.len());
272 for pos in 0..cache.len() {
273 let (k, v) = self.decompress_position(
274 &cache.latent_vectors[pos], &cache.rope_keys[pos], pos,
275 );
276 all_k.push(k);
277 all_v.push(v);
278 }
279 Ok(self.attend(&q_full, &all_k, &all_v))
280 }
281
282 pub fn memory_comparison(&self, seq_len: usize) -> MemoryComparison {
284 let mha = seq_len * 2 * self.config.num_heads * self.config.head_dim;
285 let mla = seq_len * (self.config.latent_dim + self.config.rope_dim);
286 MemoryComparison {
287 seq_len, mha_cache_floats: mha, mla_cache_floats: mla,
288 mha_cache_bytes: mha * 4, mla_cache_bytes: mla * 4,
289 reduction_ratio: 1.0 - (mla as f32 / mha as f32),
290 }
291 }
292}
293
294#[derive(Clone, Debug)]
296pub struct MemoryComparison {
297 pub seq_len: usize,
298 pub mha_cache_floats: usize,
299 pub mla_cache_floats: usize,
300 pub mha_cache_bytes: usize,
301 pub mla_cache_bytes: usize,
302 pub reduction_ratio: f32,
303}
304
305impl Attention for MLALayer {
306 fn compute(
307 &self, query: &[f32], keys: &[&[f32]], values: &[&[f32]],
308 ) -> AttentionResult<Vec<f32>> {
309 let _ = values; let positions: Vec<usize> = (0..keys.len()).collect();
311 self.forward(query, keys, 0, &positions)
312 }
313
314 fn compute_with_mask(
315 &self, query: &[f32], keys: &[&[f32]], values: &[&[f32]],
316 _mask: Option<&[bool]>,
317 ) -> AttentionResult<Vec<f32>> {
318 self.compute(query, keys, values)
319 }
320
321 fn dim(&self) -> usize { self.config.d_model }
322 fn num_heads(&self) -> usize { self.config.num_heads }
323}
324
325fn matvec(w: &[f32], x: &[f32], in_d: usize, out_d: usize) -> Vec<f32> {
328 (0..out_d)
329 .map(|r| {
330 let off = r * in_d;
331 (0..in_d).map(|c| w[off + c] * x[c]).sum()
332 })
333 .collect()
334}
335
336fn dot(a: &[f32], b: &[f32]) -> f32 {
337 a.iter().zip(b).map(|(x, y)| x * y).sum()
338}
339
340fn softmax_inplace(s: &mut [f32]) {
341 let max = s.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
342 let mut sum = 0.0_f32;
343 for v in s.iter_mut() { *v = (*v - max).exp(); sum += *v; }
344 for v in s.iter_mut() { *v /= sum; }
345}
346
347fn init_weight(in_d: usize, out_d: usize) -> Vec<f32> {
348 let scale = (2.0 / (in_d + out_d) as f32).sqrt();
349 let period = (in_d + out_d).max(1);
350 (0..in_d * out_d)
351 .map(|i| scale * ((i % period) as f32 / period as f32 - 0.5))
352 .collect()
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 fn cfg() -> MLAConfig {
360 MLAConfig {
361 d_model: 32, latent_dim: 8, latent_dim_q: None,
362 num_heads: 4, head_dim: 8, rope_dim: 4,
363 }
364 }
365
366 #[test]
367 fn test_config_valid() { assert!(cfg().validate().is_ok()); }
368
369 #[test]
370 fn test_config_latent_too_large() {
371 let mut c = cfg(); c.latent_dim = 999;
372 assert!(c.validate().is_err());
373 }
374
375 #[test]
376 fn test_config_rope_dim_odd() {
377 let mut c = cfg(); c.rope_dim = 3;
378 assert!(c.validate().is_err());
379 }
380
381 #[test]
382 fn test_config_zero_heads() {
383 let mut c = cfg(); c.num_heads = 0;
384 assert!(c.validate().is_err());
385 }
386
387 #[test]
388 fn test_forward_output_shape() {
389 let c = cfg();
390 let layer = MLALayer::new(c.clone()).unwrap();
391 let q = vec![0.1_f32; c.d_model];
392 let kv1 = vec![0.2_f32; c.d_model];
393 let kv2 = vec![0.3_f32; c.d_model];
394 let out = layer.forward(&q, &[&kv1, &kv2], 0, &[0, 1]).unwrap();
395 assert_eq!(out.len(), c.d_model);
396 }
397
398 #[test]
399 fn test_forward_dimension_mismatch() {
400 let layer = MLALayer::new(cfg()).unwrap();
401 let bad_q = vec![0.1_f32; 5];
402 let kv = vec![0.2_f32; 32];
403 assert!(layer.forward(&bad_q, &[&kv[..]], 0, &[0]).is_err());
404 }
405
406 #[test]
407 fn test_cache_size_reduction() {
408 let c = cfg();
409 let mut cache = MLACache::new(&c);
410 for _ in 0..10 { cache.push(vec![0.0; c.latent_dim], vec![0.0; c.rope_dim]); }
411 assert_eq!(cache.len(), 10);
412 assert_eq!(cache.cache_size(), 120); assert_eq!(cache.mha_equivalent_size(), 640); assert!((cache.reduction_ratio() - 0.8125).abs() < 1e-4);
415 }
416
417 #[test]
418 fn test_memory_comparison_report() {
419 let c = MLAConfig {
420 d_model: 2048, latent_dim: 256, latent_dim_q: None,
421 num_heads: 16, head_dim: 128, rope_dim: 0,
422 };
423 let layer = MLALayer::new(c).unwrap();
424 let r = layer.memory_comparison(1024);
425 assert_eq!(r.mha_cache_floats, 4_194_304);
426 assert_eq!(r.mla_cache_floats, 262_144);
427 assert!((r.reduction_ratio - 0.9375).abs() < 1e-4);
428 }
429
430 #[test]
431 fn test_cached_forward_multi_position() {
432 let c = cfg();
433 let layer = MLALayer::new(c.clone()).unwrap();
434 let mut cache = MLACache::new(&c);
435 let q = vec![0.1_f32; c.d_model];
436 for pos in 0..3 {
437 let kv = vec![(pos as f32 + 1.0) * 0.1; c.d_model];
438 let out = layer.forward_cached(&q, &kv, pos, &mut cache).unwrap();
439 assert_eq!(out.len(), c.d_model);
440 }
441 assert_eq!(cache.len(), 3);
442 let kv_last = vec![0.4_f32; c.d_model];
443 let out = layer.forward_cached(&q, &kv_last, 3, &mut cache).unwrap();
444 assert!(out.iter().all(|v| v.is_finite()));
445 assert_eq!(cache.len(), 4);
446 }
447
448 #[test]
449 fn test_rope_identity_at_zero() {
450 let mut v = vec![1.0, 2.0, 3.0, 4.0];
451 let orig = v.clone();
452 MLALayer::apply_rope(&mut v, 0);
453 for (a, b) in v.iter().zip(&orig) { assert!((a - b).abs() < 1e-6); }
454 }
455
456 #[test]
457 fn test_rope_preserves_norm() {
458 let mut v = vec![1.0, 2.0, 3.0, 4.0];
459 let norm_before: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
460 MLALayer::apply_rope(&mut v, 42);
461 let norm_after: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
462 assert!((norm_before - norm_after).abs() < 1e-5);
463 }
464
465 #[test]
466 fn test_compress_decompress_dimensions() {
467 let c = cfg();
468 let layer = MLALayer::new(c.clone()).unwrap();
469 let x = vec![0.5_f32; c.d_model];
470 let ckv = layer.compress_kv(&x);
471 assert_eq!(ckv.len(), c.latent_dim);
472 assert_eq!(layer.decompress_keys(&ckv).len(), c.full_kv_dim());
473 assert_eq!(layer.decompress_values(&ckv).len(), c.full_kv_dim());
474 }
475
476 #[test]
477 fn test_attention_trait() {
478 let c = cfg();
479 let layer = MLALayer::new(c.clone()).unwrap();
480 assert_eq!(layer.dim(), c.d_model);
481 assert_eq!(layer.num_heads(), c.num_heads);
482 let q = vec![0.1_f32; c.d_model];
483 let kv1 = vec![0.2_f32; c.d_model];
484 let kv2 = vec![0.3_f32; c.d_model];
485 let out = layer.compute(&q, &[&kv1[..], &kv2[..]], &[&kv1[..], &kv2[..]]).unwrap();
486 assert_eq!(out.len(), c.d_model);
487 assert!(out.iter().all(|v| v.is_finite()));
488 }
489
490 #[test]
491 fn test_empty_cache_ratio() {
492 let cache = MLACache::new(&cfg());
493 assert_eq!(cache.reduction_ratio(), 0.0);
494 assert!(cache.is_empty());
495 }
496}