1use crate::error::{NeuralDecoderError, Result};
7use crate::graph::DetectorGraph;
8use ndarray::{Array1, Array2, ArrayView1};
9use rand::Rng;
10use rand_distr::{Distribution, Normal};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct GNNConfig {
16 pub input_dim: usize,
18 pub embed_dim: usize,
20 pub hidden_dim: usize,
22 pub num_layers: usize,
24 pub num_heads: usize,
26 pub dropout: f32,
28}
29
30impl Default for GNNConfig {
31 fn default() -> Self {
32 Self {
33 input_dim: 5,
34 embed_dim: 64,
35 hidden_dim: 128,
36 num_layers: 3,
37 num_heads: 4,
38 dropout: 0.1,
39 }
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct Linear {
46 weights: Array2<f32>,
47 bias: Array1<f32>,
48}
49
50impl Linear {
51 pub fn new(input_dim: usize, output_dim: usize) -> Self {
53 let mut rng = rand::thread_rng();
54 let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
55 let normal = Normal::new(0.0, scale as f64).unwrap();
56
57 let weights = Array2::from_shape_fn(
58 (output_dim, input_dim),
59 |_| normal.sample(&mut rng) as f32
60 );
61 let bias = Array1::zeros(output_dim);
62
63 Self { weights, bias }
64 }
65
66 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
68 let x = ArrayView1::from(input);
69 let output = self.weights.dot(&x) + &self.bias;
70 output.to_vec()
71 }
72
73 pub fn output_dim(&self) -> usize {
75 self.weights.shape()[0]
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct LayerNorm {
82 gamma: Array1<f32>,
83 beta: Array1<f32>,
84 eps: f32,
85}
86
87impl LayerNorm {
88 pub fn new(dim: usize, eps: f32) -> Self {
90 Self {
91 gamma: Array1::ones(dim),
92 beta: Array1::zeros(dim),
93 eps,
94 }
95 }
96
97 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
99 let x = ArrayView1::from(input);
100 let mean = x.mean().unwrap_or(0.0);
101 let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
102
103 let normalized = x.mapv(|v| (v - mean) / (variance + self.eps).sqrt());
104 let output = &self.gamma * &normalized + &self.beta;
105 output.to_vec()
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct AttentionLayer {
112 num_heads: usize,
113 head_dim: usize,
114 q_linear: Linear,
115 k_linear: Linear,
116 v_linear: Linear,
117 out_linear: Linear,
118 norm: LayerNorm,
119}
120
121impl AttentionLayer {
122 pub fn new(embed_dim: usize, num_heads: usize) -> Result<Self> {
124 if embed_dim % num_heads != 0 {
125 return Err(NeuralDecoderError::attention_heads(embed_dim, num_heads));
126 }
127
128 let head_dim = embed_dim / num_heads;
129
130 Ok(Self {
131 num_heads,
132 head_dim,
133 q_linear: Linear::new(embed_dim, embed_dim),
134 k_linear: Linear::new(embed_dim, embed_dim),
135 v_linear: Linear::new(embed_dim, embed_dim),
136 out_linear: Linear::new(embed_dim, embed_dim),
137 norm: LayerNorm::new(embed_dim, 1e-5),
138 })
139 }
140
141 pub fn forward(&self, query: &[f32], keys: &[Vec<f32>], values: &[Vec<f32>]) -> Vec<f32> {
143 if keys.is_empty() || values.is_empty() {
144 return self.norm.forward(query);
145 }
146
147 let q = self.q_linear.forward(query);
149 let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
150 let v: Vec<Vec<f32>> = values.iter().map(|v| self.v_linear.forward(v)).collect();
151
152 let q_heads = self.split_heads(&q);
154 let k_heads: Vec<Vec<Vec<f32>>> = k.iter().map(|kv| self.split_heads(kv)).collect();
155 let v_heads: Vec<Vec<Vec<f32>>> = v.iter().map(|vv| self.split_heads(vv)).collect();
156
157 let mut head_outputs = Vec::new();
158 for h in 0..self.num_heads {
159 let q_h = &q_heads[h];
160 let k_h: Vec<&Vec<f32>> = k_heads.iter().map(|heads| &heads[h]).collect();
161 let v_h: Vec<&Vec<f32>> = v_heads.iter().map(|heads| &heads[h]).collect();
162
163 let head_output = self.scaled_dot_product_attention(q_h, &k_h, &v_h);
164 head_outputs.push(head_output);
165 }
166
167 let concat: Vec<f32> = head_outputs.into_iter().flatten().collect();
169
170 let projected = self.out_linear.forward(&concat);
172 let residual: Vec<f32> = query.iter().zip(projected.iter())
173 .map(|(q, p)| q + p)
174 .collect();
175
176 self.norm.forward(&residual)
177 }
178
179 fn split_heads(&self, x: &[f32]) -> Vec<Vec<f32>> {
181 (0..self.num_heads)
182 .map(|h| {
183 let start = h * self.head_dim;
184 let end = start + self.head_dim;
185 x[start..end].to_vec()
186 })
187 .collect()
188 }
189
190 fn scaled_dot_product_attention(
192 &self,
193 query: &[f32],
194 keys: &[&Vec<f32>],
195 values: &[&Vec<f32>],
196 ) -> Vec<f32> {
197 if keys.is_empty() {
198 return query.to_vec();
199 }
200
201 let scale = (self.head_dim as f32).sqrt();
202
203 let scores: Vec<f32> = keys
205 .iter()
206 .map(|k| {
207 let dot: f32 = query.iter().zip(k.iter()).map(|(q, k)| q * k).sum();
208 dot / scale
209 })
210 .collect();
211
212 let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
214 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
215 let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
216 let weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
217
218 let mut output = vec![0.0; self.head_dim];
220 for (weight, value) in weights.iter().zip(values.iter()) {
221 for (out, &val) in output.iter_mut().zip(value.iter()) {
222 *out += weight * val;
223 }
224 }
225
226 output
227 }
228
229 pub fn attention_scores(&self, query: &[f32], keys: &[Vec<f32>]) -> Vec<f32> {
231 if keys.is_empty() {
232 return Vec::new();
233 }
234
235 let q = self.q_linear.forward(query);
236 let k: Vec<Vec<f32>> = keys.iter().map(|k| self.k_linear.forward(k)).collect();
237
238 let scale = (self.head_dim as f32).sqrt() * (self.num_heads as f32);
239
240 let scores: Vec<f32> = k
241 .iter()
242 .map(|kv| {
243 let dot: f32 = q.iter().zip(kv.iter()).map(|(q, k)| q * k).sum();
244 dot / scale
245 })
246 .collect();
247
248 let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
250 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
251 let sum_exp: f32 = exp_scores.iter().sum::<f32>().max(1e-10);
252 exp_scores.iter().map(|&e| e / sum_exp).collect()
253 }
254}
255
256#[derive(Debug, Clone)]
258pub struct GNNEncoder {
259 config: GNNConfig,
260 input_projection: Linear,
261 layers: Vec<AttentionLayer>,
262 output_projection: Linear,
263}
264
265impl GNNEncoder {
266 pub fn new(config: GNNConfig) -> Result<Self> {
272 if config.embed_dim % config.num_heads != 0 {
274 return Err(NeuralDecoderError::attention_heads(
275 config.embed_dim,
276 config.num_heads,
277 ));
278 }
279
280 let input_projection = Linear::new(config.input_dim, config.embed_dim);
281
282 let layers: Vec<AttentionLayer> = (0..config.num_layers)
283 .map(|_| AttentionLayer::new(config.embed_dim, config.num_heads))
284 .collect::<Result<Vec<_>>>()?;
285
286 let output_projection = Linear::new(config.embed_dim, config.hidden_dim);
287
288 Ok(Self {
289 config,
290 input_projection,
291 layers,
292 output_projection,
293 })
294 }
295
296 pub fn encode(&self, graph: &DetectorGraph) -> Result<Array2<f32>> {
298 if graph.nodes.is_empty() {
299 return Err(NeuralDecoderError::EmptyGraph);
300 }
301
302 let num_nodes = graph.num_nodes();
303
304 let mut embeddings: Vec<Vec<f32>> = graph.nodes
306 .iter()
307 .map(|n| self.input_projection.forward(&n.features))
308 .collect();
309
310 for layer in &self.layers {
312 let mut new_embeddings = Vec::with_capacity(num_nodes);
313
314 for (node_id, embedding) in embeddings.iter().enumerate() {
315 let neighbor_ids = graph.neighbors(node_id)
317 .map(|v| v.as_slice())
318 .unwrap_or(&[]);
319
320 let neighbor_embeddings: Vec<Vec<f32>> = neighbor_ids
321 .iter()
322 .filter_map(|&nid| embeddings.get(nid).cloned())
323 .collect();
324
325 let updated = layer.forward(embedding, &neighbor_embeddings, &neighbor_embeddings);
327 new_embeddings.push(updated);
328 }
329
330 embeddings = new_embeddings;
331 }
332
333 let output_embeddings: Vec<Vec<f32>> = embeddings
335 .iter()
336 .map(|e| self.output_projection.forward(e))
337 .collect();
338
339 let mut result = Array2::zeros((num_nodes, self.config.hidden_dim));
341 for (i, emb) in output_embeddings.iter().enumerate() {
342 for (j, &val) in emb.iter().enumerate() {
343 result[[i, j]] = val;
344 }
345 }
346
347 Ok(result)
348 }
349
350 pub fn get_intermediate_embeddings(&self, graph: &DetectorGraph, layer_idx: usize) -> Result<Vec<Vec<f32>>> {
352 if graph.nodes.is_empty() {
353 return Err(NeuralDecoderError::EmptyGraph);
354 }
355
356 let num_nodes = graph.num_nodes();
357 let layer_count = layer_idx.min(self.layers.len());
358
359 let mut embeddings: Vec<Vec<f32>> = graph.nodes
361 .iter()
362 .map(|n| self.input_projection.forward(&n.features))
363 .collect();
364
365 for layer in self.layers.iter().take(layer_count) {
367 let mut new_embeddings = Vec::with_capacity(num_nodes);
368
369 for (node_id, embedding) in embeddings.iter().enumerate() {
370 let neighbor_ids = graph.neighbors(node_id)
371 .map(|v| v.as_slice())
372 .unwrap_or(&[]);
373
374 let neighbor_embeddings: Vec<Vec<f32>> = neighbor_ids
375 .iter()
376 .filter_map(|&nid| embeddings.get(nid).cloned())
377 .collect();
378
379 let updated = layer.forward(embedding, &neighbor_embeddings, &neighbor_embeddings);
380 new_embeddings.push(updated);
381 }
382
383 embeddings = new_embeddings;
384 }
385
386 Ok(embeddings)
387 }
388
389 pub fn config(&self) -> &GNNConfig {
391 &self.config
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::graph::GraphBuilder;
399
400 #[test]
401 fn test_gnn_config_default() {
402 let config = GNNConfig::default();
403 assert_eq!(config.input_dim, 5);
404 assert_eq!(config.embed_dim, 64);
405 assert_eq!(config.num_heads, 4);
406 }
407
408 #[test]
409 fn test_linear_forward() {
410 let linear = Linear::new(4, 8);
411 let input = vec![1.0, 2.0, 3.0, 4.0];
412 let output = linear.forward(&input);
413 assert_eq!(output.len(), 8);
414 }
415
416 #[test]
417 fn test_layer_norm() {
418 let norm = LayerNorm::new(4, 1e-5);
419 let input = vec![1.0, 2.0, 3.0, 4.0];
420 let output = norm.forward(&input);
421 assert_eq!(output.len(), 4);
422
423 let mean: f32 = output.iter().sum::<f32>() / output.len() as f32;
425 assert!(mean.abs() < 1e-5);
426 }
427
428 #[test]
429 fn test_attention_layer_creation() {
430 let layer = AttentionLayer::new(64, 4);
431 assert!(layer.is_ok());
432
433 let layer = AttentionLayer::new(64, 3);
435 assert!(layer.is_err());
436 }
437
438 #[test]
439 fn test_attention_forward() {
440 let layer = AttentionLayer::new(8, 2).unwrap();
441 let query = vec![0.5; 8];
442 let keys = vec![vec![0.3; 8], vec![0.7; 8]];
443 let values = vec![vec![0.2; 8], vec![0.8; 8]];
444
445 let output = layer.forward(&query, &keys, &values);
446 assert_eq!(output.len(), 8);
447 }
448
449 #[test]
450 fn test_attention_empty_neighbors() {
451 let layer = AttentionLayer::new(8, 2).unwrap();
452 let query = vec![0.5; 8];
453 let keys: Vec<Vec<f32>> = vec![];
454 let values: Vec<Vec<f32>> = vec![];
455
456 let output = layer.forward(&query, &keys, &values);
457 assert_eq!(output.len(), 8);
458 }
459
460 #[test]
461 fn test_attention_scores() {
462 let layer = AttentionLayer::new(8, 2).unwrap();
463 let query = vec![0.5; 8];
464 let keys = vec![vec![0.3; 8], vec![0.7; 8]];
465
466 let scores = layer.attention_scores(&query, &keys);
467 assert_eq!(scores.len(), 2);
468
469 let sum: f32 = scores.iter().sum();
471 assert!((sum - 1.0).abs() < 1e-5);
472 }
473
474 #[test]
475 fn test_gnn_encoder_creation() {
476 let config = GNNConfig::default();
477 let encoder = GNNEncoder::new(config).unwrap();
478 assert_eq!(encoder.config().num_layers, 3);
479 }
480
481 #[test]
482 fn test_gnn_encode_small_graph() {
483 let config = GNNConfig {
484 input_dim: 5,
485 embed_dim: 16,
486 hidden_dim: 32,
487 num_layers: 2,
488 num_heads: 4,
489 dropout: 0.0,
490 };
491 let encoder = GNNEncoder::new(config).unwrap();
492
493 let graph = GraphBuilder::from_surface_code(3)
494 .build()
495 .unwrap();
496
497 let embeddings = encoder.encode(&graph).unwrap();
498 assert_eq!(embeddings.shape(), &[9, 32]);
499 }
500
501 #[test]
502 fn test_gnn_encode_with_syndrome() {
503 let config = GNNConfig {
504 input_dim: 5,
505 embed_dim: 16,
506 hidden_dim: 32,
507 num_layers: 2,
508 num_heads: 4,
509 dropout: 0.0,
510 };
511 let encoder = GNNEncoder::new(config).unwrap();
512
513 let syndrome = vec![true, false, true, false, false, false, true, false, false];
514 let graph = GraphBuilder::from_surface_code(3)
515 .with_syndrome(&syndrome)
516 .unwrap()
517 .build()
518 .unwrap();
519
520 let embeddings = encoder.encode(&graph).unwrap();
521 assert_eq!(embeddings.shape(), &[9, 32]);
522 }
523
524 #[test]
525 fn test_gnn_encode_empty_graph() {
526 let config = GNNConfig::default();
527 let encoder = GNNEncoder::new(config).unwrap();
528
529 let graph = crate::graph::DetectorGraph::new(3);
530 let result = encoder.encode(&graph);
531 assert!(result.is_err());
532 }
533
534 #[test]
535 fn test_intermediate_embeddings() {
536 let config = GNNConfig {
537 input_dim: 5,
538 embed_dim: 16,
539 hidden_dim: 32,
540 num_layers: 3,
541 num_heads: 4,
542 dropout: 0.0,
543 };
544 let encoder = GNNEncoder::new(config).unwrap();
545
546 let graph = GraphBuilder::from_surface_code(3)
547 .build()
548 .unwrap();
549
550 let layer0 = encoder.get_intermediate_embeddings(&graph, 0).unwrap();
552 let layer1 = encoder.get_intermediate_embeddings(&graph, 1).unwrap();
553 let layer2 = encoder.get_intermediate_embeddings(&graph, 2).unwrap();
554
555 assert_eq!(layer0.len(), 9);
556 assert_eq!(layer1.len(), 9);
557 assert_eq!(layer2.len(), 9);
558
559 assert_eq!(layer0[0].len(), 16);
561 assert_eq!(layer1[0].len(), 16);
562 assert_eq!(layer2[0].len(), 16);
563 }
564
565 #[test]
566 fn test_gnn_deterministic_structure() {
567 let config = GNNConfig {
569 input_dim: 5,
570 embed_dim: 16,
571 hidden_dim: 32,
572 num_layers: 2,
573 num_heads: 4,
574 dropout: 0.0,
575 };
576 let encoder = GNNEncoder::new(config).unwrap();
577
578 let syndrome1 = vec![true, false, false, false, false, false, false, false, false];
579 let syndrome2 = vec![false, false, false, false, true, false, false, false, false];
580
581 let graph1 = GraphBuilder::from_surface_code(3)
582 .with_syndrome(&syndrome1)
583 .unwrap()
584 .build()
585 .unwrap();
586
587 let graph2 = GraphBuilder::from_surface_code(3)
588 .with_syndrome(&syndrome2)
589 .unwrap()
590 .build()
591 .unwrap();
592
593 let emb1 = encoder.encode(&graph1).unwrap();
594 let emb2 = encoder.encode(&graph2).unwrap();
595
596 let diff: f32 = (emb1.clone() - emb2.clone())
598 .iter()
599 .map(|x| x.abs())
600 .sum();
601 assert!(diff > 0.0);
602 }
603}