ruqu_neural_decoder/
mamba.rs

1//! Mamba State-Space Decoder
2//!
3//! Implements a Mamba-style state-space model for sequential decoding of
4//! syndrome representations into error corrections.
5//!
6//! ## State Space Model
7//!
8//! The Mamba decoder uses selective state spaces with data-dependent parameters:
9//! - Input-dependent state transition (A matrix selection)
10//! - Input-dependent input projection (B matrix selection)
11//! - Gated output with residual connection
12
13use crate::error::{NeuralDecoderError, Result};
14use ndarray::{Array1, Array2, ArrayView1};
15use rand::Rng;
16use rand_distr::{Distribution, Normal};
17use serde::{Deserialize, Serialize};
18
19/// Configuration for the Mamba decoder
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MambaConfig {
22    /// Input dimension (from GNN output)
23    pub input_dim: usize,
24    /// State dimension (internal recurrent state)
25    pub state_dim: usize,
26    /// Output dimension (correction probabilities)
27    pub output_dim: usize,
28}
29
30impl Default for MambaConfig {
31    fn default() -> Self {
32        Self {
33            input_dim: 128,
34            state_dim: 64,
35            output_dim: 25, // 5x5 surface code
36        }
37    }
38}
39
40/// The recurrent state of the Mamba decoder
41#[derive(Debug, Clone)]
42pub struct MambaState {
43    /// Hidden state vector
44    pub hidden: Vec<f32>,
45    /// State dimension
46    pub dim: usize,
47    /// Number of steps processed
48    pub steps: usize,
49}
50
51impl MambaState {
52    /// Create a new zero-initialized state
53    pub fn new(dim: usize) -> Self {
54        Self {
55            hidden: vec![0.0; dim],
56            dim,
57            steps: 0,
58        }
59    }
60
61    /// Reset the state to zeros
62    pub fn reset(&mut self) {
63        self.hidden.fill(0.0);
64        self.steps = 0;
65    }
66
67    /// Get the current hidden state
68    pub fn get(&self) -> &[f32] {
69        &self.hidden
70    }
71
72    /// Update the hidden state
73    pub fn update(&mut self, new_state: Vec<f32>) {
74        assert_eq!(new_state.len(), self.dim);
75        self.hidden = new_state;
76        self.steps += 1;
77    }
78}
79
80/// Linear layer with bias
81#[derive(Debug, Clone, Serialize, Deserialize)]
82struct Linear {
83    weights: Array2<f32>,
84    bias: Array1<f32>,
85}
86
87impl Linear {
88    fn new(input_dim: usize, output_dim: usize) -> Self {
89        let mut rng = rand::thread_rng();
90        let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
91        let normal = Normal::new(0.0, scale as f64).unwrap();
92
93        let weights = Array2::from_shape_fn(
94            (output_dim, input_dim),
95            |_| normal.sample(&mut rng) as f32
96        );
97        let bias = Array1::zeros(output_dim);
98
99        Self { weights, bias }
100    }
101
102    fn forward(&self, input: &[f32]) -> Vec<f32> {
103        let x = ArrayView1::from(input);
104        let output = self.weights.dot(&x) + &self.bias;
105        output.to_vec()
106    }
107}
108
109/// Selective scan block (core Mamba operation)
110#[derive(Debug, Clone, Serialize, Deserialize)]
111struct SelectiveScan {
112    /// Projects input to delta (determines discretization)
113    delta_proj: Linear,
114    /// Projects input to B (input matrix)
115    b_proj: Linear,
116    /// Projects input to C (output matrix)
117    c_proj: Linear,
118    /// Discretization scale
119    delta_scale: f32,
120    /// State dimension
121    state_dim: usize,
122}
123
124impl SelectiveScan {
125    fn new(input_dim: usize, state_dim: usize) -> Self {
126        Self {
127            delta_proj: Linear::new(input_dim, state_dim),
128            b_proj: Linear::new(input_dim, state_dim),
129            c_proj: Linear::new(input_dim, state_dim),
130            delta_scale: 0.1,
131            state_dim,
132        }
133    }
134
135    /// Perform one step of selective scan
136    fn step(&self, input: &[f32], state: &[f32]) -> (Vec<f32>, Vec<f32>) {
137        // Compute data-dependent parameters
138        let delta_raw = self.delta_proj.forward(input);
139        let b = self.b_proj.forward(input);
140        let c = self.c_proj.forward(input);
141
142        // Softplus for delta (ensures positive)
143        let delta: Vec<f32> = delta_raw.iter()
144            .map(|&x| (1.0 + (x * self.delta_scale).exp()).ln())
145            .collect();
146
147        // Discretized state transition: x = exp(-delta) * x + delta * B * u
148        // Simplified: x = (1 - delta) * x + delta * B * input_proj
149        let input_norm: f32 = input.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-6);
150
151        let mut new_state = vec![0.0; self.state_dim];
152        for i in 0..self.state_dim {
153            let decay = (-delta[i]).exp();
154            let input_contrib = delta[i] * b[i] * (input_norm / (self.state_dim as f32).sqrt());
155            new_state[i] = decay * state[i] + input_contrib;
156        }
157
158        // Output: y = C * x
159        let output: f32 = c.iter().zip(new_state.iter())
160            .map(|(ci, xi)| ci * xi)
161            .sum();
162
163        // Expand output to match input dimension for residual
164        let output_vec = vec![output / (self.state_dim as f32).sqrt(); input.len()];
165
166        (new_state, output_vec)
167    }
168}
169
170/// Mamba block combining selective scan with gating
171#[derive(Debug, Clone, Serialize, Deserialize)]
172struct MambaBlock {
173    /// Input projection
174    in_proj: Linear,
175    /// Selective scan module
176    ssm: SelectiveScan,
177    /// Gate projection
178    gate_proj: Linear,
179    /// Output projection
180    out_proj: Linear,
181    /// Layer norm
182    norm: Array1<f32>,
183    /// State dimension
184    state_dim: usize,
185}
186
187impl MambaBlock {
188    fn new(input_dim: usize, state_dim: usize) -> Self {
189        Self {
190            in_proj: Linear::new(input_dim, state_dim),
191            ssm: SelectiveScan::new(state_dim, state_dim),
192            gate_proj: Linear::new(input_dim, state_dim),
193            out_proj: Linear::new(state_dim, input_dim),
194            norm: Array1::ones(state_dim),
195            state_dim,
196        }
197    }
198
199    fn forward(&self, input: &[f32], state: &[f32]) -> (Vec<f32>, Vec<f32>) {
200        // Project input
201        let x = self.in_proj.forward(input);
202
203        // Selective scan
204        let (new_state, ssm_out) = self.ssm.step(&x, state);
205
206        // Gating
207        let gate_raw = self.gate_proj.forward(input);
208        let gate: Vec<f32> = gate_raw.iter()
209            .map(|&g| 1.0 / (1.0 + (-g).exp()))
210            .collect();
211
212        // Apply gate
213        let gated: Vec<f32> = ssm_out.iter().zip(gate.iter().cycle())
214            .map(|(s, g)| s * g)
215            .collect();
216
217        // Output projection
218        let output_raw = self.out_proj.forward(&gated[..self.state_dim.min(gated.len())]);
219
220        // Residual connection
221        let output: Vec<f32> = input.iter().zip(output_raw.iter().cycle())
222            .map(|(i, o)| i + o)
223            .collect();
224
225        (new_state, output)
226    }
227}
228
229/// Mamba decoder for syndrome-to-correction mapping
230#[derive(Debug, Clone)]
231pub struct MambaDecoder {
232    config: MambaConfig,
233    block: MambaBlock,
234    output_proj: Linear,
235    state: MambaState,
236}
237
238impl MambaDecoder {
239    /// Create a new Mamba decoder
240    pub fn new(config: MambaConfig) -> Self {
241        let block = MambaBlock::new(config.input_dim, config.state_dim);
242        let output_proj = Linear::new(config.input_dim, config.output_dim);
243        let state = MambaState::new(config.state_dim);
244
245        Self {
246            config,
247            block,
248            output_proj,
249            state,
250        }
251    }
252
253    /// Decode node embeddings to correction probabilities
254    pub fn decode(&mut self, embeddings: &Array2<f32>) -> Result<Array1<f32>> {
255        if embeddings.shape()[0] == 0 {
256            return Err(NeuralDecoderError::EmptyGraph);
257        }
258
259        let expected_dim = self.config.input_dim;
260        let actual_dim = embeddings.shape()[1];
261
262        if actual_dim != expected_dim {
263            return Err(NeuralDecoderError::embed_dim(expected_dim, actual_dim));
264        }
265
266        // Process each node embedding sequentially
267        let mut aggregated = vec![0.0; self.config.input_dim];
268
269        for row in embeddings.rows() {
270            let input: Vec<f32> = row.to_vec();
271
272            // Mamba block forward
273            let (new_state, output) = self.block.forward(&input, self.state.get());
274            self.state.update(new_state);
275
276            // Aggregate outputs
277            for (agg, out) in aggregated.iter_mut().zip(output.iter()) {
278                *agg += out;
279            }
280        }
281
282        // Normalize by number of nodes
283        let num_nodes = embeddings.shape()[0] as f32;
284        for val in &mut aggregated {
285            *val /= num_nodes;
286        }
287
288        // Project to output dimension
289        let logits = self.output_proj.forward(&aggregated);
290
291        // Sigmoid activation for probabilities
292        let probs: Vec<f32> = logits.iter()
293            .map(|&x| 1.0 / (1.0 + (-x).exp()))
294            .collect();
295
296        Ok(Array1::from_vec(probs))
297    }
298
299    /// Decode with explicit state management
300    pub fn decode_step(&mut self, embedding: &[f32]) -> Result<Vec<f32>> {
301        if embedding.len() != self.config.input_dim {
302            return Err(NeuralDecoderError::embed_dim(
303                self.config.input_dim,
304                embedding.len()
305            ));
306        }
307
308        let (new_state, output) = self.block.forward(embedding, self.state.get());
309        self.state.update(new_state);
310
311        Ok(output)
312    }
313
314    /// Get the current state
315    pub fn state(&self) -> &MambaState {
316        &self.state
317    }
318
319    /// Reset the decoder state
320    pub fn reset(&mut self) {
321        self.state.reset();
322    }
323
324    /// Get the configuration
325    pub fn config(&self) -> &MambaConfig {
326        &self.config
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_mamba_config_default() {
336        let config = MambaConfig::default();
337        assert_eq!(config.input_dim, 128);
338        assert_eq!(config.state_dim, 64);
339        assert_eq!(config.output_dim, 25);
340    }
341
342    #[test]
343    fn test_mamba_state_creation() {
344        let state = MambaState::new(64);
345        assert_eq!(state.dim, 64);
346        assert_eq!(state.steps, 0);
347        assert_eq!(state.get().len(), 64);
348
349        // All zeros initially
350        for &val in state.get() {
351            assert_eq!(val, 0.0);
352        }
353    }
354
355    #[test]
356    fn test_mamba_state_update() {
357        let mut state = MambaState::new(4);
358        let new_values = vec![1.0, 2.0, 3.0, 4.0];
359        state.update(new_values.clone());
360
361        assert_eq!(state.steps, 1);
362        assert_eq!(state.get(), &new_values[..]);
363    }
364
365    #[test]
366    fn test_mamba_state_reset() {
367        let mut state = MambaState::new(4);
368        state.update(vec![1.0, 2.0, 3.0, 4.0]);
369        state.update(vec![5.0, 6.0, 7.0, 8.0]);
370
371        assert_eq!(state.steps, 2);
372
373        state.reset();
374
375        assert_eq!(state.steps, 0);
376        for &val in state.get() {
377            assert_eq!(val, 0.0);
378        }
379    }
380
381    #[test]
382    fn test_mamba_decoder_creation() {
383        let config = MambaConfig::default();
384        let decoder = MambaDecoder::new(config);
385
386        assert_eq!(decoder.config().input_dim, 128);
387        assert_eq!(decoder.state().dim, 64);
388    }
389
390    #[test]
391    fn test_mamba_decode() {
392        let config = MambaConfig {
393            input_dim: 32,
394            state_dim: 16,
395            output_dim: 9,
396        };
397        let mut decoder = MambaDecoder::new(config);
398
399        // Create embeddings for 9 nodes
400        let embeddings = Array2::from_shape_fn((9, 32), |(_i, _j)| 0.5);
401
402        let output = decoder.decode(&embeddings).unwrap();
403        assert_eq!(output.len(), 9);
404
405        // Output should be probabilities (0 to 1)
406        for &prob in output.iter() {
407            assert!(prob >= 0.0 && prob <= 1.0);
408        }
409    }
410
411    #[test]
412    fn test_mamba_decode_updates_state() {
413        let config = MambaConfig {
414            input_dim: 32,
415            state_dim: 16,
416            output_dim: 9,
417        };
418        let mut decoder = MambaDecoder::new(config);
419
420        let embeddings = Array2::from_shape_fn((9, 32), |(_i, _j)| 0.5);
421
422        assert_eq!(decoder.state().steps, 0);
423
424        decoder.decode(&embeddings).unwrap();
425
426        // State should be updated (9 steps for 9 nodes)
427        assert_eq!(decoder.state().steps, 9);
428    }
429
430    #[test]
431    fn test_mamba_decode_step() {
432        let config = MambaConfig {
433            input_dim: 32,
434            state_dim: 16,
435            output_dim: 9,
436        };
437        let mut decoder = MambaDecoder::new(config);
438
439        let embedding = vec![0.5; 32];
440        let output = decoder.decode_step(&embedding).unwrap();
441
442        assert_eq!(output.len(), 32); // Same as input_dim for residual
443        assert_eq!(decoder.state().steps, 1);
444    }
445
446    #[test]
447    fn test_mamba_decode_wrong_dimension() {
448        let config = MambaConfig {
449            input_dim: 32,
450            state_dim: 16,
451            output_dim: 9,
452        };
453        let mut decoder = MambaDecoder::new(config);
454
455        // Wrong input dimension
456        let embeddings = Array2::from_shape_fn((9, 64), |(_i, _j)| 0.5);
457        let result = decoder.decode(&embeddings);
458
459        assert!(result.is_err());
460    }
461
462    #[test]
463    fn test_mamba_decode_empty() {
464        let config = MambaConfig {
465            input_dim: 32,
466            state_dim: 16,
467            output_dim: 9,
468        };
469        let mut decoder = MambaDecoder::new(config);
470
471        let embeddings: Array2<f32> = Array2::zeros((0, 32));
472        let result = decoder.decode(&embeddings);
473
474        assert!(result.is_err());
475    }
476
477    #[test]
478    fn test_mamba_reset() {
479        let config = MambaConfig {
480            input_dim: 32,
481            state_dim: 16,
482            output_dim: 9,
483        };
484        let mut decoder = MambaDecoder::new(config);
485
486        let embeddings = Array2::from_shape_fn((9, 32), |(_i, _j)| 0.5);
487        decoder.decode(&embeddings).unwrap();
488
489        assert_eq!(decoder.state().steps, 9);
490
491        decoder.reset();
492
493        assert_eq!(decoder.state().steps, 0);
494    }
495
496    #[test]
497    fn test_mamba_sequential_decode() {
498        let config = MambaConfig {
499            input_dim: 16,
500            state_dim: 8,
501            output_dim: 4,
502        };
503        let mut decoder = MambaDecoder::new(config);
504
505        // Process nodes one by one
506        let embeddings: Vec<Vec<f32>> = (0..5)
507            .map(|i| vec![i as f32 * 0.1; 16])
508            .collect();
509
510        let mut outputs = Vec::new();
511        for emb in &embeddings {
512            let out = decoder.decode_step(emb).unwrap();
513            outputs.push(out);
514        }
515
516        assert_eq!(outputs.len(), 5);
517        assert_eq!(decoder.state().steps, 5);
518    }
519
520    #[test]
521    fn test_mamba_state_evolution() {
522        let config = MambaConfig {
523            input_dim: 8,
524            state_dim: 4,
525            output_dim: 2,
526        };
527        let mut decoder = MambaDecoder::new(config);
528
529        let emb1 = vec![1.0; 8];
530        let emb2 = vec![0.0; 8];
531
532        decoder.decode_step(&emb1).unwrap();
533        let state1 = decoder.state().get().to_vec();
534
535        decoder.decode_step(&emb2).unwrap();
536        let state2 = decoder.state().get().to_vec();
537
538        // States should differ
539        let diff: f32 = state1.iter().zip(state2.iter())
540            .map(|(a, b)| (a - b).abs())
541            .sum();
542        assert!(diff > 0.0);
543    }
544
545    #[test]
546    fn test_selective_scan_step() {
547        let ssm = SelectiveScan::new(8, 4);
548        let input = vec![0.5; 8];
549        let state = vec![0.0; 4];
550
551        let (new_state, output) = ssm.step(&input, &state);
552
553        assert_eq!(new_state.len(), 4);
554        assert_eq!(output.len(), 8);
555    }
556
557    #[test]
558    fn test_mamba_block_forward() {
559        let block = MambaBlock::new(8, 4);
560        let input = vec![0.5; 8];
561        let state = vec![0.0; 4];
562
563        let (new_state, output) = block.forward(&input, &state);
564
565        assert_eq!(new_state.len(), 4);
566        assert_eq!(output.len(), 8);
567    }
568}