ruqu_neural_decoder/
lib.rs1#![deny(missing_docs)]
29#![warn(clippy::all)]
30#![allow(clippy::module_name_repetitions)]
31
32pub mod error;
33pub mod graph;
34pub mod gnn;
35pub mod mamba;
36pub mod fusion;
37
38pub use error::{NeuralDecoderError, Result};
40pub use graph::{DetectorGraph, GraphBuilder, Node, Edge};
41pub use gnn::{GNNEncoder, GNNConfig, AttentionLayer};
42pub use mamba::{MambaDecoder, MambaConfig, MambaState};
43pub use fusion::{FeatureFusion, FusionConfig};
44
45use ndarray::{Array1, Array2};
46use serde::{Deserialize, Serialize};
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct DecoderConfig {
51 pub distance: usize,
53 pub embed_dim: usize,
55 pub hidden_dim: usize,
57 pub num_gnn_layers: usize,
59 pub num_heads: usize,
61 pub mamba_state_dim: usize,
63 pub use_mincut_fusion: bool,
65 pub dropout: f32,
67}
68
69impl Default for DecoderConfig {
70 fn default() -> Self {
71 Self {
72 distance: 5,
73 embed_dim: 64,
74 hidden_dim: 128,
75 num_gnn_layers: 3,
76 num_heads: 4,
77 mamba_state_dim: 64,
78 use_mincut_fusion: false,
79 dropout: 0.1,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
86pub struct Correction {
87 pub x_corrections: Vec<usize>,
89 pub z_corrections: Vec<usize>,
91 pub confidence: f64,
93 pub decode_time_ns: u64,
95}
96
97pub struct NeuralDecoder {
101 config: DecoderConfig,
102 gnn: GNNEncoder,
103 mamba: MambaDecoder,
104 fusion: Option<FeatureFusion>,
105}
106
107impl NeuralDecoder {
108 pub fn new(config: DecoderConfig) -> Result<Self> {
114 let gnn_config = GNNConfig {
115 input_dim: 5, embed_dim: config.embed_dim,
117 hidden_dim: config.hidden_dim,
118 num_layers: config.num_gnn_layers,
119 num_heads: config.num_heads,
120 dropout: config.dropout,
121 };
122
123 let mamba_config = MambaConfig {
124 input_dim: config.hidden_dim,
125 state_dim: config.mamba_state_dim,
126 output_dim: config.distance * config.distance,
127 };
128
129 let fusion = if config.use_mincut_fusion {
130 let fusion_config = FusionConfig {
131 gnn_dim: config.hidden_dim,
132 mincut_dim: 16,
133 output_dim: config.hidden_dim,
134 gnn_weight: 0.5,
135 mincut_weight: 0.3,
136 boundary_weight: 0.2,
137 adaptive_weights: true,
138 temperature: 1.0,
139 };
140 FeatureFusion::new(fusion_config).ok()
141 } else {
142 None
143 };
144
145 Ok(Self {
146 config,
147 gnn: GNNEncoder::new(gnn_config)?,
148 mamba: MambaDecoder::new(mamba_config),
149 fusion,
150 })
151 }
152
153 pub fn decode(&mut self, syndrome: &[bool]) -> Result<Correction> {
155 let start = std::time::Instant::now();
156
157 let graph = GraphBuilder::from_surface_code(self.config.distance)
159 .with_syndrome(syndrome)?
160 .build()?;
161
162 let node_embeddings = self.gnn.encode(&graph)?;
164
165 let fused = node_embeddings;
169
170 let output = self.mamba.decode(&fused)?;
172
173 let corrections = self.output_to_corrections(&output)?;
175
176 let elapsed = start.elapsed();
177
178 Ok(Correction {
179 x_corrections: corrections.0,
180 z_corrections: corrections.1,
181 confidence: corrections.2,
182 decode_time_ns: elapsed.as_nanos() as u64,
183 })
184 }
185
186 fn output_to_corrections(&self, output: &Array1<f32>) -> Result<(Vec<usize>, Vec<usize>, f64)> {
188 let threshold = 0.5;
189 let mut x_corrections = Vec::new();
190
191 for (i, &val) in output.iter().enumerate() {
192 if val > threshold {
193 x_corrections.push(i);
194 }
195 }
196
197 let confidence = if output.is_empty() {
199 0.0
200 } else {
201 output.iter()
202 .map(|&v| (v - 0.5).abs() * 2.0)
203 .sum::<f32>() / output.len() as f32
204 };
205
206 Ok((x_corrections, Vec::new(), confidence as f64))
207 }
208
209 #[must_use]
211 pub fn config(&self) -> &DecoderConfig {
212 &self.config
213 }
214
215 pub fn reset(&mut self) {
217 self.mamba.reset();
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_decoder_config_default() {
227 let config = DecoderConfig::default();
228 assert_eq!(config.distance, 5);
229 assert_eq!(config.embed_dim, 64);
230 assert_eq!(config.hidden_dim, 128);
231 assert!(config.dropout >= 0.0 && config.dropout <= 1.0);
232 }
233
234 #[test]
235 fn test_decoder_creation() {
236 let config = DecoderConfig::default();
237 let decoder = NeuralDecoder::new(config).unwrap();
238 assert_eq!(decoder.config().distance, 5);
239 }
240
241 #[test]
242 fn test_correction_default() {
243 let correction = Correction::default();
244 assert!(correction.x_corrections.is_empty());
245 assert!(correction.z_corrections.is_empty());
246 assert_eq!(correction.confidence, 0.0);
247 }
248
249 #[test]
250 fn test_decoder_empty_syndrome() {
251 let config = DecoderConfig {
252 distance: 3,
253 ..Default::default()
254 };
255 let mut decoder = NeuralDecoder::new(config).unwrap();
256
257 let syndrome = vec![false; 9];
259 let result = decoder.decode(&syndrome);
260
261 assert!(result.is_ok());
263 }
264}