1use crate::activations::Activation;
12use crate::error::{NeuralError, Result};
13use crate::layers::{Dense, Dropout, Layer, LayerNorm};
14use scirs2_core::ndarray::{s, Array, Array1, Array2, Array3, IxDyn, ScalarOperand, Zip};
15use scirs2_core::numeric::{Float, NumAssign};
16use scirs2_core::random::{Rng, RngExt, SeedableRng};
17use scirs2_core::simd_ops::SimdUnifiedOps;
18use serde::{Deserialize, Serialize};
19use std::fmt::Debug;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct MambaConfig {
24 pub d_model: usize,
26 pub d_state: usize,
28 pub d_conv: usize,
30 pub expand: usize,
32 pub n_layers: usize,
34 pub dropout_prob: f64,
36 pub vocab_size: Option<usize>,
38 pub num_classes: Option<usize>,
40 pub dt_rank: Option<usize>,
42 pub bias: bool,
44 pub dt_min: f64,
46 pub dt_max: f64,
47}
48
49impl Default for MambaConfig {
50 fn default() -> Self {
51 Self {
52 d_model: 256,
53 d_state: 16,
54 d_conv: 4,
55 expand: 2,
56 n_layers: 4,
57 dropout_prob: 0.1,
58 vocab_size: None,
59 num_classes: None,
60 dt_rank: None, bias: false,
62 dt_min: 0.001,
63 dt_max: 0.1,
64 }
65 }
66}
67
68impl MambaConfig {
69 pub fn new(d_model: usize) -> Self {
71 Self {
72 d_model,
73 ..Default::default()
74 }
75 }
76
77 pub fn with_d_state(mut self, d_state: usize) -> Self {
79 self.d_state = d_state;
80 self
81 }
82
83 pub fn with_n_layers(mut self, n_layers: usize) -> Self {
85 self.n_layers = n_layers;
86 self
87 }
88
89 pub fn with_expand(mut self, expand: usize) -> Self {
91 self.expand = expand;
92 self
93 }
94
95 pub fn with_dropout(mut self, dropout_prob: f64) -> Self {
97 self.dropout_prob = dropout_prob;
98 self
99 }
100
101 pub fn with_vocab_size(mut self, vocab_size: usize) -> Self {
103 self.vocab_size = Some(vocab_size);
104 self
105 }
106
107 pub fn with_num_classes(mut self, num_classes: usize) -> Self {
109 self.num_classes = Some(num_classes);
110 self
111 }
112
113 pub fn d_inner(&self) -> usize {
115 self.d_model * self.expand
116 }
117
118 pub fn get_dt_rank(&self) -> usize {
120 self.dt_rank.unwrap_or_else(|| self.d_model.div_ceil(16)) }
122}
123
124#[derive(Debug)]
133pub struct SelectiveSSM<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
134 d_state: usize,
136 d_inner: usize,
138 a_log: Array2<F>,
140 d: Array1<F>,
142 dt_proj: Dense<F>,
144 x_proj_b: Dense<F>,
146 x_proj_c: Dense<F>,
148 x_proj_dt: Dense<F>,
150}
151
152impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> SelectiveSSM<F> {
153 pub fn new<R: Rng>(
155 d_inner: usize,
156 d_state: usize,
157 dt_rank: usize,
158 rng: &mut R,
159 ) -> Result<Self> {
160 let mut a_log = Array2::<F>::zeros((d_inner, d_state));
163 for i in 0..d_inner {
164 for j in 0..d_state {
165 let val = (j as f64 + 1.0).ln();
167 a_log[[i, j]] = F::from(val).expect("Failed to convert to float");
168 }
169 }
170
171 let d = Array1::<F>::from_elem(d_inner, F::one());
173
174 let dt_proj = Dense::<F>::new(dt_rank, d_inner, Some("dt_proj"), rng)?;
176
177 let x_proj_b = Dense::<F>::new(d_inner, d_state, Some("x_proj_b"), rng)?;
179 let x_proj_c = Dense::<F>::new(d_inner, d_state, Some("x_proj_c"), rng)?;
180 let x_proj_dt = Dense::<F>::new(d_inner, dt_rank, Some("x_proj_dt"), rng)?;
181
182 Ok(Self {
183 d_state,
184 d_inner,
185 a_log,
186 d,
187 dt_proj,
188 x_proj_b,
189 x_proj_c,
190 x_proj_dt,
191 })
192 }
193
194 pub fn forward(&self, x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
202 if x.ndim() != 3 {
203 return Err(NeuralError::InvalidArchitecture(format!(
204 "SelectiveSSM expects 3D input, got {}D",
205 x.ndim()
206 )));
207 }
208
209 let shape = x.shape();
210 let batch_size = shape[0];
211 let seq_len = shape[1];
212 let d_inner = shape[2];
213
214 if d_inner != self.d_inner {
215 return Err(NeuralError::InvalidArchitecture(format!(
216 "Input dimension {} doesn't match d_inner {}",
217 d_inner, self.d_inner
218 )));
219 }
220
221 let a_neg = self.a_log.mapv(|v| -v.exp());
223
224 let x_2d = x
227 .clone()
228 .into_shape_with_order(IxDyn(&[batch_size * seq_len, d_inner]))
229 .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
230
231 let b_proj = self.x_proj_b.forward(&x_2d)?;
233 let c_proj = self.x_proj_c.forward(&x_2d)?;
234 let dt_proj_input = self.x_proj_dt.forward(&x_2d)?;
235 let delta_proj = self.dt_proj.forward(&dt_proj_input)?;
236
237 let delta = delta_proj.mapv(|v: F| {
239 if v > F::from(20.0).expect("Failed to convert constant to float") {
240 v
241 } else {
242 (F::one() + v.exp()).ln()
243 }
244 });
245
246 let b = b_proj
248 .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_state]))
249 .map_err(|e| NeuralError::InferenceError(format!("B reshape error: {}", e)))?;
250
251 let c = c_proj
252 .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_state]))
253 .map_err(|e| NeuralError::InferenceError(format!("C reshape error: {}", e)))?;
254
255 let delta_3d = delta
256 .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_inner]))
257 .map_err(|e| NeuralError::InferenceError(format!("Delta reshape error: {}", e)))?;
258
259 let mut output = Array::zeros(IxDyn(&[batch_size, seq_len, d_inner]));
261
262 for batch_idx in 0..batch_size {
263 let mut h = Array2::<F>::zeros((d_inner, self.d_state));
265
266 for t in 0..seq_len {
267 let dt = delta_3d.slice(s![batch_idx, t, ..]);
269
270 let b_t = b.slice(s![batch_idx, t, ..]);
272 let c_t = c.slice(s![batch_idx, t, ..]);
273
274 let x_t = x.slice(s![batch_idx, t, ..]);
276
277 for i in 0..d_inner {
283 let dt_i = dt[i];
284
285 for j in 0..self.d_state {
286 let a_bar = (dt_i * a_neg[[i, j]]).exp();
288 let b_bar = dt_i * b_t[j];
290
291 h[[i, j]] = a_bar * h[[i, j]] + b_bar * x_t[i];
293 }
294 }
295
296 for i in 0..d_inner {
298 let mut y_i = F::zero();
299 for j in 0..self.d_state {
300 y_i += c_t[j] * h[[i, j]];
301 }
302 output[[batch_idx, t, i]] = y_i + self.d[[i]] * x_t[i];
304 }
305 }
306 }
307
308 Ok(output)
309 }
310}
311
312#[derive(Debug)]
314struct Conv1D<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
315 weights: Array2<F>,
317 bias: Array1<F>,
319 kernel_size: usize,
321 channels: usize,
323}
324
325impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Conv1D<F> {
326 fn new<R: Rng>(channels: usize, kernel_size: usize, rng: &mut R) -> Result<Self> {
327 let std = (F::from(2.0).expect("Failed to convert constant to float")
328 / F::from(channels * kernel_size).expect("Failed to convert to float"))
329 .sqrt();
330
331 let mut weights = Array2::<F>::zeros((channels, kernel_size));
332 for w in weights.iter_mut() {
333 let u1: f64 = rng.random();
334 let u2: f64 = rng.random();
335 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
336 *w = F::from(z).expect("Failed to convert to float") * std;
337 }
338
339 let bias = Array1::<F>::zeros(channels);
340
341 Ok(Self {
342 weights,
343 bias,
344 kernel_size,
345 channels,
346 })
347 }
348
349 fn forward(&self, x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
350 let shape = x.shape();
352 let batch_size = shape[0];
353 let seq_len = shape[1];
354 let channels = shape[2];
355
356 if channels != self.channels {
357 return Err(NeuralError::InvalidArchitecture(format!(
358 "Channel mismatch: {} vs {}",
359 channels, self.channels
360 )));
361 }
362
363 let pad = self.kernel_size - 1;
365 let mut output = Array::zeros(IxDyn(&[batch_size, seq_len, channels]));
366
367 for b in 0..batch_size {
368 for t in 0..seq_len {
369 for c in 0..channels {
370 let mut sum = self.bias[c];
371 for k in 0..self.kernel_size {
372 let input_idx = t as isize + k as isize - pad as isize;
373 if input_idx >= 0 && (input_idx as usize) < seq_len {
374 sum += self.weights[[c, k]] * x[[b, input_idx as usize, c]];
375 }
376 }
377 output[[b, t, c]] = sum;
378 }
379 }
380 }
381
382 Ok(output)
383 }
384}
385
386#[derive(Debug, Clone, Copy)]
388struct SiLU;
389
390impl SiLU {
391 fn forward<F: Float>(&self, x: F) -> F {
392 x * (F::one() / (F::one() + (-x).exp()))
393 }
394}
395
396#[derive(Debug)]
398pub struct MambaBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
399where
400 F: SimdUnifiedOps,
401{
402 d_model: usize,
404 d_inner: usize,
405 in_proj: Dense<F>,
407 conv1d: Conv1D<F>,
409 ssm: SelectiveSSM<F>,
411 out_proj: Dense<F>,
413 norm: LayerNorm<F>,
415}
416
417impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
418 MambaBlock<F>
419{
420 pub fn new<R: Rng>(config: &MambaConfig, rng: &mut R) -> Result<Self> {
422 let d_inner = config.d_inner();
423 let dt_rank = config.get_dt_rank();
424
425 let in_proj = Dense::<F>::new(config.d_model, d_inner * 2, Some("in_proj"), rng)?;
427
428 let conv1d = Conv1D::new(d_inner, config.d_conv, rng)?;
430
431 let ssm = SelectiveSSM::new(d_inner, config.d_state, dt_rank, rng)?;
433
434 let out_proj = Dense::<F>::new(d_inner, config.d_model, Some("out_proj"), rng)?;
436
437 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
439 let norm = LayerNorm::<F>::new(config.d_model, 1e-5, &mut rng)?;
440
441 Ok(Self {
442 d_model: config.d_model,
443 d_inner,
444 in_proj,
445 conv1d,
446 ssm,
447 out_proj,
448 norm,
449 })
450 }
451
452 pub fn forward(&self, x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
454 let residual = x.clone();
456
457 let normed = self.norm.forward(x)?;
459
460 let shape = normed.shape();
461 let batch_size = shape[0];
462 let seq_len = shape[1];
463
464 let x_2d = normed
466 .clone()
467 .into_shape_with_order(IxDyn(&[batch_size * seq_len, self.d_model]))
468 .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
469
470 let proj = self.in_proj.forward(&x_2d)?;
471
472 let proj_3d = proj
473 .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_inner * 2]))
474 .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
475
476 let x_branch = proj_3d
478 .slice(s![.., .., ..self.d_inner])
479 .to_owned()
480 .into_dyn();
481 let z_branch = proj_3d
482 .slice(s![.., .., self.d_inner..])
483 .to_owned()
484 .into_dyn();
485
486 let x_conv = self.conv1d.forward(&x_branch)?;
488
489 let silu = SiLU;
490 let x_silu = x_conv.mapv(|v| silu.forward(v));
491
492 let x_ssm = self.ssm.forward(&x_silu)?;
494
495 let z_silu = z_branch.mapv(|v| silu.forward(v));
497
498 let gated = &x_ssm * &z_silu;
500
501 let gated_2d = gated
503 .into_shape_with_order(IxDyn(&[batch_size * seq_len, self.d_inner]))
504 .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
505
506 let output = self.out_proj.forward(&gated_2d)?;
507
508 let output_3d = output
509 .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_model]))
510 .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
511
512 Ok(&residual + &output_3d)
514 }
515}
516
517#[derive(Debug)]
549pub struct Mamba<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
550where
551 F: SimdUnifiedOps,
552{
553 config: MambaConfig,
555 blocks: Vec<MambaBlock<F>>,
557 final_norm: LayerNorm<F>,
559 classifier: Option<Dense<F>>,
561}
562
563impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
564 Mamba<F>
565{
566 pub fn new<R: Rng>(config: MambaConfig, rng: &mut R) -> Result<Self> {
568 let mut blocks = Vec::with_capacity(config.n_layers);
570 for _ in 0..config.n_layers {
571 blocks.push(MambaBlock::new(&config, rng)?);
572 }
573
574 let mut rng_final = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
576 let final_norm = LayerNorm::<F>::new(config.d_model, 1e-5, &mut rng_final)?;
577
578 let classifier = if let Some(num_classes) = config.num_classes {
580 Some(Dense::<F>::new(
581 config.d_model,
582 num_classes,
583 Some("classifier"),
584 rng,
585 )?)
586 } else {
587 None
588 };
589
590 Ok(Self {
591 config,
592 blocks,
593 final_norm,
594 classifier,
595 })
596 }
597
598 pub fn config(&self) -> &MambaConfig {
600 &self.config
601 }
602
603 pub fn num_layers(&self) -> usize {
605 self.blocks.len()
606 }
607}
608
609impl<F> Layer<F> for Mamba<F>
610where
611 F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static + SimdUnifiedOps,
612{
613 fn as_any(&self) -> &dyn std::any::Any {
614 self
615 }
616
617 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
618 self
619 }
620
621 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
622 if input.ndim() != 3 {
624 return Err(NeuralError::InvalidArchitecture(format!(
625 "Mamba expects 3D input [batch, seq_len, d_model], got {}D",
626 input.ndim()
627 )));
628 }
629
630 let shape = input.shape();
631 let batch_size = shape[0];
632 let seq_len = shape[1];
633 let d_model = shape[2];
634
635 if d_model != self.config.d_model {
636 return Err(NeuralError::InvalidArchitecture(format!(
637 "Input dimension {} doesn't match config d_model {}",
638 d_model, self.config.d_model
639 )));
640 }
641
642 let mut hidden = input.clone();
644 for block in &self.blocks {
645 hidden = block.forward(&hidden)?;
646 }
647
648 let normed = self.final_norm.forward(&hidden)?;
650
651 if let Some(ref classifier) = self.classifier {
653 let mut pooled = Array::zeros(IxDyn(&[batch_size, self.config.d_model]));
655 let seq_len_f = F::from(seq_len).expect("Failed to convert to float");
656
657 for b in 0..batch_size {
658 for d in 0..self.config.d_model {
659 let mut sum = F::zero();
660 for t in 0..seq_len {
661 sum += normed[[b, t, d]];
662 }
663 pooled[[b, d]] = sum / seq_len_f;
664 }
665 }
666
667 classifier.forward(&pooled)
668 } else {
669 Ok(normed)
670 }
671 }
672
673 fn backward(
674 &self,
675 _input: &Array<F, IxDyn>,
676 _grad_output: &Array<F, IxDyn>,
677 ) -> Result<Array<F, IxDyn>> {
678 Err(NeuralError::NotImplemented(
679 "Mamba backward pass not yet implemented".to_string(),
680 ))
681 }
682
683 fn update(&mut self, _learning_rate: F) -> Result<()> {
684 Ok(())
685 }
686}
687
688#[derive(Debug)]
693pub struct S4Layer<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
694 d_model: usize,
696 d_state: usize,
698 a: Array2<F>,
700 b: Array2<F>,
702 c: Array2<F>,
704 d: Array1<F>,
706 delta: F,
708}
709
710impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> S4Layer<F> {
711 pub fn new<R: Rng>(d_model: usize, d_state: usize, rng: &mut R) -> Result<Self> {
713 let mut a = Array2::<F>::zeros((d_state, d_state));
716 for i in 0..d_state {
717 for j in 0..d_state {
718 let val = if i > j {
719 -((2.0 * i as f64 + 1.0) * (2.0 * j as f64 + 1.0)).sqrt()
720 } else if i == j {
721 -(i as f64 + 1.0)
722 } else {
723 0.0
724 };
725 a[[i, j]] = F::from(val).expect("Failed to convert to float");
726 }
727 }
728
729 let mut b = Array2::<F>::zeros((d_state, d_model));
731 for i in 0..d_state {
732 let val = (2.0 * i as f64 + 1.0).sqrt();
733 for j in 0..d_model {
734 let u: f64 = rng.random();
735 b[[i, j]] = F::from(val * (u - 0.5) * 0.1).expect("Operation failed");
736 }
737 }
738
739 let mut c = Array2::<F>::zeros((d_model, d_state));
741 let std = (2.0 / (d_model + d_state) as f64).sqrt();
742 for i in 0..d_model {
743 for j in 0..d_state {
744 let u1: f64 = rng.random();
745 let u2: f64 = rng.random();
746 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
747 c[[i, j]] = F::from(z * std).expect("Failed to convert to float");
748 }
749 }
750
751 let d = Array1::<F>::from_elem(d_model, F::one());
753
754 let delta = F::from(0.001).expect("Failed to convert constant to float");
756
757 Ok(Self {
758 d_model,
759 d_state,
760 a,
761 b,
762 c,
763 d,
764 delta,
765 })
766 }
767
768 pub fn forward(&self, x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
770 if x.ndim() != 3 {
772 return Err(NeuralError::InvalidArchitecture(format!(
773 "S4Layer expects 3D input, got {}D",
774 x.ndim()
775 )));
776 }
777
778 let shape = x.shape();
779 let batch_size = shape[0];
780 let seq_len = shape[1];
781 let d_model = shape[2];
782
783 if d_model != self.d_model {
784 return Err(NeuralError::InvalidArchitecture(format!(
785 "Input dimension {} doesn't match d_model {}",
786 d_model, self.d_model
787 )));
788 }
789
790 let mut a_bar = Array2::<F>::eye(self.d_state);
794 for i in 0..self.d_state {
795 for j in 0..self.d_state {
796 a_bar[[i, j]] += self.delta * self.a[[i, j]];
797 }
798 }
799
800 let b_bar = &self.b * self.delta;
802
803 let mut output = Array::zeros(IxDyn(&[batch_size, seq_len, d_model]));
805
806 for b in 0..batch_size {
807 let mut state = Array1::<F>::zeros(self.d_state);
809
810 for t in 0..seq_len {
811 let x_t: Array1<F> = x
813 .slice(s![b, t, ..])
814 .to_owned()
815 .into_shape_with_order(d_model)
816 .map_err(|_| {
817 NeuralError::InferenceError("Failed to reshape input".to_string())
818 })?;
819
820 let new_state = a_bar.dot(&state) + b_bar.dot(&x_t);
822 state = new_state;
823
824 let y_t = self.c.dot(&state) + &self.d * &x_t;
826
827 for d in 0..d_model {
828 output[[b, t, d]] = y_t[d];
829 }
830 }
831 }
832
833 Ok(output)
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840 use scirs2_core::ndarray::Array3;
841
842 #[test]
843 fn test_mamba_config() {
844 let config = MambaConfig::new(256)
845 .with_n_layers(4)
846 .with_d_state(16)
847 .with_expand(2);
848
849 assert_eq!(config.d_model, 256);
850 assert_eq!(config.n_layers, 4);
851 assert_eq!(config.d_state, 16);
852 assert_eq!(config.d_inner(), 512);
853 }
854
855 #[test]
856 fn test_mamba_creation() {
857 let mut rng = scirs2_core::random::rng();
858 let config = MambaConfig::new(64).with_n_layers(2).with_d_state(8);
859
860 let mamba = Mamba::<f64>::new(config, &mut rng);
861 assert!(mamba.is_ok());
862 }
863
864 #[test]
865 fn test_mamba_forward() {
866 let mut rng = scirs2_core::random::rng();
867 let config = MambaConfig::new(32)
868 .with_n_layers(2)
869 .with_d_state(8)
870 .with_expand(2);
871
872 let mamba = Mamba::<f64>::new(config, &mut rng).expect("Operation failed");
873
874 let input = Array3::<f64>::from_elem((2, 8, 32), 0.1).into_dyn();
876 let output = mamba.forward(&input);
877
878 assert!(output.is_ok());
879 let output = output.expect("Operation failed");
880 assert_eq!(output.shape(), &[2, 8, 32]);
881 }
882
883 #[test]
884 fn test_mamba_with_classifier() {
885 let mut rng = scirs2_core::random::rng();
886 let config = MambaConfig::new(32)
887 .with_n_layers(2)
888 .with_d_state(8)
889 .with_num_classes(10);
890
891 let mamba = Mamba::<f64>::new(config, &mut rng).expect("Operation failed");
892
893 let input = Array3::<f64>::from_elem((2, 8, 32), 0.1).into_dyn();
894 let output = mamba.forward(&input);
895
896 assert!(output.is_ok());
897 let output = output.expect("Operation failed");
898 assert_eq!(output.shape(), &[2, 10]);
900 }
901
902 #[test]
903 fn test_selective_ssm() {
904 let mut rng = scirs2_core::random::rng();
905 let d_inner = 16;
906 let d_state = 4;
907 let dt_rank = 2;
908
909 let ssm = SelectiveSSM::<f64>::new(d_inner, d_state, dt_rank, &mut rng)
910 .expect("Operation failed");
911
912 let input = Array3::<f64>::from_elem((2, 4, d_inner), 0.1).into_dyn();
913 let output = ssm.forward(&input);
914
915 assert!(output.is_ok());
916 assert_eq!(output.expect("Operation failed").shape(), &[2, 4, d_inner]);
917 }
918
919 #[test]
920 fn test_s4_layer() {
921 let mut rng = scirs2_core::random::rng();
922 let d_model = 16;
923 let d_state = 8;
924
925 let s4 = S4Layer::<f64>::new(d_model, d_state, &mut rng).expect("Operation failed");
926
927 let input = Array3::<f64>::from_elem((2, 8, d_model), 0.1).into_dyn();
928 let output = s4.forward(&input);
929
930 assert!(output.is_ok());
931 assert_eq!(output.expect("Operation failed").shape(), &[2, 8, d_model]);
932 }
933
934 #[test]
935 fn test_mamba_block() {
936 let mut rng = scirs2_core::random::rng();
937 let config = MambaConfig::new(32).with_d_state(8);
938
939 let block = MambaBlock::<f64>::new(&config, &mut rng).expect("Operation failed");
940
941 let input = Array3::<f64>::from_elem((2, 4, 32), 0.1).into_dyn();
942 let output = block.forward(&input);
943
944 assert!(output.is_ok());
945 assert_eq!(output.expect("Operation failed").shape(), &[2, 4, 32]);
946 }
947
948 #[test]
949 fn test_mamba_numerical_stability() {
950 let mut rng = scirs2_core::random::rng();
951 let config = MambaConfig::new(16).with_n_layers(1).with_d_state(4);
952
953 let mamba = Mamba::<f64>::new(config, &mut rng).expect("Operation failed");
954
955 let mut input = Array3::<f64>::zeros((1, 8, 16));
957 for i in 0..8 {
958 for j in 0..16 {
959 input[[0, i, j]] = (i as f64 - 4.0) * 0.1 + j as f64 * 0.01;
960 }
961 }
962
963 let output = mamba.forward(&input.into_dyn());
964 assert!(output.is_ok());
965
966 for val in output.expect("Operation failed").iter() {
968 assert!(val.is_finite(), "Output contains non-finite values");
969 }
970 }
971
972 #[test]
973 fn test_conv1d() {
974 let mut rng = scirs2_core::random::rng();
975 let conv = Conv1D::<f64>::new(8, 3, &mut rng).expect("Operation failed");
976
977 let input = Array3::<f64>::from_elem((2, 4, 8), 0.1).into_dyn();
978 let output = conv.forward(&input);
979
980 assert!(output.is_ok());
981 assert_eq!(output.expect("Operation failed").shape(), &[2, 4, 8]);
982 }
983}