ruqu_neural_decoder/
gnn.rs

1//! GNN Encoder for Syndrome Graphs
2//!
3//! Implements graph neural network layers for encoding detector graphs
4//! into fixed-dimensional representations suitable for the Mamba decoder.
5
6use crate::error::{NeuralDecoderError, Result};
7use crate::graph::DetectorGraph;
8use ndarray::{Array1, Array2, ArrayView1};
9use rand::Rng;
10use rand_distr::{Distribution, Normal};
11use serde::{Deserialize, Serialize};
12
13/// Configuration for the GNN encoder
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GNNConfig {
16    /// Input feature dimension
17    pub input_dim: usize,
18    /// Embedding dimension
19    pub embed_dim: usize,
20    /// Hidden dimension
21    pub hidden_dim: usize,
22    /// Number of GNN layers
23    pub num_layers: usize,
24    /// Number of attention heads
25    pub num_heads: usize,
26    /// Dropout rate
27    pub dropout: f32,
28}
29
30impl Default for GNNConfig {
31    fn default() -> Self {
32        Self {
33            input_dim: 5,
34            embed_dim: 64,
35            hidden_dim: 128,
36            num_layers: 3,
37            num_heads: 4,
38            dropout: 0.1,
39        }
40    }
41}
42
43/// Linear layer for projections
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct Linear {
46    weights: Array2<f32>,
47    bias: Array1<f32>,
48}
49
50impl Linear {
51    /// Create a new linear layer with Xavier initialization
52    pub fn new(input_dim: usize, output_dim: usize) -> Self {
53        let mut rng = rand::thread_rng();
54        let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
55        let normal = Normal::new(0.0, scale as f64).unwrap();
56
57        let weights = Array2::from_shape_fn(
58            (output_dim, input_dim),
59            |_| normal.sample(&mut rng) as f32
60        );
61        let bias = Array1::zeros(output_dim);
62
63        Self { weights, bias }
64    }
65
66    /// Forward pass
67    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
68        let x = ArrayView1::from(input);
69        let output = self.weights.dot(&x) + &self.bias;
70        output.to_vec()
71    }
72
73    /// Get output dimension
74    pub fn output_dim(&self) -> usize {
75        self.weights.shape()[0]
76    }
77}
78
79/// Layer normalization
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct LayerNorm {
82    gamma: Array1<f32>,
83    beta: Array1<f32>,
84    eps: f32,
85}
86
87impl LayerNorm {
88    /// Create new layer normalization
89    pub fn new(dim: usize, eps: f32) -> Self {
90        Self {
91            gamma: Array1::ones(dim),
92            beta: Array1::zeros(dim),
93            eps,
94        }
95    }
96
97    /// Forward pass
98    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
99        let x = ArrayView1::from(input);
100        let mean = x.mean().unwrap_or(0.0);
101        let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
102
103        let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
104        let output = &self.gamma * &normalized + &self.beta;
105        output.to_vec()
106    }
107}
108
109/// Multi-head attention layer for graph attention
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct AttentionLayer {
112    num_heads: usize,
113    head_dim: usize,
114    q_linear: Linear,
115    k_linear: Linear,
116    v_linear: Linear,
117    out_linear: Linear,
118    norm: LayerNorm,
119}
120
121impl AttentionLayer {
122    /// Create a new attention layer
123    pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self> {
124        if embed_dim % num_heads != 0 {
125            return Err(NeuralDecoderError::attention_heads(embed_dim, num_heads));
126        }
127
128        let head_dim = embed_dim / num_heads;
129
130        Ok(Self {
131            num_heads,
132            head_dim,
133            q_linear: Linear::new(embed_dim, embed_dim),
134            k_linear: Linear::new(embed_dim, embed_dim),
135            v_linear: Linear::new(embed_dim, embed_dim),
136            out_linear: Linear::new(embed_dim, embed_dim),
137            norm: LayerNorm::new(embed_dim, 1e-5),
138        })
139    }
140
141    /// Forward pass with attention
142    pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
143        if keys.is_empty() || values.is_empty() {
144            return self.norm.forward(query);
145        }
146
147        // Project query, keys, values
148        let q = self.q_linear.forward(query);
149        let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
150        let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
151
152        // Multi-head attention
153        let q_heads = self.split_heads(&q);
154        let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|kv| self.split_heads(kv)).collect();
155        let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|vv| self.split_heads(vv)).collect();
156
157        let mut head_outputs = Vec::new();
158        for h in 0..self.num_heads {
159            let q_h = &q_heads[h];
160            let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
161            let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
162
163            let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
164            head_outputs.push(head_output);
165        }
166
167        // Concatenate heads
168        let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
169
170        // Output projection and residual
171        let projected = self.out_linear.forward(&concat);
172        let residual: Vec<f32> = query.iter().zip(projected.iter())
173            .map(|(q, p)| q + p)
174            .collect();
175
176        self.norm.forward(&residual)
177    }
178
179    /// Split vector into heads
180    fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
181        (0..self.num_heads)
182            .map(|h| {
183                let start = h * self.head_dim;
184                let end = start + self.head_dim;
185                x[start..end].to_vec()
186            })
187            .collect()
188    }
189
190    /// Scaled dot-product attention
191    fn scaled_dot_product_attention(
192        &self,
193        query: &[f32],
194        keys: &[&Vec<f32>],
195        values: &[&Vec<f32>],
196    ) -> Vec<f32> {
197        if keys.is_empty() {
198            return query.to_vec();
199        }
200
201        let scale = (self.head_dim as f32).sqrt();
202
203        // Compute scores
204        let scores: Vec<f32> = keys
205            .iter()
206            .map(|k| {
207                let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
208                dot / scale
209            })
210            .collect();
211
212        // Softmax
213        let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
214        let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
215        let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
216        let weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
217
218        // Weighted sum
219        let mut output = vec![0.0; self.head_dim];
220        for (weight, value) in weights.iter().zip(values.iter()) {
221            for (out, &val) in output.iter_mut().zip(value.iter()) {
222                *out += weight * val;
223            }
224        }
225
226        output
227    }
228
229    /// Get attention scores (for interpretation)
230    pub fn attention_scores(&self, query: &[f32], keys: &[Vec<f32>]) -> Vec<f32> {
231        if keys.is_empty() {
232            return Vec::new();
233        }
234
235        let q = self.q_linear.forward(query);
236        let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
237
238        let scale = (self.head_dim as f32).sqrt() * (self.num_heads as f32);
239
240        let scores: Vec<f32> = k
241            .iter()
242            .map(|kv| {
243                let dot: f32 = q.iter().zip(kv.iter()).map(|(q, k)| q * k).sum();
244                dot / scale
245            })
246            .collect();
247
248        // Softmax
249        let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
250        let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
251        let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
252        exp_scores.iter().map(|&e| e / sum_exp).collect()
253    }
254}
255
256/// GNN Encoder for syndrome graphs
257#[derive(Debug, Clone)]
258pub struct GNNEncoder {
259    config: GNNConfig,
260    input_projection: Linear,
261    layers: Vec<AttentionLayer>,
262    output_projection: Linear,
263}
264
265impl GNNEncoder {
266    /// Create a new GNN encoder
267    ///
268    /// # Errors
269    ///
270    /// Returns an error if `embed_dim` is not divisible by `num_heads`.
271    pub fn new(config: GNNConfig) -> Result<Self> {
272        // Validate config before creating layers
273        if config.embed_dim % config.num_heads != 0 {
274            return Err(NeuralDecoderError::attention_heads(
275                config.embed_dim,
276                config.num_heads,
277            ));
278        }
279
280        let input_projection = Linear::new(config.input_dim, config.embed_dim);
281
282        let layers: Vec<AttentionLayer> = (0..config.num_layers)
283            .map(|_| AttentionLayer::new(config.embed_dim, config.num_heads))
284            .collect::<Result<Vec<_>>>()?;
285
286        let output_projection = Linear::new(config.embed_dim, config.hidden_dim);
287
288        Ok(Self {
289            config,
290            input_projection,
291            layers,
292            output_projection,
293        })
294    }
295
296    /// Encode a detector graph
297    pub fn encode(&self, graph: &DetectorGraph) -> Result<Array2<f32>> {
298        if graph.nodes.is_empty() {
299            return Err(NeuralDecoderError::EmptyGraph);
300        }
301
302        let num_nodes = graph.num_nodes();
303
304        // Project input features
305        let mut embeddings: Vec<Vec<f32>> = graph.nodes
306            .iter()
307            .map(|n| self.input_projection.forward(&n.features))
308            .collect();
309
310        // Message passing layers
311        for layer in &self.layers {
312            let mut new_embeddings = Vec::with_capacity(num_nodes);
313
314            for (node_id, embedding) in embeddings.iter().enumerate() {
315                // Get neighbor embeddings
316                let neighbor_ids = graph.neighbors(node_id)
317                    .map(|v| v.as_slice())
318                    .unwrap_or(&[]);
319
320                let neighbor_embeddings: Vec<Vec<f32>> = neighbor_ids
321                    .iter()
322                    .filter_map(|&nid| embeddings.get(nid).cloned())
323                    .collect();
324
325                // Apply attention
326                let updated = layer.forward(embedding, &neighbor_embeddings, &neighbor_embeddings);
327                new_embeddings.push(updated);
328            }
329
330            embeddings = new_embeddings;
331        }
332
333        // Output projection
334        let output_embeddings: Vec<Vec<f32>> = embeddings
335            .iter()
336            .map(|e| self.output_projection.forward(e))
337            .collect();
338
339        // Convert to array
340        let mut result = Array2::zeros((num_nodes, self.config.hidden_dim));
341        for (i, emb) in output_embeddings.iter().enumerate() {
342            for (j, &val) in emb.iter().enumerate() {
343                result[[i, j]] = val;
344            }
345        }
346
347        Ok(result)
348    }
349
350    /// Get node embeddings without output projection (for debugging)
351    pub fn get_intermediate_embeddings(&self, graph: &DetectorGraph, layer_idx: usize) -> Result<Vec<Vec<f32>>> {
352        if graph.nodes.is_empty() {
353            return Err(NeuralDecoderError::EmptyGraph);
354        }
355
356        let num_nodes = graph.num_nodes();
357        let layer_count = layer_idx.min(self.layers.len());
358
359        // Project input features
360        let mut embeddings: Vec<Vec<f32>> = graph.nodes
361            .iter()
362            .map(|n| self.input_projection.forward(&n.features))
363            .collect();
364
365        // Message passing layers up to layer_idx
366        for layer in self.layers.iter().take(layer_count) {
367            let mut new_embeddings = Vec::with_capacity(num_nodes);
368
369            for (node_id, embedding) in embeddings.iter().enumerate() {
370                let neighbor_ids = graph.neighbors(node_id)
371                    .map(|v| v.as_slice())
372                    .unwrap_or(&[]);
373
374                let neighbor_embeddings: Vec<Vec<f32>> = neighbor_ids
375                    .iter()
376                    .filter_map(|&nid| embeddings.get(nid).cloned())
377                    .collect();
378
379                let updated = layer.forward(embedding, &neighbor_embeddings, &neighbor_embeddings);
380                new_embeddings.push(updated);
381            }
382
383            embeddings = new_embeddings;
384        }
385
386        Ok(embeddings)
387    }
388
389    /// Get the configuration
390    pub fn config(&self) -> &GNNConfig {
391        &self.config
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::graph::GraphBuilder;
399
400    #[test]
401    fn test_gnn_config_default() {
402        let config = GNNConfig::default();
403        assert_eq!(config.input_dim, 5);
404        assert_eq!(config.embed_dim, 64);
405        assert_eq!(config.num_heads, 4);
406    }
407
408    #[test]
409    fn test_linear_forward() {
410        let linear = Linear::new(4, 8);
411        let input = vec![1.0, 2.0, 3.0, 4.0];
412        let output = linear.forward(&input);
413        assert_eq!(output.len(), 8);
414    }
415
416    #[test]
417    fn test_layer_norm() {
418        let norm = LayerNorm::new(4, 1e-5);
419        let input = vec![1.0, 2.0, 3.0, 4.0];
420        let output = norm.forward(&input);
421        assert_eq!(output.len(), 4);
422
423        // Check zero mean (approximately)
424        let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
425        assert!(mean.abs() < 1e-5);
426    }
427
428    #[test]
429    fn test_attention_layer_creation() {
430        let layer = AttentionLayer::new(64, 4);
431        assert!(layer.is_ok());
432
433        // Invalid: embed_dim not divisible by num_heads
434        let layer = AttentionLayer::new(64, 3);
435        assert!(layer.is_err());
436    }
437
438    #[test]
439    fn test_attention_forward() {
440        let layer = AttentionLayer::new(8, 2).unwrap();
441        let query = vec![0.5; 8];
442        let keys = vec![vec![0.3; 8], vec![0.7; 8]];
443        let values = vec![vec![0.2; 8], vec![0.8; 8]];
444
445        let output = layer.forward(&query, &keys, &values);
446        assert_eq!(output.len(), 8);
447    }
448
449    #[test]
450    fn test_attention_empty_neighbors() {
451        let layer = AttentionLayer::new(8, 2).unwrap();
452        let query = vec![0.5; 8];
453        let keys: Vec<Vec<f32>> = vec![];
454        let values: Vec<Vec<f32>> = vec![];
455
456        let output = layer.forward(&query, &keys, &values);
457        assert_eq!(output.len(), 8);
458    }
459
460    #[test]
461    fn test_attention_scores() {
462        let layer = AttentionLayer::new(8, 2).unwrap();
463        let query = vec![0.5; 8];
464        let keys = vec![vec![0.3; 8], vec![0.7; 8]];
465
466        let scores = layer.attention_scores(&query, &keys);
467        assert_eq!(scores.len(), 2);
468
469        // Scores should sum to 1.0
470        let sum: f32 = scores.iter().sum();
471        assert!((sum - 1.0).abs() < 1e-5);
472    }
473
474    #[test]
475    fn test_gnn_encoder_creation() {
476        let config = GNNConfig::default();
477        let encoder = GNNEncoder::new(config).unwrap();
478        assert_eq!(encoder.config().num_layers, 3);
479    }
480
481    #[test]
482    fn test_gnn_encode_small_graph() {
483        let config = GNNConfig {
484            input_dim: 5,
485            embed_dim: 16,
486            hidden_dim: 32,
487            num_layers: 2,
488            num_heads: 4,
489            dropout: 0.0,
490        };
491        let encoder = GNNEncoder::new(config).unwrap();
492
493        let graph = GraphBuilder::from_surface_code(3)
494            .build()
495            .unwrap();
496
497        let embeddings = encoder.encode(&graph).unwrap();
498        assert_eq!(embeddings.shape(), &[9, 32]);
499    }
500
501    #[test]
502    fn test_gnn_encode_with_syndrome() {
503        let config = GNNConfig {
504            input_dim: 5,
505            embed_dim: 16,
506            hidden_dim: 32,
507            num_layers: 2,
508            num_heads: 4,
509            dropout: 0.0,
510        };
511        let encoder = GNNEncoder::new(config).unwrap();
512
513        let syndrome = vec![true, false, true, false, false, false, true, false, false];
514        let graph = GraphBuilder::from_surface_code(3)
515            .with_syndrome(&syndrome)
516            .unwrap()
517            .build()
518            .unwrap();
519
520        let embeddings = encoder.encode(&graph).unwrap();
521        assert_eq!(embeddings.shape(), &[9, 32]);
522    }
523
524    #[test]
525    fn test_gnn_encode_empty_graph() {
526        let config = GNNConfig::default();
527        let encoder = GNNEncoder::new(config).unwrap();
528
529        let graph = crate::graph::DetectorGraph::new(3);
530        let result = encoder.encode(&graph);
531        assert!(result.is_err());
532    }
533
534    #[test]
535    fn test_intermediate_embeddings() {
536        let config = GNNConfig {
537            input_dim: 5,
538            embed_dim: 16,
539            hidden_dim: 32,
540            num_layers: 3,
541            num_heads: 4,
542            dropout: 0.0,
543        };
544        let encoder = GNNEncoder::new(config).unwrap();
545
546        let graph = GraphBuilder::from_surface_code(3)
547            .build()
548            .unwrap();
549
550        // Get embeddings at different layers
551        let layer0 = encoder.get_intermediate_embeddings(&graph, 0).unwrap();
552        let layer1 = encoder.get_intermediate_embeddings(&graph, 1).unwrap();
553        let layer2 = encoder.get_intermediate_embeddings(&graph, 2).unwrap();
554
555        assert_eq!(layer0.len(), 9);
556        assert_eq!(layer1.len(), 9);
557        assert_eq!(layer2.len(), 9);
558
559        // Each embedding should have embed_dim dimensions
560        assert_eq!(layer0[0].len(), 16);
561        assert_eq!(layer1[0].len(), 16);
562        assert_eq!(layer2[0].len(), 16);
563    }
564
565    #[test]
566    fn test_gnn_deterministic_structure() {
567        // Test that different syndromes produce different embeddings
568        let config = GNNConfig {
569            input_dim: 5,
570            embed_dim: 16,
571            hidden_dim: 32,
572            num_layers: 2,
573            num_heads: 4,
574            dropout: 0.0,
575        };
576        let encoder = GNNEncoder::new(config).unwrap();
577
578        let syndrome1 = vec![true, false, false, false, false, false, false, false, false];
579        let syndrome2 = vec![false, false, false, false, true, false, false, false, false];
580
581        let graph1 = GraphBuilder::from_surface_code(3)
582            .with_syndrome(&syndrome1)
583            .unwrap()
584            .build()
585            .unwrap();
586
587        let graph2 = GraphBuilder::from_surface_code(3)
588            .with_syndrome(&syndrome2)
589            .unwrap()
590            .build()
591            .unwrap();
592
593        let emb1 = encoder.encode(&graph1).unwrap();
594        let emb2 = encoder.encode(&graph2).unwrap();
595
596        // Embeddings should differ
597        let diff: f32 = (emb1.clone() - emb2.clone())
598            .iter()
599            .map(|x| x.abs())
600            .sum();
601        assert!(diff > 0.0);
602    }
603}