ruqu_neural_decoder/
fusion.rs

1//! Feature Fusion for Neural Quantum Error Decoding
2//!
3//! This module fuses multiple sources of information for error prediction:
4//! - GNN embeddings from the graph attention encoder
5//! - Min-cut features from graph algorithms
6//! - Boundary proximity weighting
7//! - Coherence confidence scaling
8//!
9//! ## Fusion Strategy
10//!
11//! The fusion combines neural and algorithmic features:
12//!
13//! 1. **GNN Features**: Rich learned representations of syndrome patterns
14//! 2. **Min-Cut Features**: Graph-theoretic error chain likelihood
15//! 3. **Boundary Features**: Distance-based corrections for edge effects
16//! 4. **Confidence Weighting**: Adaptive fusion based on prediction certainty
17
18use crate::error::{NeuralDecoderError, Result};
19use ndarray::{Array1, Array2, Axis};
20use ruvector_mincut::{DynamicGraph, MinCutBuilder, Weight};
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23
24/// Configuration for feature fusion
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct FusionConfig {
27    /// GNN embedding dimension
28    pub gnn_dim: usize,
29    /// MinCut feature dimension
30    pub mincut_dim: usize,
31    /// Output dimension after fusion
32    pub output_dim: usize,
33    /// Weight for GNN features (0-1)
34    pub gnn_weight: f32,
35    /// Weight for MinCut features (0-1)
36    pub mincut_weight: f32,
37    /// Weight for boundary features (0-1)
38    pub boundary_weight: f32,
39    /// Enable adaptive weighting based on confidence
40    pub adaptive_weights: bool,
41    /// Temperature for softmax confidence scaling
42    pub temperature: f32,
43}
44
45impl Default for FusionConfig {
46    fn default() -> Self {
47        Self {
48            gnn_dim: 64,
49            mincut_dim: 16,
50            output_dim: 32,
51            gnn_weight: 0.5,
52            mincut_weight: 0.3,
53            boundary_weight: 0.2,
54            adaptive_weights: true,
55            temperature: 1.0,
56        }
57    }
58}
59
60impl FusionConfig {
61    /// Validate configuration
62    pub fn validate(&self) -> Result<()> {
63        let total_weight = self.gnn_weight + self.mincut_weight + self.boundary_weight;
64        if (total_weight - 1.0).abs() > 1e-6 {
65            return Err(NeuralDecoderError::ConfigError(format!(
66                "Fusion weights must sum to 1.0, got {}",
67                total_weight
68            )));
69        }
70        if self.temperature <= 0.0 {
71            return Err(NeuralDecoderError::ConfigError(
72                "Temperature must be positive".to_string(),
73            ));
74        }
75        Ok(())
76    }
77}
78
79/// Min-cut features extracted from detector graph
80#[derive(Debug, Clone)]
81pub struct MinCutFeatures {
82    /// Global minimum cut value
83    pub global_mincut: f64,
84    /// Local cut values for each node
85    pub local_cuts: Vec<f64>,
86    /// Edge participation in min-cut
87    pub edge_in_cut: HashMap<(usize, usize), bool>,
88    /// Cut-based error chain probability
89    pub error_chain_prob: Vec<f64>,
90}
91
92impl MinCutFeatures {
93    /// Extract min-cut features from a detector graph
94    ///
95    /// # Arguments
96    /// * `adjacency` - Adjacency list of detector graph
97    /// * `edge_weights` - Edge weights (error probabilities)
98    /// * `num_nodes` - Number of detector nodes
99    pub fn extract(
100        adjacency: &HashMap<usize, Vec<usize>>,
101        edge_weights: &HashMap<(usize, usize), f32>,
102        num_nodes: usize,
103    ) -> Result<Self> {
104        if num_nodes == 0 {
105            return Err(NeuralDecoderError::EmptyGraph);
106        }
107
108        // Build graph for min-cut computation
109        let graph = DynamicGraph::new();
110
111        for (&node, neighbors) in adjacency {
112            for &neighbor in neighbors {
113                if node < neighbor {
114                    let weight = edge_weights
115                        .get(&(node, neighbor))
116                        .or_else(|| edge_weights.get(&(neighbor, node)))
117                        .copied()
118                        .unwrap_or(1.0);
119                    // Use 1/weight as edge capacity (higher prob = lower capacity)
120                    let _ = graph.insert_edge(node as u64, neighbor as u64, 1.0 / (weight + 1e-6) as Weight);
121                }
122            }
123        }
124
125        // Compute global min-cut
126        let mincut = MinCutBuilder::new()
127            .exact()
128            .build()
129            .map_err(|e| NeuralDecoderError::MinCutError(e.to_string()))?;
130
131        let global_mincut = if graph.num_edges() > 0 {
132            mincut.min_cut_value()
133        } else {
134            f64::INFINITY
135        };
136
137        // Compute local cuts (simplified: use node degree as proxy)
138        let mut local_cuts = vec![0.0; num_nodes];
139        for (node, neighbors) in adjacency {
140            let total_weight: f32 = neighbors
141                .iter()
142                .map(|&n| {
143                    edge_weights
144                        .get(&(*node, n))
145                        .or_else(|| edge_weights.get(&(n, *node)))
146                        .copied()
147                        .unwrap_or(1.0)
148                })
149                .sum();
150            local_cuts[*node] = total_weight as f64;
151        }
152
153        // Estimate error chain probability based on local structure
154        let max_cut = local_cuts.iter().cloned().fold(0.0f64, f64::max).max(1e-6);
155        let error_chain_prob: Vec<f64> = local_cuts
156            .iter()
157            .map(|&cut| 1.0 - (cut / max_cut))
158            .collect();
159
160        // Track which edges are likely in a cut (high weight / degree ratio)
161        let mut edge_in_cut = HashMap::new();
162        for (&node, neighbors) in adjacency {
163            for &neighbor in neighbors {
164                if node < neighbor {
165                    let weight = edge_weights
166                        .get(&(node, neighbor))
167                        .or_else(|| edge_weights.get(&(neighbor, node)))
168                        .copied()
169                        .unwrap_or(1.0);
170                    let avg_degree = (local_cuts[node] + local_cuts[neighbor]) / 2.0;
171                    // Edge is likely in cut if it has high relative weight
172                    edge_in_cut.insert((node, neighbor), (weight as f64) > avg_degree * 0.3);
173                }
174            }
175        }
176
177        Ok(Self {
178            global_mincut,
179            local_cuts,
180            edge_in_cut,
181            error_chain_prob,
182        })
183    }
184
185    /// Convert to feature vector for each node
186    pub fn to_features(&self, num_nodes: usize, feature_dim: usize) -> Array2<f32> {
187        let mut features = Array2::zeros((num_nodes, feature_dim));
188        let global_norm = self.global_mincut.max(1e-6);
189
190        for i in 0..num_nodes {
191            if feature_dim >= 1 {
192                // Normalized local cut
193                features[[i, 0]] = (self.local_cuts.get(i).copied().unwrap_or(0.0) / global_norm) as f32;
194            }
195            if feature_dim >= 2 {
196                // Error chain probability
197                features[[i, 1]] = self.error_chain_prob.get(i).copied().unwrap_or(0.5) as f32;
198            }
199            if feature_dim >= 3 {
200                // Global context
201                features[[i, 2]] = (global_norm.ln() / 10.0).tanh() as f32;
202            }
203            // Pad remaining dimensions with normalized local features
204            for j in 3..feature_dim {
205                features[[i, j]] = features[[i, j % 3]];
206            }
207        }
208
209        features
210    }
211}
212
213/// Boundary proximity features
214#[derive(Debug, Clone)]
215pub struct BoundaryFeatures {
216    /// Distance from each node to nearest boundary
217    pub distances: Vec<f32>,
218    /// Boundary type for each node (0=inner, 1=X-boundary, 2=Z-boundary)
219    pub boundary_types: Vec<u8>,
220    /// Normalized boundary weights
221    pub weights: Vec<f32>,
222}
223
224impl BoundaryFeatures {
225    /// Compute boundary features from node positions
226    ///
227    /// # Arguments
228    /// * `positions` - (x, y) coordinates for each node
229    /// * `grid_size` - Size of the syndrome grid
230    pub fn compute(positions: &[(f32, f32)], grid_size: usize) -> Self {
231        let num_nodes = positions.len();
232        let mut distances = Vec::with_capacity(num_nodes);
233        let mut boundary_types = Vec::with_capacity(num_nodes);
234        let mut weights = Vec::with_capacity(num_nodes);
235
236        let size = grid_size as f32;
237
238        for &(x, y) in positions {
239            // Normalize to [0, 1]
240            let x_norm = x / size.max(1.0);
241            let y_norm = y / size.max(1.0);
242
243            // Distance to nearest boundary
244            let d_left = x_norm;
245            let d_right = 1.0 - x_norm;
246            let d_bottom = y_norm;
247            let d_top = 1.0 - y_norm;
248
249            let min_x_dist = d_left.min(d_right);
250            let min_y_dist = d_bottom.min(d_top);
251            let min_dist = min_x_dist.min(min_y_dist);
252
253            distances.push(min_dist);
254
255            // Determine boundary type
256            // In surface codes: X-boundaries are left/right, Z-boundaries are top/bottom
257            let boundary_type = if min_dist < 0.1 {
258                if min_x_dist < min_y_dist {
259                    1 // X-boundary
260                } else {
261                    2 // Z-boundary
262                }
263            } else {
264                0 // Inner
265            };
266            boundary_types.push(boundary_type);
267
268            // Weight based on distance (closer to boundary = higher weight for boundary effects)
269            let weight = 1.0 - min_dist;
270            weights.push(weight);
271        }
272
273        // Normalize weights
274        let max_weight: f32 = weights.iter().cloned().fold(0.0f32, f32::max).max(1e-6);
275        for w in &mut weights {
276            *w /= max_weight;
277        }
278
279        Self {
280            distances,
281            boundary_types,
282            weights,
283        }
284    }
285
286    /// Convert to feature matrix
287    pub fn to_features(&self, feature_dim: usize) -> Array2<f32> {
288        let num_nodes = self.distances.len();
289        let mut features = Array2::zeros((num_nodes, feature_dim));
290
291        for i in 0..num_nodes {
292            if feature_dim >= 1 {
293                features[[i, 0]] = self.distances[i];
294            }
295            if feature_dim >= 2 {
296                features[[i, 1]] = self.boundary_types[i] as f32 / 2.0;
297            }
298            if feature_dim >= 3 {
299                features[[i, 2]] = self.weights[i];
300            }
301            // Additional boundary-derived features
302            if feature_dim >= 4 {
303                // Sin/cos encoding of boundary type
304                let angle = self.boundary_types[i] as f32 * std::f32::consts::PI / 3.0;
305                features[[i, 3]] = angle.sin();
306            }
307            if feature_dim >= 5 {
308                let angle = self.boundary_types[i] as f32 * std::f32::consts::PI / 3.0;
309                features[[i, 4]] = angle.cos();
310            }
311            // Pad remaining with distance decay
312            for j in 5..feature_dim {
313                features[[i, j]] = (-(self.distances[i] * (j - 4) as f32)).exp();
314            }
315        }
316
317        features
318    }
319}
320
321/// Coherence-based confidence estimation
322#[derive(Debug, Clone)]
323pub struct CoherenceEstimator {
324    /// Window size for local coherence
325    window_size: usize,
326    /// Minimum confidence threshold
327    min_confidence: f32,
328}
329
330impl CoherenceEstimator {
331    /// Create a new coherence estimator
332    pub fn new(window_size: usize, min_confidence: f32) -> Self {
333        Self {
334            window_size,
335            min_confidence: min_confidence.max(0.01),
336        }
337    }
338
339    /// Estimate confidence scores based on prediction coherence
340    ///
341    /// # Arguments
342    /// * `predictions` - Raw predictions (num_nodes, output_dim)
343    /// * `adjacency` - Graph adjacency
344    ///
345    /// # Returns
346    /// Confidence score for each node
347    pub fn estimate(
348        &self,
349        predictions: &Array2<f32>,
350        adjacency: &HashMap<usize, Vec<usize>>,
351    ) -> Vec<f32> {
352        let num_nodes = predictions.shape()[0];
353        let output_dim = predictions.shape()[1];
354        let mut confidences = vec![self.min_confidence; num_nodes];
355
356        for node in 0..num_nodes {
357            let neighbors = adjacency.get(&node).cloned().unwrap_or_default();
358
359            if neighbors.is_empty() {
360                // No neighbors: use prediction entropy as confidence
361                let entropy = self.compute_entropy(&predictions.row(node).to_vec());
362                confidences[node] = 1.0 - entropy;
363                continue;
364            }
365
366            // Local coherence: similarity of prediction to neighbors
367            let mut total_sim = 0.0;
368            let node_pred: Vec<f32> = predictions.row(node).to_vec();
369
370            for &neighbor in &neighbors {
371                let neighbor_pred: Vec<f32> = predictions.row(neighbor).to_vec();
372                let sim = self.cosine_similarity(&node_pred, &neighbor_pred);
373                total_sim += sim;
374            }
375
376            let avg_sim = total_sim / neighbors.len() as f32;
377
378            // High similarity to neighbors = high coherence = high confidence
379            // Low entropy in predictions = high certainty = high confidence
380            let entropy = self.compute_entropy(&node_pred);
381            let certainty = 1.0 - entropy;
382
383            // Combine coherence and certainty
384            confidences[node] = (0.6 * avg_sim + 0.4 * certainty).max(self.min_confidence);
385        }
386
387        confidences
388    }
389
390    /// Compute normalized entropy of a probability distribution
391    fn compute_entropy(&self, probs: &[f32]) -> f32 {
392        let eps = 1e-10;
393        let mut entropy = 0.0;
394        for &p in probs {
395            let p = p.clamp(eps as f32, 1.0 - eps as f32);
396            entropy -= p * p.ln();
397        }
398        // Normalize by max entropy (uniform distribution)
399        let max_entropy = (probs.len() as f32).ln();
400        if max_entropy > eps as f32 {
401            entropy / max_entropy
402        } else {
403            0.0
404        }
405    }
406
407    /// Compute cosine similarity between two vectors
408    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
409        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
410        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
411        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
412
413        if norm_a > 1e-10 && norm_b > 1e-10 {
414            dot / (norm_a * norm_b)
415        } else {
416            0.0
417        }
418    }
419}
420
421/// Feature fusion module
422///
423/// Combines GNN embeddings, min-cut features, and boundary features
424/// into a unified representation for error prediction.
425#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct FeatureFusion {
427    config: FusionConfig,
428    /// GNN projection weights
429    gnn_proj: Array2<f32>,
430    /// MinCut projection weights
431    mincut_proj: Array2<f32>,
432    /// Boundary projection weights
433    boundary_proj: Array2<f32>,
434    /// Output projection
435    output_proj: Array2<f32>,
436    /// Biases
437    bias: Array1<f32>,
438}
439
440impl FeatureFusion {
441    /// Create a new feature fusion module
442    pub fn new(config: FusionConfig) -> Result<Self> {
443        config.validate()?;
444
445        let combined_dim = config.gnn_dim + config.mincut_dim + 8; // 8 for boundary features
446
447        // Initialize projection matrices with Xavier initialization
448        let gnn_proj = Self::init_weights(config.gnn_dim, config.output_dim);
449        let mincut_proj = Self::init_weights(config.mincut_dim, config.output_dim);
450        let boundary_proj = Self::init_weights(8, config.output_dim);
451        let output_proj = Self::init_weights(config.output_dim * 3, config.output_dim);
452        let bias = Array1::zeros(config.output_dim);
453
454        Ok(Self {
455            config,
456            gnn_proj,
457            mincut_proj,
458            boundary_proj,
459            output_proj,
460            bias,
461        })
462    }
463
464    /// Xavier initialization for weight matrices
465    fn init_weights(input_dim: usize, output_dim: usize) -> Array2<f32> {
466        use rand::Rng;
467        use rand_distr::{Distribution, Normal};
468
469        let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
470        let normal = Normal::new(0.0, scale as f64).unwrap();
471        let mut rng = rand::thread_rng();
472
473        Array2::from_shape_fn((output_dim, input_dim), |_| {
474            normal.sample(&mut rng) as f32
475        })
476    }
477
478    /// Simple fuse for GNN and MinCut features only
479    ///
480    /// # Arguments
481    /// * `gnn_features` - GNN embeddings (num_nodes, gnn_dim)
482    /// * `mincut_features` - MinCut features (num_nodes, mincut_dim)
483    ///
484    /// # Returns
485    /// Fused features (num_nodes, output_dim)
486    pub fn fuse_simple(
487        &self,
488        gnn_features: &Array2<f32>,
489        mincut_features: &Array2<f32>,
490    ) -> Result<Array2<f32>> {
491        let num_nodes = gnn_features.shape()[0];
492
493        // Create default boundary features (zeros)
494        let boundary_features = Array2::zeros((num_nodes, 8));
495
496        self.fuse(gnn_features, mincut_features, &boundary_features, None)
497    }
498
499    /// Fuse features from multiple sources
500    ///
501    /// # Arguments
502    /// * `gnn_features` - GNN embeddings (num_nodes, gnn_dim)
503    /// * `mincut_features` - MinCut features (num_nodes, mincut_dim)
504    /// * `boundary_features` - Boundary features (num_nodes, 8)
505    /// * `confidences` - Optional confidence scores for adaptive weighting
506    ///
507    /// # Returns
508    /// Fused features (num_nodes, output_dim)
509    pub fn fuse(
510        &self,
511        gnn_features: &Array2<f32>,
512        mincut_features: &Array2<f32>,
513        boundary_features: &Array2<f32>,
514        confidences: Option<&[f32]>,
515    ) -> Result<Array2<f32>> {
516        let num_nodes = gnn_features.shape()[0];
517
518        if mincut_features.shape()[0] != num_nodes || boundary_features.shape()[0] != num_nodes {
519            return Err(NeuralDecoderError::shape_mismatch(
520                vec![num_nodes],
521                vec![mincut_features.shape()[0]],
522            ));
523        }
524
525        // Project each feature set
526        let gnn_proj = gnn_features.dot(&self.gnn_proj.t());
527        let mincut_proj = mincut_features.dot(&self.mincut_proj.t());
528        let boundary_proj = boundary_features.dot(&self.boundary_proj.t());
529
530        // Determine weights (adaptive or fixed)
531        let (gnn_w, mincut_w, boundary_w) = if self.config.adaptive_weights {
532            if let Some(conf) = confidences {
533                // Higher confidence -> trust GNN more
534                let avg_conf: f32 = conf.iter().sum::<f32>() / conf.len() as f32;
535                let gnn_w = self.config.gnn_weight * (1.0 + avg_conf);
536                let mincut_w = self.config.mincut_weight * (2.0 - avg_conf);
537                let boundary_w = self.config.boundary_weight;
538                let total = gnn_w + mincut_w + boundary_w;
539                (gnn_w / total, mincut_w / total, boundary_w / total)
540            } else {
541                (self.config.gnn_weight, self.config.mincut_weight, self.config.boundary_weight)
542            }
543        } else {
544            (self.config.gnn_weight, self.config.mincut_weight, self.config.boundary_weight)
545        };
546
547        // Weighted combination
548        let mut combined = Array2::zeros((num_nodes, self.config.output_dim * 3));
549        for i in 0..num_nodes {
550            // Per-node confidence scaling if available
551            let node_scale = confidences.map(|c| c[i]).unwrap_or(1.0);
552
553            for j in 0..self.config.output_dim {
554                combined[[i, j]] = gnn_proj[[i, j]] * gnn_w * node_scale;
555                combined[[i, self.config.output_dim + j]] = mincut_proj[[i, j]] * mincut_w;
556                combined[[i, 2 * self.config.output_dim + j]] = boundary_proj[[i, j]] * boundary_w;
557            }
558        }
559
560        // Final projection with ReLU and residual
561        let output = combined.dot(&self.output_proj.t());
562        let activated = output.mapv(|v| v.max(0.0)); // ReLU
563        let with_bias = activated + &self.bias;
564
565        Ok(with_bias)
566    }
567
568    /// Convenience method to compute all features and fuse them
569    ///
570    /// # Arguments
571    /// * `gnn_embeddings` - GNN node embeddings
572    /// * `adjacency` - Graph adjacency list
573    /// * `edge_weights` - Edge weights for min-cut
574    /// * `positions` - Node positions
575    /// * `grid_size` - Grid size for boundary computation
576    pub fn fuse_all(
577        &self,
578        gnn_embeddings: &Array2<f32>,
579        adjacency: &HashMap<usize, Vec<usize>>,
580        edge_weights: &HashMap<(usize, usize), f32>,
581        positions: &[(f32, f32)],
582        grid_size: usize,
583    ) -> Result<Array2<f32>> {
584        let num_nodes = gnn_embeddings.shape()[0];
585
586        // Extract min-cut features
587        let mincut_features = MinCutFeatures::extract(adjacency, edge_weights, num_nodes)?;
588        let mincut_array = mincut_features.to_features(num_nodes, self.config.mincut_dim);
589
590        // Compute boundary features
591        let boundary_features = BoundaryFeatures::compute(positions, grid_size);
592        let boundary_array = boundary_features.to_features(8);
593
594        // Estimate confidences based on GNN predictions
595        let coherence = CoherenceEstimator::new(3, 0.1);
596        let confidences = coherence.estimate(gnn_embeddings, adjacency);
597
598        // Fuse all features
599        self.fuse(gnn_embeddings, &mincut_array, &boundary_array, Some(&confidences))
600    }
601
602    /// Get configuration
603    pub fn config(&self) -> &FusionConfig {
604        &self.config
605    }
606
607    /// Get output dimension
608    pub fn output_dim(&self) -> usize {
609        self.config.output_dim
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    fn create_test_graph() -> (HashMap<usize, Vec<usize>>, HashMap<(usize, usize), f32>) {
618        let mut adjacency = HashMap::new();
619        adjacency.insert(0, vec![1, 2]);
620        adjacency.insert(1, vec![0, 2, 3]);
621        adjacency.insert(2, vec![0, 1, 3]);
622        adjacency.insert(3, vec![1, 2]);
623
624        let mut edge_weights = HashMap::new();
625        edge_weights.insert((0, 1), 0.1);
626        edge_weights.insert((0, 2), 0.2);
627        edge_weights.insert((1, 2), 0.15);
628        edge_weights.insert((1, 3), 0.1);
629        edge_weights.insert((2, 3), 0.1);
630
631        (adjacency, edge_weights)
632    }
633
634    #[test]
635    fn test_mincut_features() {
636        let (adjacency, edge_weights) = create_test_graph();
637        let features = MinCutFeatures::extract(&adjacency, &edge_weights, 4).unwrap();
638
639        assert_eq!(features.local_cuts.len(), 4);
640        assert_eq!(features.error_chain_prob.len(), 4);
641        assert!(features.global_mincut > 0.0);
642    }
643
644    #[test]
645    fn test_boundary_features() {
646        let positions = vec![
647            (0.0, 0.0),  // Corner
648            (0.5, 0.5),  // Center
649            (1.0, 0.5),  // Right edge
650            (0.5, 1.0),  // Top edge
651        ];
652
653        let features = BoundaryFeatures::compute(&positions, 1);
654
655        assert_eq!(features.distances.len(), 4);
656        assert!(features.distances[0] < features.distances[1]); // Corner closer to boundary
657        assert_eq!(features.boundary_types[1], 0); // Center is inner
658    }
659
660    #[test]
661    fn test_coherence_estimator() {
662        let predictions = Array2::from_shape_fn((4, 2), |(i, j)| {
663            if j == 0 { 0.8 } else { 0.2 }
664        });
665
666        let (adjacency, _) = create_test_graph();
667        let estimator = CoherenceEstimator::new(3, 0.1);
668        let confidences = estimator.estimate(&predictions, &adjacency);
669
670        assert_eq!(confidences.len(), 4);
671        for &c in &confidences {
672            assert!(c >= 0.1 && c <= 1.0);
673        }
674    }
675
676    #[test]
677    fn test_fusion_config_validation() {
678        let mut config = FusionConfig::default();
679        assert!(config.validate().is_ok());
680
681        config.gnn_weight = 0.8; // Now sum > 1
682        assert!(config.validate().is_err());
683
684        config.gnn_weight = 0.5;
685        config.temperature = -1.0;
686        assert!(config.validate().is_err());
687    }
688
689    #[test]
690    fn test_feature_fusion() {
691        let config = FusionConfig {
692            gnn_dim: 16,
693            mincut_dim: 8,
694            output_dim: 8,
695            gnn_weight: 0.5,
696            mincut_weight: 0.3,
697            boundary_weight: 0.2,
698            adaptive_weights: false,
699            temperature: 1.0,
700        };
701
702        let fusion = FeatureFusion::new(config).unwrap();
703
704        let num_nodes = 4;
705        let gnn_features = Array2::from_shape_fn((num_nodes, 16), |(i, j)| {
706            ((i + j) as f32) / 100.0
707        });
708        let mincut_features = Array2::from_shape_fn((num_nodes, 8), |(i, j)| {
709            ((i * j) as f32) / 50.0
710        });
711        let boundary_features = Array2::from_shape_fn((num_nodes, 8), |(i, _)| {
712            (i as f32) / 4.0
713        });
714
715        let fused = fusion.fuse(
716            &gnn_features,
717            &mincut_features,
718            &boundary_features,
719            None,
720        ).unwrap();
721
722        assert_eq!(fused.shape(), &[num_nodes, 8]);
723    }
724
725    #[test]
726    fn test_fuse_all() {
727        let config = FusionConfig {
728            gnn_dim: 8,
729            mincut_dim: 4,
730            output_dim: 4,
731            gnn_weight: 0.5,
732            mincut_weight: 0.3,
733            boundary_weight: 0.2,
734            adaptive_weights: true,
735            temperature: 1.0,
736        };
737
738        let fusion = FeatureFusion::new(config).unwrap();
739        let (adjacency, edge_weights) = create_test_graph();
740
741        let gnn_embeddings = Array2::from_shape_fn((4, 8), |(i, j)| {
742            ((i + j) as f32) / 10.0
743        });
744
745        let positions = vec![
746            (0.0, 0.0),
747            (1.0, 0.0),
748            (0.0, 1.0),
749            (1.0, 1.0),
750        ];
751
752        let result = fusion.fuse_all(
753            &gnn_embeddings,
754            &adjacency,
755            &edge_weights,
756            &positions,
757            2,
758        );
759
760        assert!(result.is_ok());
761        let fused = result.unwrap();
762        assert_eq!(fused.shape(), &[4, 4]);
763    }
764
765    #[test]
766    fn test_mincut_features_to_array() {
767        let (adjacency, edge_weights) = create_test_graph();
768        let features = MinCutFeatures::extract(&adjacency, &edge_weights, 4).unwrap();
769
770        let array = features.to_features(4, 8);
771        assert_eq!(array.shape(), &[4, 8]);
772    }
773
774    #[test]
775    fn test_boundary_features_to_array() {
776        let positions = vec![(0.0, 0.0), (0.5, 0.5), (1.0, 0.0), (0.5, 1.0)];
777        let features = BoundaryFeatures::compute(&positions, 2);
778
779        let array = features.to_features(8);
780        assert_eq!(array.shape(), &[4, 8]);
781    }
782
783    #[test]
784    fn test_empty_graph_error() {
785        let adjacency = HashMap::new();
786        let edge_weights = HashMap::new();
787
788        let result = MinCutFeatures::extract(&adjacency, &edge_weights, 0);
789        assert!(matches!(result, Err(NeuralDecoderError::EmptyGraph)));
790    }
791}