ruvector_attention/attention/
ssm.rs1#[derive(Debug, Clone)]
29pub struct SSMConfig {
30 pub d_model: usize,
32 pub d_state: usize,
34 pub d_conv: usize,
36 pub expand_factor: usize,
38 pub dt_rank: usize,
40}
41
42impl SSMConfig {
43 pub fn new(d_model: usize) -> Self {
45 let expand = 2;
46 Self {
47 d_model,
48 d_state: 16,
49 d_conv: 4,
50 expand_factor: expand,
51 dt_rank: (d_model + 15) / 16, }
53 }
54
55 pub fn d_inner(&self) -> usize {
57 self.d_model * self.expand_factor
58 }
59
60 pub fn validate(&self) -> Result<(), &'static str> {
62 if self.d_model == 0 {
63 return Err("d_model must be > 0");
64 }
65 if self.d_state == 0 {
66 return Err("d_state must be > 0");
67 }
68 if self.d_conv == 0 {
69 return Err("d_conv must be > 0");
70 }
71 if self.expand_factor == 0 {
72 return Err("expand_factor must be > 0");
73 }
74 if self.dt_rank == 0 {
75 return Err("dt_rank must be > 0");
76 }
77 Ok(())
78 }
79}
80
81#[inline]
87pub fn softplus(x: f32) -> f32 {
88 if x > 20.0 {
89 x } else if x < -20.0 {
91 0.0
92 } else {
93 (1.0 + x.exp()).ln()
94 }
95}
96
97#[inline]
99pub fn silu(x: f32) -> f32 {
100 x / (1.0 + (-x).exp())
101}
102
103pub fn rms_norm(x: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
105 let n = x.len();
106 assert_eq!(n, weight.len(), "rms_norm: x and weight must match in size");
107 let mean_sq = x.iter().map(|v| v * v).sum::<f32>() / n as f32;
108 let inv_rms = 1.0 / (mean_sq + eps).sqrt();
109 x.iter()
110 .zip(weight.iter())
111 .map(|(&xi, &wi)| xi * inv_rms * wi)
112 .collect()
113}
114
115fn matvec(matrix: &[f32], x: &[f32], rows: usize, cols: usize) -> Vec<f32> {
117 assert_eq!(matrix.len(), rows * cols);
118 assert_eq!(x.len(), cols);
119 (0..rows)
120 .map(|r| {
121 let row = &matrix[r * cols..(r + 1) * cols];
122 row.iter().zip(x.iter()).map(|(m, v)| m * v).sum()
123 })
124 .collect()
125}
126
127pub struct SelectiveSSM {
139 config: SSMConfig,
140 a_log: Vec<f32>, conv_weight: Vec<f32>,
144 conv_bias: Vec<f32>, in_proj: Vec<f32>,
147 w_dt: Vec<f32>,
149 dt_bias: Vec<f32>, w_b: Vec<f32>,
152 w_c: Vec<f32>,
154 out_proj: Vec<f32>,
156}
157
158impl SelectiveSSM {
159 pub fn new(config: SSMConfig) -> Self {
161 config.validate().expect("invalid SSMConfig");
162 let d_inner = config.d_inner();
163 let d_state = config.d_state;
164 let d_model = config.d_model;
165 let d_conv = config.d_conv;
166 let dt_rank = config.dt_rank;
167
168 let a_log = vec![0.0_f32; d_inner * d_state];
170 let conv_weight = vec![1.0 / d_conv as f32; d_inner * d_conv];
171 let conv_bias = vec![0.0; d_inner];
172 let scale = 1.0 / (d_model as f32).sqrt();
174 let in_proj = vec![scale; 2 * d_inner * d_model];
175 let w_dt = vec![scale; d_inner * dt_rank];
176 let dt_bias = vec![0.0; d_inner];
177 let w_b = vec![scale; d_state * d_inner];
178 let w_c = vec![scale; d_state * d_inner];
179 let out_proj = vec![scale; d_model * d_inner];
180
181 Self {
182 config,
183 a_log,
184 conv_weight,
185 conv_bias,
186 in_proj,
187 w_dt,
188 dt_bias,
189 w_b,
190 w_c,
191 out_proj,
192 }
193 }
194
195 pub fn config(&self) -> &SSMConfig {
197 &self.config
198 }
199
200 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
205 let d_model = self.config.d_model;
206 let seq_len = input.len() / d_model;
207 assert_eq!(
208 input.len(),
209 seq_len * d_model,
210 "input not divisible by d_model"
211 );
212
213 let d_inner = self.config.d_inner();
214
215 let mut z_seq = Vec::with_capacity(seq_len * d_inner);
217 let mut xc_seq = Vec::with_capacity(seq_len * d_inner);
218 for t in 0..seq_len {
219 let x_t = &input[t * d_model..(t + 1) * d_model];
220 let projected = matvec(&self.in_proj, x_t, 2 * d_inner, d_model);
221 z_seq.extend_from_slice(&projected[..d_inner]);
222 xc_seq.extend_from_slice(&projected[d_inner..]);
223 }
224
225 let xc_conv = self.causal_conv(&xc_seq, seq_len, d_inner);
227
228 let y_seq = self.selective_scan(&xc_conv, seq_len, d_inner);
230
231 let mut output = Vec::with_capacity(seq_len * d_model);
233 for t in 0..seq_len {
234 let gated: Vec<f32> = (0..d_inner)
235 .map(|i| y_seq[t * d_inner + i] * silu(z_seq[t * d_inner + i]))
236 .collect();
237 let out_t = matvec(&self.out_proj, &gated, d_model, d_inner);
238 output.extend_from_slice(&out_t);
239 }
240 output
241 }
242
243 fn causal_conv(&self, xc: &[f32], seq_len: usize, d_inner: usize) -> Vec<f32> {
245 let d_conv = self.config.d_conv;
246 let mut out = vec![0.0; seq_len * d_inner];
247 for t in 0..seq_len {
248 for i in 0..d_inner {
249 let mut acc = self.conv_bias[i];
250 for k in 0..d_conv {
251 if t >= k {
252 let w = self.conv_weight[i * d_conv + k];
253 acc += w * xc[(t - k) * d_inner + i];
254 }
255 }
256 out[t * d_inner + i] = silu(acc);
257 }
258 }
259 out
260 }
261
262 fn selective_scan(&self, x: &[f32], seq_len: usize, d_inner: usize) -> Vec<f32> {
264 let d_state = self.config.d_state;
265 let mut h = vec![0.0_f32; d_inner * d_state];
266 let mut y_seq = Vec::with_capacity(seq_len * d_inner);
267
268 for t in 0..seq_len {
269 let x_t = &x[t * d_inner..(t + 1) * d_inner];
270 let dt_pre = matvec(&self.w_dt, x_t, self.config.dt_rank, d_inner);
272 let delta: Vec<f32> = (0..d_inner)
274 .map(|i| softplus(dt_pre[i % self.config.dt_rank] + self.dt_bias[i]))
275 .collect();
276 let b_t = matvec(&self.w_b, x_t, d_state, d_inner);
278 let c_t = matvec(&self.w_c, x_t, d_state, d_inner);
280
281 let mut y_t = vec![0.0_f32; d_inner];
283 for i in 0..d_inner {
284 for j in 0..d_state {
285 let a = -(-self.a_log[i * d_state + j]).exp(); let a_bar = (delta[i] * a).exp();
287 let b_bar = delta[i] * b_t[j];
288 let idx = i * d_state + j;
289 h[idx] = a_bar * h[idx] + b_bar * x_t[i];
290 y_t[i] += c_t[j] * h[idx];
291 }
292 }
293 y_seq.extend_from_slice(&y_t);
294 }
295 y_seq
296 }
297
298 pub fn init_state(&self) -> SSMState {
300 SSMState {
301 h: vec![0.0; self.config.d_inner() * self.config.d_state],
302 d_inner: self.config.d_inner(),
303 d_state: self.config.d_state,
304 }
305 }
306
307 pub fn step(&self, token: &[f32], state: &mut SSMState) -> Vec<f32> {
310 let d_model = self.config.d_model;
311 let d_inner = self.config.d_inner();
312 let d_state = self.config.d_state;
313 assert_eq!(token.len(), d_model);
314
315 let projected = matvec(&self.in_proj, token, 2 * d_inner, d_model);
317 let z = &projected[..d_inner];
318 let xc: Vec<f32> = (0..d_inner).map(|i| silu(projected[d_inner + i])).collect();
319
320 let dt_pre = matvec(&self.w_dt, &xc, self.config.dt_rank, d_inner);
322 let delta: Vec<f32> = (0..d_inner)
323 .map(|i| softplus(dt_pre[i % self.config.dt_rank] + self.dt_bias[i]))
324 .collect();
325 let b_t = matvec(&self.w_b, &xc, d_state, d_inner);
326 let c_t = matvec(&self.w_c, &xc, d_state, d_inner);
327
328 let mut y = vec![0.0_f32; d_inner];
330 for i in 0..d_inner {
331 for j in 0..d_state {
332 let a = -(-self.a_log[i * d_state + j]).exp();
333 let a_bar = (delta[i] * a).exp();
334 let b_bar = delta[i] * b_t[j];
335 let idx = i * d_state + j;
336 state.h[idx] = a_bar * state.h[idx] + b_bar * xc[i];
337 y[i] += c_t[j] * state.h[idx];
338 }
339 }
340
341 let gated: Vec<f32> = (0..d_inner).map(|i| y[i] * silu(z[i])).collect();
343 matvec(&self.out_proj, &gated, d_model, d_inner)
344 }
345}
346
347#[derive(Debug, Clone)]
349pub struct SSMState {
350 pub h: Vec<f32>,
352 d_inner: usize,
353 d_state: usize,
354}
355
356impl SSMState {
357 pub fn reset(&mut self) {
359 self.h.fill(0.0);
360 }
361
362 pub fn shape(&self) -> (usize, usize) {
364 (self.d_inner, self.d_state)
365 }
366}
367
368pub struct MambaBlock {
374 ssm: SelectiveSSM,
375 norm_weight: Vec<f32>,
376 norm_eps: f32,
377}
378
379impl MambaBlock {
380 pub fn new(config: SSMConfig) -> Self {
381 let d = config.d_model;
382 Self {
383 ssm: SelectiveSSM::new(config),
384 norm_weight: vec![1.0; d],
385 norm_eps: 1e-5,
386 }
387 }
388
389 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
391 let d = self.ssm.config().d_model;
392 let seq_len = input.len() / d;
393 let mut normed = Vec::with_capacity(input.len());
395 for t in 0..seq_len {
396 let tok = &input[t * d..(t + 1) * d];
397 normed.extend(rms_norm(tok, &self.norm_weight, self.norm_eps));
398 }
399 let ssm_out = self.ssm.forward(&normed);
400 input
402 .iter()
403 .zip(ssm_out.iter())
404 .map(|(a, b)| a + b)
405 .collect()
406 }
407
408 pub fn step(&self, token: &[f32], state: &mut SSMState) -> Vec<f32> {
410 let normed = rms_norm(token, &self.norm_weight, self.norm_eps);
411 let out = self.ssm.step(&normed, state);
412 token.iter().zip(out.iter()).map(|(a, b)| a + b).collect()
413 }
414}
415
416#[derive(Debug, Clone, Copy, PartialEq)]
422pub enum LayerKind {
423 SSM,
424 Attention,
425}
426
427#[derive(Debug, Clone)]
429pub struct HybridConfig {
430 pub ssm: SSMConfig,
431 pub num_layers: usize,
432 pub hybrid_ratio: f32,
434}
435
436impl HybridConfig {
437 pub fn layer_schedule(&self) -> Vec<LayerKind> {
439 (0..self.num_layers)
440 .map(|i| {
441 let attn_every = if self.hybrid_ratio <= 0.0 {
442 usize::MAX
443 } else {
444 (1.0 / self.hybrid_ratio).round().max(1.0) as usize
445 };
446 if attn_every < usize::MAX && i % attn_every == attn_every - 1 {
447 LayerKind::Attention
448 } else {
449 LayerKind::SSM
450 }
451 })
452 .collect()
453 }
454}
455
456pub struct HybridBlock {
461 schedule: Vec<LayerKind>,
462 ssm_layers: Vec<MambaBlock>,
464 num_attention_layers: usize,
467}
468
469impl HybridBlock {
470 pub fn new(config: HybridConfig) -> Self {
471 let schedule = config.layer_schedule();
472 let ssm_count = schedule.iter().filter(|k| **k == LayerKind::SSM).count();
473 let attn_count = schedule.len() - ssm_count;
474 let ssm_layers = (0..ssm_count)
475 .map(|_| MambaBlock::new(config.ssm.clone()))
476 .collect();
477 Self {
478 schedule,
479 ssm_layers,
480 num_attention_layers: attn_count,
481 }
482 }
483
484 pub fn schedule(&self) -> &[LayerKind] {
486 &self.schedule
487 }
488
489 pub fn attention_layer_count(&self) -> usize {
491 self.num_attention_layers
492 }
493
494 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
500 let mut x = input.to_vec();
501 let mut ssm_idx = 0;
502 for kind in &self.schedule {
503 match kind {
504 LayerKind::SSM => {
505 x = self.ssm_layers[ssm_idx].forward(&x);
506 ssm_idx += 1;
507 }
508 LayerKind::Attention => {
509 }
511 }
512 }
513 x
514 }
515}
516
517#[cfg(test)]
522mod tests {
523 use super::*;
524
525 #[test]
526 fn test_config_defaults() {
527 let c = SSMConfig::new(64);
528 assert_eq!(c.d_model, 64);
529 assert_eq!(c.d_state, 16);
530 assert_eq!(c.d_conv, 4);
531 assert_eq!(c.expand_factor, 2);
532 assert_eq!(c.d_inner(), 128);
533 assert!(c.validate().is_ok());
534 }
535
536 #[test]
537 fn test_config_validation_errors() {
538 let mut c = SSMConfig::new(64);
539 c.d_model = 0;
540 assert!(c.validate().is_err());
541 c.d_model = 64;
542 c.d_state = 0;
543 assert!(c.validate().is_err());
544 c.d_state = 16;
545 c.d_conv = 0;
546 assert!(c.validate().is_err());
547 }
548
549 #[test]
550 fn test_softplus_values() {
551 assert!((softplus(0.0) - 0.6931).abs() < 1e-3); assert!((softplus(1.0) - 1.3133).abs() < 1e-3); assert!((softplus(25.0) - 25.0).abs() < 1e-3);
555 assert!(softplus(-25.0) < 1e-3);
557 }
558
559 #[test]
560 fn test_silu_values() {
561 assert!((silu(0.0)).abs() < 1e-6); assert!((silu(1.0) - 0.7311).abs() < 1e-3);
564 assert!(silu(-5.0) < 0.0);
566 }
567
568 #[test]
569 fn test_rms_norm() {
570 let x = vec![3.0, 4.0];
571 let w = vec![1.0, 1.0];
572 let normed = rms_norm(&x, &w, 1e-8);
573 let rms = (12.5_f32).sqrt();
575 assert!((normed[0] - 3.0 / rms).abs() < 1e-4);
576 assert!((normed[1] - 4.0 / rms).abs() < 1e-4);
577 }
578
579 #[test]
580 fn test_selective_scan_single_step() {
581 let config = SSMConfig::new(4);
582 let ssm = SelectiveSSM::new(config);
583 let input = vec![1.0; 4]; let output = ssm.forward(&input);
585 assert_eq!(output.len(), 4);
586 assert!(output.iter().all(|v| v.is_finite()));
588 }
589
590 #[test]
591 fn test_selective_scan_sequence() {
592 let config = SSMConfig::new(4);
593 let ssm = SelectiveSSM::new(config);
594 let seq_len = 5;
595 let input = vec![0.5; seq_len * 4];
596 let output = ssm.forward(&input);
597 assert_eq!(output.len(), seq_len * 4);
598 assert!(output.iter().all(|v| v.is_finite()));
599 }
600
601 #[test]
602 fn test_state_recurrence_consistency() {
603 let config = SSMConfig::new(4);
605 let ssm = SelectiveSSM::new(config);
606
607 let token = vec![1.0; 4];
608 let batch_out = ssm.forward(&token);
610 let mut state = ssm.init_state();
612 let step_out = ssm.step(&token, &mut state);
613
614 assert_eq!(batch_out.len(), step_out.len());
615 assert!(step_out.iter().all(|v| v.is_finite()));
618 }
619
620 #[test]
621 fn test_mamba_block_forward() {
622 let config = SSMConfig::new(8);
623 let block = MambaBlock::new(config);
624 let input = vec![1.0; 3 * 8]; let output = block.forward(&input);
626 assert_eq!(output.len(), 3 * 8);
627 assert!(output.iter().all(|v| v.is_finite()));
628 assert!(output.iter().any(|v| *v != 0.0));
631 }
632
633 #[test]
634 fn test_hybrid_routing() {
635 let hc = HybridConfig {
637 ssm: SSMConfig::new(4),
638 num_layers: 8,
639 hybrid_ratio: 0.25,
640 };
641 let schedule = hc.layer_schedule();
642 assert_eq!(schedule.len(), 8);
643 let attn_count = schedule
644 .iter()
645 .filter(|k| **k == LayerKind::Attention)
646 .count();
647 assert_eq!(attn_count, 2); assert_eq!(schedule[3], LayerKind::Attention);
650 assert_eq!(schedule[7], LayerKind::Attention);
651 }
652
653 #[test]
654 fn test_hybrid_block_forward() {
655 let hc = HybridConfig {
656 ssm: SSMConfig::new(4),
657 num_layers: 4,
658 hybrid_ratio: 0.25,
659 };
660 let block = HybridBlock::new(hc);
661 assert_eq!(block.attention_layer_count(), 1);
662 let input = vec![1.0; 2 * 4]; let output = block.forward(&input);
664 assert_eq!(output.len(), 2 * 4);
665 assert!(output.iter().all(|v| v.is_finite()));
666 }
667
668 #[test]
669 fn test_inference_step_updates_state() {
670 let config = SSMConfig::new(4);
671 let ssm = SelectiveSSM::new(config);
672 let mut state = ssm.init_state();
673 assert!(state.h.iter().all(|v| *v == 0.0));
674
675 let token = vec![1.0; 4];
676 let _ = ssm.step(&token, &mut state);
677 assert!(state.h.iter().any(|v| *v != 0.0));
679
680 let h_after_1 = state.h.clone();
682 let _ = ssm.step(&token, &mut state);
683 assert_ne!(state.h, h_after_1);
684 }
685
686 #[test]
687 fn test_ssm_state_reset() {
688 let config = SSMConfig::new(4);
689 let ssm = SelectiveSSM::new(config);
690 let mut state = ssm.init_state();
691 let _ = ssm.step(&vec![1.0; 4], &mut state);
692 assert!(state.h.iter().any(|v| *v != 0.0));
693 state.reset();
694 assert!(state.h.iter().all(|v| *v == 0.0));
695 assert_eq!(state.shape(), (8, 16)); }
697}