ruqu_neural_decoder/
lib.rs

1//! # Neural Quantum Error Decoder (NQED)
2//!
3//! This crate implements a neural-network-based quantum error decoder that combines
4//! Graph Neural Networks (GNN) with Mamba state-space models for efficient syndrome
5//! decoding.
6//!
7//! ## Architecture
8//!
9//! The NQED pipeline consists of:
10//! 1. **Syndrome Graph Construction**: Converts syndrome bitmaps to graph structures
11//! 2. **GNN Encoding**: Multi-layer graph attention for syndrome representation
12//! 3. **Mamba Decoder**: State-space model for sequential decoding
13//! 4. **Feature Fusion**: Integrates min-cut structural information
14//!
15//! ## Quick Start
16//!
17//! ```rust,ignore
18//! use ruqu_neural_decoder::{NeuralDecoder, DecoderConfig};
19//!
20//! let config = DecoderConfig::default();
21//! let mut decoder = NeuralDecoder::new(config);
22//!
23//! // Create syndrome from measurements
24//! let syndrome = vec![true, false, true, false, false];
25//! let correction = decoder.decode(&syndrome)?;
26//! ```
27
28#![deny(missing_docs)]
29#![warn(clippy::all)]
30#![allow(clippy::module_name_repetitions)]
31
32pub mod error;
33pub mod graph;
34pub mod gnn;
35pub mod mamba;
36pub mod fusion;
37
38// Re-exports
39pub use error::{NeuralDecoderError, Result};
40pub use graph::{DetectorGraph, GraphBuilder, Node, Edge};
41pub use gnn::{GNNEncoder, GNNConfig, AttentionLayer};
42pub use mamba::{MambaDecoder, MambaConfig, MambaState};
43pub use fusion::{FeatureFusion, FusionConfig};
44
45use ndarray::{Array1, Array2};
46use serde::{Deserialize, Serialize};
47
48/// Configuration for the neural decoder
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct DecoderConfig {
51    /// Code distance (determines graph size)
52    pub distance: usize,
53    /// Embedding dimension for node features
54    pub embed_dim: usize,
55    /// Hidden dimension for internal representations
56    pub hidden_dim: usize,
57    /// Number of GNN layers
58    pub num_gnn_layers: usize,
59    /// Number of attention heads
60    pub num_heads: usize,
61    /// Mamba state dimension
62    pub mamba_state_dim: usize,
63    /// Whether to use min-cut fusion
64    pub use_mincut_fusion: bool,
65    /// Dropout rate (0.0 to 1.0)
66    pub dropout: f32,
67}
68
69impl Default for DecoderConfig {
70    fn default() -> Self {
71        Self {
72            distance: 5,
73            embed_dim: 64,
74            hidden_dim: 128,
75            num_gnn_layers: 3,
76            num_heads: 4,
77            mamba_state_dim: 64,
78            use_mincut_fusion: false,
79            dropout: 0.1,
80        }
81    }
82}
83
84/// Correction output from the decoder
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct Correction {
87    /// X-type corrections (bit flips)
88    pub x_corrections: Vec<usize>,
89    /// Z-type corrections (phase flips)
90    pub z_corrections: Vec<usize>,
91    /// Confidence score (0.0 to 1.0)
92    pub confidence: f64,
93    /// Decode time in nanoseconds
94    pub decode_time_ns: u64,
95}
96
97/// Neural Quantum Error Decoder
98///
99/// Combines GNN-based syndrome encoding with Mamba state-space decoding.
100pub struct NeuralDecoder {
101    config: DecoderConfig,
102    gnn: GNNEncoder,
103    mamba: MambaDecoder,
104    fusion: Option<FeatureFusion>,
105}
106
107impl NeuralDecoder {
108    /// Create a new neural decoder with the given configuration
109    ///
110    /// # Errors
111    ///
112    /// Returns an error if `embed_dim` is not divisible by `num_heads`.
113    pub fn new(config: DecoderConfig) -> Result<Self> {
114        let gnn_config = GNNConfig {
115            input_dim: 5, // Node features: [fired, row_norm, col_norm, node_type_x, node_type_z]
116            embed_dim: config.embed_dim,
117            hidden_dim: config.hidden_dim,
118            num_layers: config.num_gnn_layers,
119            num_heads: config.num_heads,
120            dropout: config.dropout,
121        };
122
123        let mamba_config = MambaConfig {
124            input_dim: config.hidden_dim,
125            state_dim: config.mamba_state_dim,
126            output_dim: config.distance * config.distance,
127        };
128
129        let fusion = if config.use_mincut_fusion {
130            let fusion_config = FusionConfig {
131                gnn_dim: config.hidden_dim,
132                mincut_dim: 16,
133                output_dim: config.hidden_dim,
134                gnn_weight: 0.5,
135                mincut_weight: 0.3,
136                boundary_weight: 0.2,
137                adaptive_weights: true,
138                temperature: 1.0,
139            };
140            FeatureFusion::new(fusion_config).ok()
141        } else {
142            None
143        };
144
145        Ok(Self {
146            config,
147            gnn: GNNEncoder::new(gnn_config)?,
148            mamba: MambaDecoder::new(mamba_config),
149            fusion,
150        })
151    }
152
153    /// Decode a syndrome bitmap and produce corrections
154    pub fn decode(&mut self, syndrome: &[bool]) -> Result<Correction> {
155        let start = std::time::Instant::now();
156
157        // Build detector graph from syndrome
158        let graph = GraphBuilder::from_surface_code(self.config.distance)
159            .with_syndrome(syndrome)?
160            .build()?;
161
162        // GNN encoding
163        let node_embeddings = self.gnn.encode(&graph)?;
164
165        // Optional: fuse with min-cut features (requires graph with edge weights and positions)
166        // For now, use the raw GNN embeddings. Full fusion requires:
167        // fusion.fuse(&node_embeddings, &mincut_features, &boundary_features, confidences)
168        let fused = node_embeddings;
169
170        // Mamba decoding
171        let output = self.mamba.decode(&fused)?;
172
173        // Convert output to corrections
174        let corrections = self.output_to_corrections(&output)?;
175
176        let elapsed = start.elapsed();
177
178        Ok(Correction {
179            x_corrections: corrections.0,
180            z_corrections: corrections.1,
181            confidence: corrections.2,
182            decode_time_ns: elapsed.as_nanos() as u64,
183        })
184    }
185
186    /// Convert model output to correction indices
187    fn output_to_corrections(&self, output: &Array1<f32>) -> Result<(Vec<usize>, Vec<usize>, f64)> {
188        let threshold = 0.5;
189        let mut x_corrections = Vec::new();
190
191        for (i, &val) in output.iter().enumerate() {
192            if val > threshold {
193                x_corrections.push(i);
194            }
195        }
196
197        // Compute confidence as average certainty (guard against empty output)
198        let confidence = if output.is_empty() {
199            0.0
200        } else {
201            output.iter()
202                .map(|&v| (v - 0.5).abs() * 2.0)
203                .sum::<f32>() / output.len() as f32
204        };
205
206        Ok((x_corrections, Vec::new(), confidence as f64))
207    }
208
209    /// Get the decoder configuration
210    #[must_use]
211    pub fn config(&self) -> &DecoderConfig {
212        &self.config
213    }
214
215    /// Reset the decoder state
216    pub fn reset(&mut self) {
217        self.mamba.reset();
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_decoder_config_default() {
227        let config = DecoderConfig::default();
228        assert_eq!(config.distance, 5);
229        assert_eq!(config.embed_dim, 64);
230        assert_eq!(config.hidden_dim, 128);
231        assert!(config.dropout >= 0.0 && config.dropout <= 1.0);
232    }
233
234    #[test]
235    fn test_decoder_creation() {
236        let config = DecoderConfig::default();
237        let decoder = NeuralDecoder::new(config).unwrap();
238        assert_eq!(decoder.config().distance, 5);
239    }
240
241    #[test]
242    fn test_correction_default() {
243        let correction = Correction::default();
244        assert!(correction.x_corrections.is_empty());
245        assert!(correction.z_corrections.is_empty());
246        assert_eq!(correction.confidence, 0.0);
247    }
248
249    #[test]
250    fn test_decoder_empty_syndrome() {
251        let config = DecoderConfig {
252            distance: 3,
253            ..Default::default()
254        };
255        let mut decoder = NeuralDecoder::new(config).unwrap();
256
257        // Empty syndrome (all zeros)
258        let syndrome = vec![false; 9];
259        let result = decoder.decode(&syndrome);
260
261        // Should succeed even with empty syndrome
262        assert!(result.is_ok());
263    }
264}