scirs2_series/neural_forecasting/
mamba.rs1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::lstm::LSTMCell;
11use crate::error::Result; #[derive(Debug)]
15pub struct MambaBlock<F: Float + Debug> {
16 #[allow(dead_code)]
18 state_dim: usize,
19 #[allow(dead_code)]
21 input_dim: usize,
22 #[allow(dead_code)]
24 selection_weights: Array2<F>,
25 #[allow(dead_code)]
27 state_matrix: Array2<F>,
28 #[allow(dead_code)]
30 input_projection: Array2<F>,
31 #[allow(dead_code)]
33 output_projection: Array2<F>,
34}
35
36impl<F: Float + Debug + Clone + FromPrimitive> MambaBlock<F> {
37 pub fn new(input_dim: usize, state_dim: usize) -> Self {
39 let scale = F::from(2.0).unwrap() / F::from(input_dim).unwrap();
40 let std_dev = scale.sqrt();
41
42 Self {
43 state_dim,
44 input_dim,
45 selection_weights: LSTMCell::random_matrix(state_dim, input_dim, std_dev),
46 state_matrix: LSTMCell::random_matrix(state_dim, state_dim, std_dev),
47 input_projection: LSTMCell::random_matrix(state_dim, input_dim, std_dev),
48 output_projection: LSTMCell::random_matrix(input_dim, state_dim, std_dev),
49 }
50 }
51
52 pub fn forward(&self, input: &Array2<F>) -> Result<Array2<F>> {
54 Ok(input.clone())
56 }
57}