1use crate::error::{NeuralDecoderError, Result};
13use ndarray::{Array1, Array2};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Node {
20 pub id: usize,
22 pub row: usize,
24 pub col: usize,
26 pub fired: bool,
28 pub node_type: NodeType,
30 pub features: Vec<f32>,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum NodeType {
37 XStabilizer,
39 ZStabilizer,
41 Boundary,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Edge {
48 pub from: usize,
50 pub to: usize,
52 pub weight: f32,
54 pub edge_type: EdgeType,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum EdgeType {
61 Horizontal,
63 Vertical,
65 Temporal,
67 Boundary,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct DetectorGraph {
74 pub nodes: Vec<Node>,
76 pub edges: Vec<Edge>,
78 adjacency: HashMap<usize, Vec<usize>>,
80 pub distance: usize,
82 pub num_fired: usize,
84}
85
86impl DetectorGraph {
87 pub fn new(distance: usize) -> Self {
89 Self {
90 nodes: Vec::new(),
91 edges: Vec::new(),
92 adjacency: HashMap::new(),
93 distance,
94 num_fired: 0,
95 }
96 }
97
98 pub fn add_node(&mut self, node: Node) {
100 let id = node.id;
101 if node.fired {
102 self.num_fired += 1;
103 }
104 self.nodes.push(node);
105 self.adjacency.entry(id).or_default();
106 }
107
108 pub fn add_edge(&mut self, edge: Edge) {
110 self.adjacency.entry(edge.from).or_default().push(edge.to);
111 self.adjacency.entry(edge.to).or_default().push(edge.from);
112 self.edges.push(edge);
113 }
114
115 pub fn neighbors(&self, node_id: usize) -> Option<&Vec<usize>> {
117 self.adjacency.get(&node_id)
118 }
119
120 pub fn node_features(&self) -> Array2<f32> {
122 if self.nodes.is_empty() {
123 return Array2::zeros((0, 1));
124 }
125
126 let feature_dim = self.nodes[0].features.len();
127 let mut features = Array2::zeros((self.nodes.len(), feature_dim));
128
129 for (i, node) in self.nodes.iter().enumerate() {
130 for (j, &f) in node.features.iter().enumerate() {
131 features[[i, j]] = f;
132 }
133 }
134
135 features
136 }
137
138 pub fn adjacency_matrix(&self) -> Array2<f32> {
140 let n = self.nodes.len();
141 let mut adj = Array2::zeros((n, n));
142
143 for edge in &self.edges {
144 adj[[edge.from, edge.to]] = edge.weight;
145 adj[[edge.to, edge.from]] = edge.weight;
146 }
147
148 adj
149 }
150
151 pub fn edge_weights(&self) -> Array1<f32> {
153 Array1::from_iter(self.edges.iter().map(|e| e.weight))
154 }
155
156 pub fn fired_indices(&self) -> Vec<usize> {
158 self.nodes
159 .iter()
160 .filter(|n| n.fired)
161 .map(|n| n.id)
162 .collect()
163 }
164
165 pub fn validate(&self) -> Result<()> {
167 if self.nodes.is_empty() {
168 return Err(NeuralDecoderError::EmptyGraph);
169 }
170
171 for edge in &self.edges {
173 if edge.from >= self.nodes.len() || edge.to >= self.nodes.len() {
174 return Err(NeuralDecoderError::InvalidDetector(
175 edge.from.max(edge.to)
176 ));
177 }
178 }
179
180 Ok(())
181 }
182
183 pub fn num_nodes(&self) -> usize {
185 self.nodes.len()
186 }
187
188 pub fn num_edges(&self) -> usize {
190 self.edges.len()
191 }
192}
193
194pub struct GraphBuilder {
196 distance: usize,
197 syndrome: Option<Vec<bool>>,
198 node_type_pattern: NodeTypePattern,
199 error_rate: f64,
200}
201
202#[derive(Debug, Clone, Copy)]
204pub enum NodeTypePattern {
205 Checkerboard,
207 AllX,
209 AllZ,
211}
212
213impl GraphBuilder {
214 pub fn from_surface_code(distance: usize) -> Self {
216 Self {
217 distance,
218 syndrome: None,
219 node_type_pattern: NodeTypePattern::Checkerboard,
220 error_rate: 0.001,
221 }
222 }
223
224 pub fn with_syndrome(mut self, syndrome: &[bool]) -> Result<Self> {
226 let expected = self.distance * self.distance;
227 if syndrome.len() != expected {
228 return Err(NeuralDecoderError::syndrome_dim(
229 self.distance,
230 syndrome.len(),
231 1,
232 ));
233 }
234 self.syndrome = Some(syndrome.to_vec());
235 Ok(self)
236 }
237
238 pub fn with_node_pattern(mut self, pattern: NodeTypePattern) -> Self {
240 self.node_type_pattern = pattern;
241 self
242 }
243
244 pub fn with_error_rate(mut self, rate: f64) -> Self {
246 self.error_rate = rate;
247 self
248 }
249
250 pub fn build(self) -> Result<DetectorGraph> {
252 let d = self.distance;
253 let mut graph = DetectorGraph::new(d);
254
255 let syndrome = self.syndrome.unwrap_or_else(|| vec![false; d * d]);
257
258 for row in 0..d {
260 for col in 0..d {
261 let id = row * d + col;
262 let fired = syndrome.get(id).copied().unwrap_or(false);
263
264 let node_type = match self.node_type_pattern {
265 NodeTypePattern::Checkerboard => {
266 if (row + col) % 2 == 0 {
267 NodeType::XStabilizer
268 } else {
269 NodeType::ZStabilizer
270 }
271 }
272 NodeTypePattern::AllX => NodeType::XStabilizer,
273 NodeTypePattern::AllZ => NodeType::ZStabilizer,
274 };
275
276 let features = vec![
278 if fired { 1.0 } else { 0.0 },
279 row as f32 / d as f32,
280 col as f32 / d as f32,
281 if node_type == NodeType::XStabilizer { 1.0 } else { 0.0 },
282 if node_type == NodeType::ZStabilizer { 1.0 } else { 0.0 },
283 ];
284
285 graph.add_node(Node {
286 id,
287 row,
288 col,
289 fired,
290 node_type,
291 features,
292 });
293 }
294 }
295
296 let weight = (-self.error_rate.ln()) as f32;
298
299 for row in 0..d {
300 for col in 0..d {
301 let id = row * d + col;
302
303 if col + 1 < d {
305 let neighbor = row * d + (col + 1);
306 graph.add_edge(Edge {
307 from: id,
308 to: neighbor,
309 weight,
310 edge_type: EdgeType::Horizontal,
311 });
312 }
313
314 if row + 1 < d {
316 let neighbor = (row + 1) * d + col;
317 graph.add_edge(Edge {
318 from: id,
319 to: neighbor,
320 weight,
321 edge_type: EdgeType::Vertical,
322 });
323 }
324 }
325 }
326
327 graph.validate()?;
328 Ok(graph)
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_node_creation() {
338 let node = Node {
339 id: 0,
340 row: 0,
341 col: 0,
342 fired: true,
343 node_type: NodeType::XStabilizer,
344 features: vec![1.0],
345 };
346 assert_eq!(node.id, 0);
347 assert!(node.fired);
348 }
349
350 #[test]
351 fn test_edge_creation() {
352 let edge = Edge {
353 from: 0,
354 to: 1,
355 weight: 1.5,
356 edge_type: EdgeType::Horizontal,
357 };
358 assert_eq!(edge.from, 0);
359 assert_eq!(edge.to, 1);
360 }
361
362 #[test]
363 fn test_graph_construction_d3() {
364 let graph = GraphBuilder::from_surface_code(3)
365 .build()
366 .unwrap();
367
368 assert_eq!(graph.num_nodes(), 9);
370
371 assert_eq!(graph.num_edges(), 12);
373 }
374
375 #[test]
376 fn test_graph_construction_d5() {
377 let graph = GraphBuilder::from_surface_code(5)
378 .build()
379 .unwrap();
380
381 assert_eq!(graph.num_nodes(), 25);
383
384 assert_eq!(graph.num_edges(), 40);
386 }
387
388 #[test]
389 fn test_graph_with_syndrome() {
390 let syndrome = vec![true, false, true, false, true, false, false, false, true];
391 let graph = GraphBuilder::from_surface_code(3)
392 .with_syndrome(&syndrome)
393 .unwrap()
394 .build()
395 .unwrap();
396
397 assert_eq!(graph.num_fired, 4);
398 assert_eq!(graph.fired_indices(), vec![0, 2, 4, 8]);
399 }
400
401 #[test]
402 fn test_graph_syndrome_dimension_mismatch() {
403 let syndrome = vec![true, false, true]; let result = GraphBuilder::from_surface_code(3)
405 .with_syndrome(&syndrome);
406
407 assert!(result.is_err());
408 }
409
410 #[test]
411 fn test_graph_adjacency() {
412 let graph = GraphBuilder::from_surface_code(3)
413 .build()
414 .unwrap();
415
416 let neighbors = graph.neighbors(0).unwrap();
418 assert_eq!(neighbors.len(), 2);
419
420 let neighbors = graph.neighbors(4).unwrap();
422 assert_eq!(neighbors.len(), 4);
423 }
424
425 #[test]
426 fn test_node_features_matrix() {
427 let graph = GraphBuilder::from_surface_code(3)
428 .build()
429 .unwrap();
430
431 let features = graph.node_features();
432 assert_eq!(features.shape(), &[9, 5]);
433 }
434
435 #[test]
436 fn test_adjacency_matrix() {
437 let graph = GraphBuilder::from_surface_code(3)
438 .build()
439 .unwrap();
440
441 let adj = graph.adjacency_matrix();
442 assert_eq!(adj.shape(), &[9, 9]);
443
444 for i in 0..9 {
446 for j in 0..9 {
447 assert_eq!(adj[[i, j]], adj[[j, i]]);
448 }
449 }
450 }
451
452 #[test]
453 fn test_edge_weights() {
454 let graph = GraphBuilder::from_surface_code(3)
455 .with_error_rate(0.01)
456 .build()
457 .unwrap();
458
459 let weights = graph.edge_weights();
460 assert_eq!(weights.len(), 12);
461
462 for w in weights.iter() {
464 assert!(*w > 0.0);
465 }
466 }
467
468 #[test]
469 fn test_node_type_pattern_checkerboard() {
470 let graph = GraphBuilder::from_surface_code(3)
471 .with_node_pattern(NodeTypePattern::Checkerboard)
472 .build()
473 .unwrap();
474
475 for node in &graph.nodes {
477 let expected = if (node.row + node.col) % 2 == 0 {
478 NodeType::XStabilizer
479 } else {
480 NodeType::ZStabilizer
481 };
482 assert_eq!(node.node_type, expected);
483 }
484 }
485
486 #[test]
487 fn test_node_type_pattern_all_x() {
488 let graph = GraphBuilder::from_surface_code(3)
489 .with_node_pattern(NodeTypePattern::AllX)
490 .build()
491 .unwrap();
492
493 for node in &graph.nodes {
494 assert_eq!(node.node_type, NodeType::XStabilizer);
495 }
496 }
497
498 #[test]
499 fn test_empty_syndrome() {
500 let syndrome = vec![false; 9];
501 let graph = GraphBuilder::from_surface_code(3)
502 .with_syndrome(&syndrome)
503 .unwrap()
504 .build()
505 .unwrap();
506
507 assert_eq!(graph.num_fired, 0);
508 assert!(graph.fired_indices().is_empty());
509 }
510
511 #[test]
512 fn test_all_fired_syndrome() {
513 let syndrome = vec![true; 9];
514 let graph = GraphBuilder::from_surface_code(3)
515 .with_syndrome(&syndrome)
516 .unwrap()
517 .build()
518 .unwrap();
519
520 assert_eq!(graph.num_fired, 9);
521 assert_eq!(graph.fired_indices().len(), 9);
522 }
523
524 #[test]
525 fn test_graph_validation() {
526 let graph = GraphBuilder::from_surface_code(3)
527 .build()
528 .unwrap();
529
530 assert!(graph.validate().is_ok());
531 }
532
533 #[test]
534 fn test_empty_graph_validation() {
535 let graph = DetectorGraph::new(3);
536 assert!(graph.validate().is_err());
537 }
538}