Skip to main content

scirs2_neural/models/architectures/
mamba.rs

1//! Mamba architecture implementation
2//!
3//! This module implements the Mamba architecture as described in:
4//! "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"
5//! by Albert Gu and Tri Dao (<https://arxiv.org/abs/2312.00752>)
6//!
7//! Mamba is a selective state space model (SSM) that achieves linear-time
8//! sequence modeling with competitive performance to Transformers while
9//! being significantly more efficient for long sequences.
10
11use crate::activations::Activation;
12use crate::error::{NeuralError, Result};
13use crate::layers::{Dense, Dropout, Layer, LayerNorm};
14use scirs2_core::ndarray::{s, Array, Array1, Array2, Array3, IxDyn, ScalarOperand, Zip};
15use scirs2_core::numeric::{Float, NumAssign};
16use scirs2_core::random::{Rng, RngExt, SeedableRng};
17use scirs2_core::simd_ops::SimdUnifiedOps;
18use serde::{Deserialize, Serialize};
19use std::fmt::Debug;
20
21/// Configuration for the Mamba model
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct MambaConfig {
24    /// Model dimension (d_model)
25    pub d_model: usize,
26    /// State dimension (n)
27    pub d_state: usize,
28    /// Convolution kernel size
29    pub d_conv: usize,
30    /// Expansion factor for the intermediate dimension
31    pub expand: usize,
32    /// Number of Mamba blocks
33    pub n_layers: usize,
34    /// Dropout probability
35    pub dropout_prob: f64,
36    /// Vocabulary size (for embedding)
37    pub vocab_size: Option<usize>,
38    /// Number of output classes (for classification)
39    pub num_classes: Option<usize>,
40    /// Rank for delta projection (dt_rank)
41    pub dt_rank: Option<usize>,
42    /// Whether to use bias in projections
43    pub bias: bool,
44    /// Initialization range for delta
45    pub dt_min: f64,
46    pub dt_max: f64,
47}
48
49impl Default for MambaConfig {
50    fn default() -> Self {
51        Self {
52            d_model: 256,
53            d_state: 16,
54            d_conv: 4,
55            expand: 2,
56            n_layers: 4,
57            dropout_prob: 0.1,
58            vocab_size: None,
59            num_classes: None,
60            dt_rank: None, // Auto-computed as ceil(d_model / 16)
61            bias: false,
62            dt_min: 0.001,
63            dt_max: 0.1,
64        }
65    }
66}
67
68impl MambaConfig {
69    /// Create a new MambaConfig with the specified model dimension
70    pub fn new(d_model: usize) -> Self {
71        Self {
72            d_model,
73            ..Default::default()
74        }
75    }
76
77    /// Set the state dimension
78    pub fn with_d_state(mut self, d_state: usize) -> Self {
79        self.d_state = d_state;
80        self
81    }
82
83    /// Set the number of layers
84    pub fn with_n_layers(mut self, n_layers: usize) -> Self {
85        self.n_layers = n_layers;
86        self
87    }
88
89    /// Set the expansion factor
90    pub fn with_expand(mut self, expand: usize) -> Self {
91        self.expand = expand;
92        self
93    }
94
95    /// Set dropout probability
96    pub fn with_dropout(mut self, dropout_prob: f64) -> Self {
97        self.dropout_prob = dropout_prob;
98        self
99    }
100
101    /// Set vocabulary size
102    pub fn with_vocab_size(mut self, vocab_size: usize) -> Self {
103        self.vocab_size = Some(vocab_size);
104        self
105    }
106
107    /// Set number of classes for classification
108    pub fn with_num_classes(mut self, num_classes: usize) -> Self {
109        self.num_classes = Some(num_classes);
110        self
111    }
112
113    /// Get the inner dimension
114    pub fn d_inner(&self) -> usize {
115        self.d_model * self.expand
116    }
117
118    /// Get the dt_rank (auto-computed if not set)
119    pub fn get_dt_rank(&self) -> usize {
120        self.dt_rank.unwrap_or_else(|| self.d_model.div_ceil(16)) // ceil division
121    }
122}
123
124/// Selective State Space Model (S6) - the core component of Mamba
125///
126/// Implements the discretized SSM:
127/// h_t = A_bar * h_{t-1} + B_bar * x_t
128/// y_t = C * h_t
129///
130/// Where A_bar and B_bar are computed from continuous-time parameters
131/// using the zero-order hold (ZOH) discretization.
132#[derive(Debug)]
133pub struct SelectiveSSM<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
134    /// State dimension
135    d_state: usize,
136    /// Inner dimension
137    d_inner: usize,
138    /// Continuous-time A matrix (diagonal, initialized to special values)
139    a_log: Array2<F>,
140    /// D parameter (skip connection)
141    d: Array1<F>,
142    /// Delta projection weights
143    dt_proj: Dense<F>,
144    /// Projection for B
145    x_proj_b: Dense<F>,
146    /// Projection for C
147    x_proj_c: Dense<F>,
148    /// Projection for delta
149    x_proj_dt: Dense<F>,
150}
151
152impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> SelectiveSSM<F> {
153    /// Create a new SelectiveSSM
154    pub fn new<R: Rng>(
155        d_inner: usize,
156        d_state: usize,
157        dt_rank: usize,
158        rng: &mut R,
159    ) -> Result<Self> {
160        // Initialize A with the S4D-Real initialization
161        // A = -exp(uniformly spaced values from ln(1) to ln(state_dim))
162        let mut a_log = Array2::<F>::zeros((d_inner, d_state));
163        for i in 0..d_inner {
164            for j in 0..d_state {
165                // Log-spaced from 1 to d_state
166                let val = (j as f64 + 1.0).ln();
167                a_log[[i, j]] = F::from(val).expect("Failed to convert to float");
168            }
169        }
170
171        // D initialization (skip connection, initialized to 1)
172        let d = Array1::<F>::from_elem(d_inner, F::one());
173
174        // Delta projection from dt_rank to d_inner
175        let dt_proj = Dense::<F>::new(dt_rank, d_inner, Some("dt_proj"), rng)?;
176
177        // Projections for B, C, and dt from input
178        let x_proj_b = Dense::<F>::new(d_inner, d_state, Some("x_proj_b"), rng)?;
179        let x_proj_c = Dense::<F>::new(d_inner, d_state, Some("x_proj_c"), rng)?;
180        let x_proj_dt = Dense::<F>::new(d_inner, dt_rank, Some("x_proj_dt"), rng)?;
181
182        Ok(Self {
183            d_state,
184            d_inner,
185            a_log,
186            d,
187            dt_proj,
188            x_proj_b,
189            x_proj_c,
190            x_proj_dt,
191        })
192    }
193
194    /// Compute the SSM output using selective scan
195    ///
196    /// # Arguments
197    /// * `x` - Input tensor [batch, seq_len, d_inner]
198    ///
199    /// # Returns
200    /// * Output tensor [batch, seq_len, d_inner]
201    pub fn forward(&self, x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
202        if x.ndim() != 3 {
203            return Err(NeuralError::InvalidArchitecture(format!(
204                "SelectiveSSM expects 3D input, got {}D",
205                x.ndim()
206            )));
207        }
208
209        let shape = x.shape();
210        let batch_size = shape[0];
211        let seq_len = shape[1];
212        let d_inner = shape[2];
213
214        if d_inner != self.d_inner {
215            return Err(NeuralError::InvalidArchitecture(format!(
216                "Input dimension {} doesn't match d_inner {}",
217                d_inner, self.d_inner
218            )));
219        }
220
221        // Compute A from a_log: A = -exp(a_log)
222        let a_neg = self.a_log.mapv(|v| -v.exp());
223
224        // Project input to get B, C, delta
225        // Reshape for dense layer
226        let x_2d = x
227            .clone()
228            .into_shape_with_order(IxDyn(&[batch_size * seq_len, d_inner]))
229            .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
230
231        // Get B, C, dt projections
232        let b_proj = self.x_proj_b.forward(&x_2d)?;
233        let c_proj = self.x_proj_c.forward(&x_2d)?;
234        let dt_proj_input = self.x_proj_dt.forward(&x_2d)?;
235        let delta_proj = self.dt_proj.forward(&dt_proj_input)?;
236
237        // Apply softplus to delta: delta = softplus(delta_proj)
238        let delta = delta_proj.mapv(|v: F| {
239            if v > F::from(20.0).expect("Failed to convert constant to float") {
240                v
241            } else {
242                (F::one() + v.exp()).ln()
243            }
244        });
245
246        // Reshape B, C, delta back to [batch, seq, ...]
247        let b = b_proj
248            .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_state]))
249            .map_err(|e| NeuralError::InferenceError(format!("B reshape error: {}", e)))?;
250
251        let c = c_proj
252            .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_state]))
253            .map_err(|e| NeuralError::InferenceError(format!("C reshape error: {}", e)))?;
254
255        let delta_3d = delta
256            .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_inner]))
257            .map_err(|e| NeuralError::InferenceError(format!("Delta reshape error: {}", e)))?;
258
259        // Perform selective scan
260        let mut output = Array::zeros(IxDyn(&[batch_size, seq_len, d_inner]));
261
262        for batch_idx in 0..batch_size {
263            // Initialize state: [d_inner, d_state]
264            let mut h = Array2::<F>::zeros((d_inner, self.d_state));
265
266            for t in 0..seq_len {
267                // Get delta for this timestep: [d_inner]
268                let dt = delta_3d.slice(s![batch_idx, t, ..]);
269
270                // Get B and C for this timestep: [d_state]
271                let b_t = b.slice(s![batch_idx, t, ..]);
272                let c_t = c.slice(s![batch_idx, t, ..]);
273
274                // Get input for this timestep: [d_inner]
275                let x_t = x.slice(s![batch_idx, t, ..]);
276
277                // Discretize A and B using zero-order hold
278                // A_bar = exp(delta * A)
279                // B_bar = (A_bar - I) * A^(-1) * B ≈ delta * B (simplified)
280
281                // Update state for each dimension
282                for i in 0..d_inner {
283                    let dt_i = dt[i];
284
285                    for j in 0..self.d_state {
286                        // A_bar[i,j] = exp(dt[i] * A[i,j])
287                        let a_bar = (dt_i * a_neg[[i, j]]).exp();
288                        // B_bar[i,j] ≈ dt[i] * B[j] (simplified discretization)
289                        let b_bar = dt_i * b_t[j];
290
291                        // State update: h = A_bar * h + B_bar * x
292                        h[[i, j]] = a_bar * h[[i, j]] + b_bar * x_t[i];
293                    }
294                }
295
296                // Compute output: y = C * h + D * x
297                for i in 0..d_inner {
298                    let mut y_i = F::zero();
299                    for j in 0..self.d_state {
300                        y_i += c_t[j] * h[[i, j]];
301                    }
302                    // Add skip connection
303                    output[[batch_idx, t, i]] = y_i + self.d[[i]] * x_t[i];
304                }
305            }
306        }
307
308        Ok(output)
309    }
310}
311
312/// 1D Convolution layer for Mamba
313#[derive(Debug)]
314struct Conv1D<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
315    /// Convolution weights [out_channels, kernel_size]
316    weights: Array2<F>,
317    /// Bias [out_channels]
318    bias: Array1<F>,
319    /// Kernel size
320    kernel_size: usize,
321    /// Number of channels
322    channels: usize,
323}
324
325impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Conv1D<F> {
326    fn new<R: Rng>(channels: usize, kernel_size: usize, rng: &mut R) -> Result<Self> {
327        let std = (F::from(2.0).expect("Failed to convert constant to float")
328            / F::from(channels * kernel_size).expect("Failed to convert to float"))
329        .sqrt();
330
331        let mut weights = Array2::<F>::zeros((channels, kernel_size));
332        for w in weights.iter_mut() {
333            let u1: f64 = rng.random();
334            let u2: f64 = rng.random();
335            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
336            *w = F::from(z).expect("Failed to convert to float") * std;
337        }
338
339        let bias = Array1::<F>::zeros(channels);
340
341        Ok(Self {
342            weights,
343            bias,
344            kernel_size,
345            channels,
346        })
347    }
348
349    fn forward(&self, x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
350        // x: [batch, seq_len, channels]
351        let shape = x.shape();
352        let batch_size = shape[0];
353        let seq_len = shape[1];
354        let channels = shape[2];
355
356        if channels != self.channels {
357            return Err(NeuralError::InvalidArchitecture(format!(
358                "Channel mismatch: {} vs {}",
359                channels, self.channels
360            )));
361        }
362
363        // Causal 1D convolution with padding
364        let pad = self.kernel_size - 1;
365        let mut output = Array::zeros(IxDyn(&[batch_size, seq_len, channels]));
366
367        for b in 0..batch_size {
368            for t in 0..seq_len {
369                for c in 0..channels {
370                    let mut sum = self.bias[c];
371                    for k in 0..self.kernel_size {
372                        let input_idx = t as isize + k as isize - pad as isize;
373                        if input_idx >= 0 && (input_idx as usize) < seq_len {
374                            sum += self.weights[[c, k]] * x[[b, input_idx as usize, c]];
375                        }
376                    }
377                    output[[b, t, c]] = sum;
378                }
379            }
380        }
381
382        Ok(output)
383    }
384}
385
386/// SiLU (Swish) activation function
387#[derive(Debug, Clone, Copy)]
388struct SiLU;
389
390impl SiLU {
391    fn forward<F: Float>(&self, x: F) -> F {
392        x * (F::one() / (F::one() + (-x).exp()))
393    }
394}
395
396/// A single Mamba block
397#[derive(Debug)]
398pub struct MambaBlock<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
399where
400    F: SimdUnifiedOps,
401{
402    /// Configuration
403    d_model: usize,
404    d_inner: usize,
405    /// Input projection [d_model -> d_inner * 2]
406    in_proj: Dense<F>,
407    /// 1D convolution
408    conv1d: Conv1D<F>,
409    /// Selective SSM
410    ssm: SelectiveSSM<F>,
411    /// Output projection [d_inner -> d_model]
412    out_proj: Dense<F>,
413    /// Layer normalization
414    norm: LayerNorm<F>,
415}
416
417impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
418    MambaBlock<F>
419{
420    /// Create a new MambaBlock
421    pub fn new<R: Rng>(config: &MambaConfig, rng: &mut R) -> Result<Self> {
422        let d_inner = config.d_inner();
423        let dt_rank = config.get_dt_rank();
424
425        // Input projection: d_model -> d_inner * 2 (for x and z branches)
426        let in_proj = Dense::<F>::new(config.d_model, d_inner * 2, Some("in_proj"), rng)?;
427
428        // 1D causal convolution
429        let conv1d = Conv1D::new(d_inner, config.d_conv, rng)?;
430
431        // Selective SSM
432        let ssm = SelectiveSSM::new(d_inner, config.d_state, dt_rank, rng)?;
433
434        // Output projection
435        let out_proj = Dense::<F>::new(d_inner, config.d_model, Some("out_proj"), rng)?;
436
437        // Layer norm
438        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
439        let norm = LayerNorm::<F>::new(config.d_model, 1e-5, &mut rng)?;
440
441        Ok(Self {
442            d_model: config.d_model,
443            d_inner,
444            in_proj,
445            conv1d,
446            ssm,
447            out_proj,
448            norm,
449        })
450    }
451
452    /// Forward pass
453    pub fn forward(&self, x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
454        // x: [batch, seq_len, d_model]
455        let residual = x.clone();
456
457        // Layer norm
458        let normed = self.norm.forward(x)?;
459
460        let shape = normed.shape();
461        let batch_size = shape[0];
462        let seq_len = shape[1];
463
464        // Input projection: [batch, seq, d_model] -> [batch, seq, d_inner * 2]
465        let x_2d = normed
466            .clone()
467            .into_shape_with_order(IxDyn(&[batch_size * seq_len, self.d_model]))
468            .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
469
470        let proj = self.in_proj.forward(&x_2d)?;
471
472        let proj_3d = proj
473            .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_inner * 2]))
474            .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
475
476        // Split into x and z branches
477        let x_branch = proj_3d
478            .slice(s![.., .., ..self.d_inner])
479            .to_owned()
480            .into_dyn();
481        let z_branch = proj_3d
482            .slice(s![.., .., self.d_inner..])
483            .to_owned()
484            .into_dyn();
485
486        // Apply convolution and SiLU to x branch
487        let x_conv = self.conv1d.forward(&x_branch)?;
488
489        let silu = SiLU;
490        let x_silu = x_conv.mapv(|v| silu.forward(v));
491
492        // Apply SSM
493        let x_ssm = self.ssm.forward(&x_silu)?;
494
495        // Gate with z branch (SiLU activation)
496        let z_silu = z_branch.mapv(|v| silu.forward(v));
497
498        // Element-wise multiplication
499        let gated = &x_ssm * &z_silu;
500
501        // Output projection
502        let gated_2d = gated
503            .into_shape_with_order(IxDyn(&[batch_size * seq_len, self.d_inner]))
504            .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
505
506        let output = self.out_proj.forward(&gated_2d)?;
507
508        let output_3d = output
509            .into_shape_with_order(IxDyn(&[batch_size, seq_len, self.d_model]))
510            .map_err(|e| NeuralError::InferenceError(format!("Reshape error: {}", e)))?;
511
512        // Residual connection
513        Ok(&residual + &output_3d)
514    }
515}
516
517/// The Mamba model
518///
519/// A state-space model that achieves linear-time sequence modeling
520/// with selective state spaces.
521///
522/// # Architecture
523///
524/// - Optional embedding layer (for language modeling)
525/// - Stack of Mamba blocks
526/// - Final layer norm
527/// - Optional classification head
528///
529/// # Examples
530///
531/// ```rust
532/// use scirs2_neural::models::architectures::{Mamba, MambaConfig};
533/// use scirs2_neural::layers::Layer;
534/// use scirs2_core::ndarray::Array3;
535/// use scirs2_core::random::rng;
536///
537/// let mut rng = rng();
538/// let config = MambaConfig::new(256)
539///     .with_n_layers(4)
540///     .with_d_state(16);
541///
542/// let mamba = Mamba::<f64>::new(config, &mut rng).expect("Operation failed");
543///
544/// // Input: [batch, seq_len, d_model]
545/// let input = Array3::<f64>::from_elem((2, 32, 256), 0.1).into_dyn();
546/// let output = mamba.forward(&input).expect("Operation failed");
547/// ```
548#[derive(Debug)]
549pub struct Mamba<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
550where
551    F: SimdUnifiedOps,
552{
553    /// Configuration
554    config: MambaConfig,
555    /// Stack of Mamba blocks
556    blocks: Vec<MambaBlock<F>>,
557    /// Final layer normalization
558    final_norm: LayerNorm<F>,
559    /// Optional classification head
560    classifier: Option<Dense<F>>,
561}
562
563impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
564    Mamba<F>
565{
566    /// Create a new Mamba model
567    pub fn new<R: Rng>(config: MambaConfig, rng: &mut R) -> Result<Self> {
568        // Create Mamba blocks
569        let mut blocks = Vec::with_capacity(config.n_layers);
570        for _ in 0..config.n_layers {
571            blocks.push(MambaBlock::new(&config, rng)?);
572        }
573
574        // Final layer norm
575        let mut rng_final = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
576        let final_norm = LayerNorm::<F>::new(config.d_model, 1e-5, &mut rng_final)?;
577
578        // Optional classifier
579        let classifier = if let Some(num_classes) = config.num_classes {
580            Some(Dense::<F>::new(
581                config.d_model,
582                num_classes,
583                Some("classifier"),
584                rng,
585            )?)
586        } else {
587            None
588        };
589
590        Ok(Self {
591            config,
592            blocks,
593            final_norm,
594            classifier,
595        })
596    }
597
598    /// Get the configuration
599    pub fn config(&self) -> &MambaConfig {
600        &self.config
601    }
602
603    /// Get the number of layers
604    pub fn num_layers(&self) -> usize {
605        self.blocks.len()
606    }
607}
608
609impl<F> Layer<F> for Mamba<F>
610where
611    F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static + SimdUnifiedOps,
612{
613    fn as_any(&self) -> &dyn std::any::Any {
614        self
615    }
616
617    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
618        self
619    }
620
621    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
622        // Input: [batch, seq_len, d_model]
623        if input.ndim() != 3 {
624            return Err(NeuralError::InvalidArchitecture(format!(
625                "Mamba expects 3D input [batch, seq_len, d_model], got {}D",
626                input.ndim()
627            )));
628        }
629
630        let shape = input.shape();
631        let batch_size = shape[0];
632        let seq_len = shape[1];
633        let d_model = shape[2];
634
635        if d_model != self.config.d_model {
636            return Err(NeuralError::InvalidArchitecture(format!(
637                "Input dimension {} doesn't match config d_model {}",
638                d_model, self.config.d_model
639            )));
640        }
641
642        // Pass through Mamba blocks
643        let mut hidden = input.clone();
644        for block in &self.blocks {
645            hidden = block.forward(&hidden)?;
646        }
647
648        // Final layer norm
649        let normed = self.final_norm.forward(&hidden)?;
650
651        // If classifier, apply it (use last token or mean pooling)
652        if let Some(ref classifier) = self.classifier {
653            // Mean pooling over sequence
654            let mut pooled = Array::zeros(IxDyn(&[batch_size, self.config.d_model]));
655            let seq_len_f = F::from(seq_len).expect("Failed to convert to float");
656
657            for b in 0..batch_size {
658                for d in 0..self.config.d_model {
659                    let mut sum = F::zero();
660                    for t in 0..seq_len {
661                        sum += normed[[b, t, d]];
662                    }
663                    pooled[[b, d]] = sum / seq_len_f;
664                }
665            }
666
667            classifier.forward(&pooled)
668        } else {
669            Ok(normed)
670        }
671    }
672
673    fn backward(
674        &self,
675        _input: &Array<F, IxDyn>,
676        _grad_output: &Array<F, IxDyn>,
677    ) -> Result<Array<F, IxDyn>> {
678        Err(NeuralError::NotImplemented(
679            "Mamba backward pass not yet implemented".to_string(),
680        ))
681    }
682
683    fn update(&mut self, _learning_rate: F) -> Result<()> {
684        Ok(())
685    }
686}
687
688/// Simplified State Space Model (S4) layer
689///
690/// This implements a basic structured state space model without
691/// the selective mechanism. Useful for comparison and simpler use cases.
692#[derive(Debug)]
693pub struct S4Layer<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
694    /// Dimension
695    d_model: usize,
696    /// State dimension
697    d_state: usize,
698    /// A matrix (HiPPO-initialized)
699    a: Array2<F>,
700    /// B matrix
701    b: Array2<F>,
702    /// C matrix
703    c: Array2<F>,
704    /// D matrix (skip connection)
705    d: Array1<F>,
706    /// Delta (step size)
707    delta: F,
708}
709
710impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> S4Layer<F> {
711    /// Create a new S4 layer with HiPPO initialization
712    pub fn new<R: Rng>(d_model: usize, d_state: usize, rng: &mut R) -> Result<Self> {
713        // HiPPO matrix initialization for A
714        // A[i,j] = -sqrt((2i+1)(2j+1)) if i > j else -(i+1) if i == j else 0
715        let mut a = Array2::<F>::zeros((d_state, d_state));
716        for i in 0..d_state {
717            for j in 0..d_state {
718                let val = if i > j {
719                    -((2.0 * i as f64 + 1.0) * (2.0 * j as f64 + 1.0)).sqrt()
720                } else if i == j {
721                    -(i as f64 + 1.0)
722                } else {
723                    0.0
724                };
725                a[[i, j]] = F::from(val).expect("Failed to convert to float");
726            }
727        }
728
729        // B initialization
730        let mut b = Array2::<F>::zeros((d_state, d_model));
731        for i in 0..d_state {
732            let val = (2.0 * i as f64 + 1.0).sqrt();
733            for j in 0..d_model {
734                let u: f64 = rng.random();
735                b[[i, j]] = F::from(val * (u - 0.5) * 0.1).expect("Operation failed");
736            }
737        }
738
739        // C initialization (learnable output projection)
740        let mut c = Array2::<F>::zeros((d_model, d_state));
741        let std = (2.0 / (d_model + d_state) as f64).sqrt();
742        for i in 0..d_model {
743            for j in 0..d_state {
744                let u1: f64 = rng.random();
745                let u2: f64 = rng.random();
746                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
747                c[[i, j]] = F::from(z * std).expect("Failed to convert to float");
748            }
749        }
750
751        // D (skip connection)
752        let d = Array1::<F>::from_elem(d_model, F::one());
753
754        // Default step size
755        let delta = F::from(0.001).expect("Failed to convert constant to float");
756
757        Ok(Self {
758            d_model,
759            d_state,
760            a,
761            b,
762            c,
763            d,
764            delta,
765        })
766    }
767
768    /// Forward pass using convolution mode (parallel over sequence)
769    pub fn forward(&self, x: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
770        // x: [batch, seq_len, d_model]
771        if x.ndim() != 3 {
772            return Err(NeuralError::InvalidArchitecture(format!(
773                "S4Layer expects 3D input, got {}D",
774                x.ndim()
775            )));
776        }
777
778        let shape = x.shape();
779        let batch_size = shape[0];
780        let seq_len = shape[1];
781        let d_model = shape[2];
782
783        if d_model != self.d_model {
784            return Err(NeuralError::InvalidArchitecture(format!(
785                "Input dimension {} doesn't match d_model {}",
786                d_model, self.d_model
787            )));
788        }
789
790        // Discretize using ZOH
791        // A_bar = exp(delta * A)
792        // For simplicity, use first-order approximation: A_bar ≈ I + delta * A
793        let mut a_bar = Array2::<F>::eye(self.d_state);
794        for i in 0..self.d_state {
795            for j in 0..self.d_state {
796                a_bar[[i, j]] += self.delta * self.a[[i, j]];
797            }
798        }
799
800        // B_bar ≈ delta * B
801        let b_bar = &self.b * self.delta;
802
803        // Run SSM
804        let mut output = Array::zeros(IxDyn(&[batch_size, seq_len, d_model]));
805
806        for b in 0..batch_size {
807            // State: [d_state]
808            let mut state = Array1::<F>::zeros(self.d_state);
809
810            for t in 0..seq_len {
811                // Get input: [d_model]
812                let x_t: Array1<F> = x
813                    .slice(s![b, t, ..])
814                    .to_owned()
815                    .into_shape_with_order(d_model)
816                    .map_err(|_| {
817                        NeuralError::InferenceError("Failed to reshape input".to_string())
818                    })?;
819
820                // State update: state = A_bar @ state + B_bar @ x_t
821                let new_state = a_bar.dot(&state) + b_bar.dot(&x_t);
822                state = new_state;
823
824                // Output: y_t = C @ state + D * x_t
825                let y_t = self.c.dot(&state) + &self.d * &x_t;
826
827                for d in 0..d_model {
828                    output[[b, t, d]] = y_t[d];
829                }
830            }
831        }
832
833        Ok(output)
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840    use scirs2_core::ndarray::Array3;
841
842    #[test]
843    fn test_mamba_config() {
844        let config = MambaConfig::new(256)
845            .with_n_layers(4)
846            .with_d_state(16)
847            .with_expand(2);
848
849        assert_eq!(config.d_model, 256);
850        assert_eq!(config.n_layers, 4);
851        assert_eq!(config.d_state, 16);
852        assert_eq!(config.d_inner(), 512);
853    }
854
855    #[test]
856    fn test_mamba_creation() {
857        let mut rng = scirs2_core::random::rng();
858        let config = MambaConfig::new(64).with_n_layers(2).with_d_state(8);
859
860        let mamba = Mamba::<f64>::new(config, &mut rng);
861        assert!(mamba.is_ok());
862    }
863
864    #[test]
865    fn test_mamba_forward() {
866        let mut rng = scirs2_core::random::rng();
867        let config = MambaConfig::new(32)
868            .with_n_layers(2)
869            .with_d_state(8)
870            .with_expand(2);
871
872        let mamba = Mamba::<f64>::new(config, &mut rng).expect("Operation failed");
873
874        // Input: [batch=2, seq_len=8, d_model=32]
875        let input = Array3::<f64>::from_elem((2, 8, 32), 0.1).into_dyn();
876        let output = mamba.forward(&input);
877
878        assert!(output.is_ok());
879        let output = output.expect("Operation failed");
880        assert_eq!(output.shape(), &[2, 8, 32]);
881    }
882
883    #[test]
884    fn test_mamba_with_classifier() {
885        let mut rng = scirs2_core::random::rng();
886        let config = MambaConfig::new(32)
887            .with_n_layers(2)
888            .with_d_state(8)
889            .with_num_classes(10);
890
891        let mamba = Mamba::<f64>::new(config, &mut rng).expect("Operation failed");
892
893        let input = Array3::<f64>::from_elem((2, 8, 32), 0.1).into_dyn();
894        let output = mamba.forward(&input);
895
896        assert!(output.is_ok());
897        let output = output.expect("Operation failed");
898        // With classifier, output should be [batch, num_classes]
899        assert_eq!(output.shape(), &[2, 10]);
900    }
901
902    #[test]
903    fn test_selective_ssm() {
904        let mut rng = scirs2_core::random::rng();
905        let d_inner = 16;
906        let d_state = 4;
907        let dt_rank = 2;
908
909        let ssm = SelectiveSSM::<f64>::new(d_inner, d_state, dt_rank, &mut rng)
910            .expect("Operation failed");
911
912        let input = Array3::<f64>::from_elem((2, 4, d_inner), 0.1).into_dyn();
913        let output = ssm.forward(&input);
914
915        assert!(output.is_ok());
916        assert_eq!(output.expect("Operation failed").shape(), &[2, 4, d_inner]);
917    }
918
919    #[test]
920    fn test_s4_layer() {
921        let mut rng = scirs2_core::random::rng();
922        let d_model = 16;
923        let d_state = 8;
924
925        let s4 = S4Layer::<f64>::new(d_model, d_state, &mut rng).expect("Operation failed");
926
927        let input = Array3::<f64>::from_elem((2, 8, d_model), 0.1).into_dyn();
928        let output = s4.forward(&input);
929
930        assert!(output.is_ok());
931        assert_eq!(output.expect("Operation failed").shape(), &[2, 8, d_model]);
932    }
933
934    #[test]
935    fn test_mamba_block() {
936        let mut rng = scirs2_core::random::rng();
937        let config = MambaConfig::new(32).with_d_state(8);
938
939        let block = MambaBlock::<f64>::new(&config, &mut rng).expect("Operation failed");
940
941        let input = Array3::<f64>::from_elem((2, 4, 32), 0.1).into_dyn();
942        let output = block.forward(&input);
943
944        assert!(output.is_ok());
945        assert_eq!(output.expect("Operation failed").shape(), &[2, 4, 32]);
946    }
947
948    #[test]
949    fn test_mamba_numerical_stability() {
950        let mut rng = scirs2_core::random::rng();
951        let config = MambaConfig::new(16).with_n_layers(1).with_d_state(4);
952
953        let mamba = Mamba::<f64>::new(config, &mut rng).expect("Operation failed");
954
955        // Test with varying input values
956        let mut input = Array3::<f64>::zeros((1, 8, 16));
957        for i in 0..8 {
958            for j in 0..16 {
959                input[[0, i, j]] = (i as f64 - 4.0) * 0.1 + j as f64 * 0.01;
960            }
961        }
962
963        let output = mamba.forward(&input.into_dyn());
964        assert!(output.is_ok());
965
966        // Check all values are finite
967        for val in output.expect("Operation failed").iter() {
968            assert!(val.is_finite(), "Output contains non-finite values");
969        }
970    }
971
972    #[test]
973    fn test_conv1d() {
974        let mut rng = scirs2_core::random::rng();
975        let conv = Conv1D::<f64>::new(8, 3, &mut rng).expect("Operation failed");
976
977        let input = Array3::<f64>::from_elem((2, 4, 8), 0.1).into_dyn();
978        let output = conv.forward(&input);
979
980        assert!(output.is_ok());
981        assert_eq!(output.expect("Operation failed").shape(), &[2, 4, 8]);
982    }
983}