1use crate::error::{NeuralDecoderError, Result};
19use ndarray::{Array1, Array2, Axis};
20use ruvector_mincut::{DynamicGraph, MinCutBuilder, Weight};
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct FusionConfig {
27 pub gnn_dim: usize,
29 pub mincut_dim: usize,
31 pub output_dim: usize,
33 pub gnn_weight: f32,
35 pub mincut_weight: f32,
37 pub boundary_weight: f32,
39 pub adaptive_weights: bool,
41 pub temperature: f32,
43}
44
45impl Default for FusionConfig {
46 fn default() -> Self {
47 Self {
48 gnn_dim: 64,
49 mincut_dim: 16,
50 output_dim: 32,
51 gnn_weight: 0.5,
52 mincut_weight: 0.3,
53 boundary_weight: 0.2,
54 adaptive_weights: true,
55 temperature: 1.0,
56 }
57 }
58}
59
60impl FusionConfig {
61 pub fn validate(&self) -> Result<()> {
63 let total_weight = self.gnn_weight + self.mincut_weight + self.boundary_weight;
64 if (total_weight - 1.0).abs() > 1e-6 {
65 return Err(NeuralDecoderError::ConfigError(format!(
66 "Fusion weights must sum to 1.0, got {}",
67 total_weight
68 )));
69 }
70 if self.temperature <= 0.0 {
71 return Err(NeuralDecoderError::ConfigError(
72 "Temperature must be positive".to_string(),
73 ));
74 }
75 Ok(())
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct MinCutFeatures {
82 pub global_mincut: f64,
84 pub local_cuts: Vec<f64>,
86 pub edge_in_cut: HashMap<(usize, usize), bool>,
88 pub error_chain_prob: Vec<f64>,
90}
91
92impl MinCutFeatures {
93 pub fn extract(
100 adjacency: &HashMap<usize, Vec<usize>>,
101 edge_weights: &HashMap<(usize, usize), f32>,
102 num_nodes: usize,
103 ) -> Result<Self> {
104 if num_nodes == 0 {
105 return Err(NeuralDecoderError::EmptyGraph);
106 }
107
108 let graph = DynamicGraph::new();
110
111 for (&node, neighbors) in adjacency {
112 for &neighbor in neighbors {
113 if node < neighbor {
114 let weight = edge_weights
115 .get(&(node, neighbor))
116 .or_else(|| edge_weights.get(&(neighbor, node)))
117 .copied()
118 .unwrap_or(1.0);
119 let _ = graph.insert_edge(node as u64, neighbor as u64, 1.0 / (weight + 1e-6) as Weight);
121 }
122 }
123 }
124
125 let mincut = MinCutBuilder::new()
127 .exact()
128 .build()
129 .map_err(|e| NeuralDecoderError::MinCutError(e.to_string()))?;
130
131 let global_mincut = if graph.num_edges() > 0 {
132 mincut.min_cut_value()
133 } else {
134 f64::INFINITY
135 };
136
137 let mut local_cuts = vec![0.0; num_nodes];
139 for (node, neighbors) in adjacency {
140 let total_weight: f32 = neighbors
141 .iter()
142 .map(|&n| {
143 edge_weights
144 .get(&(*node, n))
145 .or_else(|| edge_weights.get(&(n, *node)))
146 .copied()
147 .unwrap_or(1.0)
148 })
149 .sum();
150 local_cuts[*node] = total_weight as f64;
151 }
152
153 let max_cut = local_cuts.iter().cloned().fold(0.0f64, f64::max).max(1e-6);
155 let error_chain_prob: Vec<f64> = local_cuts
156 .iter()
157 .map(|&cut| 1.0 - (cut / max_cut))
158 .collect();
159
160 let mut edge_in_cut = HashMap::new();
162 for (&node, neighbors) in adjacency {
163 for &neighbor in neighbors {
164 if node < neighbor {
165 let weight = edge_weights
166 .get(&(node, neighbor))
167 .or_else(|| edge_weights.get(&(neighbor, node)))
168 .copied()
169 .unwrap_or(1.0);
170 let avg_degree = (local_cuts[node] + local_cuts[neighbor]) / 2.0;
171 edge_in_cut.insert((node, neighbor), (weight as f64) > avg_degree * 0.3);
173 }
174 }
175 }
176
177 Ok(Self {
178 global_mincut,
179 local_cuts,
180 edge_in_cut,
181 error_chain_prob,
182 })
183 }
184
185 pub fn to_features(&self, num_nodes: usize, feature_dim: usize) -> Array2<f32> {
187 let mut features = Array2::zeros((num_nodes, feature_dim));
188 let global_norm = self.global_mincut.max(1e-6);
189
190 for i in 0..num_nodes {
191 if feature_dim >= 1 {
192 features[[i, 0]] = (self.local_cuts.get(i).copied().unwrap_or(0.0) / global_norm) as f32;
194 }
195 if feature_dim >= 2 {
196 features[[i, 1]] = self.error_chain_prob.get(i).copied().unwrap_or(0.5) as f32;
198 }
199 if feature_dim >= 3 {
200 features[[i, 2]] = (global_norm.ln() / 10.0).tanh() as f32;
202 }
203 for j in 3..feature_dim {
205 features[[i, j]] = features[[i, j % 3]];
206 }
207 }
208
209 features
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct BoundaryFeatures {
216 pub distances: Vec<f32>,
218 pub boundary_types: Vec<u8>,
220 pub weights: Vec<f32>,
222}
223
224impl BoundaryFeatures {
225 pub fn compute(positions: &[(f32, f32)], grid_size: usize) -> Self {
231 let num_nodes = positions.len();
232 let mut distances = Vec::with_capacity(num_nodes);
233 let mut boundary_types = Vec::with_capacity(num_nodes);
234 let mut weights = Vec::with_capacity(num_nodes);
235
236 let size = grid_size as f32;
237
238 for &(x, y) in positions {
239 let x_norm = x / size.max(1.0);
241 let y_norm = y / size.max(1.0);
242
243 let d_left = x_norm;
245 let d_right = 1.0 - x_norm;
246 let d_bottom = y_norm;
247 let d_top = 1.0 - y_norm;
248
249 let min_x_dist = d_left.min(d_right);
250 let min_y_dist = d_bottom.min(d_top);
251 let min_dist = min_x_dist.min(min_y_dist);
252
253 distances.push(min_dist);
254
255 let boundary_type = if min_dist < 0.1 {
258 if min_x_dist < min_y_dist {
259 1 } else {
261 2 }
263 } else {
264 0 };
266 boundary_types.push(boundary_type);
267
268 let weight = 1.0 - min_dist;
270 weights.push(weight);
271 }
272
273 let max_weight: f32 = weights.iter().cloned().fold(0.0f32, f32::max).max(1e-6);
275 for w in &mut weights {
276 *w /= max_weight;
277 }
278
279 Self {
280 distances,
281 boundary_types,
282 weights,
283 }
284 }
285
286 pub fn to_features(&self, feature_dim: usize) -> Array2<f32> {
288 let num_nodes = self.distances.len();
289 let mut features = Array2::zeros((num_nodes, feature_dim));
290
291 for i in 0..num_nodes {
292 if feature_dim >= 1 {
293 features[[i, 0]] = self.distances[i];
294 }
295 if feature_dim >= 2 {
296 features[[i, 1]] = self.boundary_types[i] as f32 / 2.0;
297 }
298 if feature_dim >= 3 {
299 features[[i, 2]] = self.weights[i];
300 }
301 if feature_dim >= 4 {
303 let angle = self.boundary_types[i] as f32 * std::f32::consts::PI / 3.0;
305 features[[i, 3]] = angle.sin();
306 }
307 if feature_dim >= 5 {
308 let angle = self.boundary_types[i] as f32 * std::f32::consts::PI / 3.0;
309 features[[i, 4]] = angle.cos();
310 }
311 for j in 5..feature_dim {
313 features[[i, j]] = (-(self.distances[i] * (j - 4) as f32)).exp();
314 }
315 }
316
317 features
318 }
319}
320
321#[derive(Debug, Clone)]
323pub struct CoherenceEstimator {
324 window_size: usize,
326 min_confidence: f32,
328}
329
330impl CoherenceEstimator {
331 pub fn new(window_size: usize, min_confidence: f32) -> Self {
333 Self {
334 window_size,
335 min_confidence: min_confidence.max(0.01),
336 }
337 }
338
339 pub fn estimate(
348 &self,
349 predictions: &Array2<f32>,
350 adjacency: &HashMap<usize, Vec<usize>>,
351 ) -> Vec<f32> {
352 let num_nodes = predictions.shape()[0];
353 let output_dim = predictions.shape()[1];
354 let mut confidences = vec![self.min_confidence; num_nodes];
355
356 for node in 0..num_nodes {
357 let neighbors = adjacency.get(&node).cloned().unwrap_or_default();
358
359 if neighbors.is_empty() {
360 let entropy = self.compute_entropy(&predictions.row(node).to_vec());
362 confidences[node] = 1.0 - entropy;
363 continue;
364 }
365
366 let mut total_sim = 0.0;
368 let node_pred: Vec<f32> = predictions.row(node).to_vec();
369
370 for &neighbor in &neighbors {
371 let neighbor_pred: Vec<f32> = predictions.row(neighbor).to_vec();
372 let sim = self.cosine_similarity(&node_pred, &neighbor_pred);
373 total_sim += sim;
374 }
375
376 let avg_sim = total_sim / neighbors.len() as f32;
377
378 let entropy = self.compute_entropy(&node_pred);
381 let certainty = 1.0 - entropy;
382
383 confidences[node] = (0.6 * avg_sim + 0.4 * certainty).max(self.min_confidence);
385 }
386
387 confidences
388 }
389
390 fn compute_entropy(&self, probs: &[f32]) -> f32 {
392 let eps = 1e-10;
393 let mut entropy = 0.0;
394 for &p in probs {
395 let p = p.clamp(eps as f32, 1.0 - eps as f32);
396 entropy -= p * p.ln();
397 }
398 let max_entropy = (probs.len() as f32).ln();
400 if max_entropy > eps as f32 {
401 entropy / max_entropy
402 } else {
403 0.0
404 }
405 }
406
407 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
409 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
410 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
411 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
412
413 if norm_a > 1e-10 && norm_b > 1e-10 {
414 dot / (norm_a * norm_b)
415 } else {
416 0.0
417 }
418 }
419}
420
421#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct FeatureFusion {
427 config: FusionConfig,
428 gnn_proj: Array2<f32>,
430 mincut_proj: Array2<f32>,
432 boundary_proj: Array2<f32>,
434 output_proj: Array2<f32>,
436 bias: Array1<f32>,
438}
439
440impl FeatureFusion {
441 pub fn new(config: FusionConfig) -> Result<Self> {
443 config.validate()?;
444
445 let combined_dim = config.gnn_dim + config.mincut_dim + 8; let gnn_proj = Self::init_weights(config.gnn_dim, config.output_dim);
449 let mincut_proj = Self::init_weights(config.mincut_dim, config.output_dim);
450 let boundary_proj = Self::init_weights(8, config.output_dim);
451 let output_proj = Self::init_weights(config.output_dim * 3, config.output_dim);
452 let bias = Array1::zeros(config.output_dim);
453
454 Ok(Self {
455 config,
456 gnn_proj,
457 mincut_proj,
458 boundary_proj,
459 output_proj,
460 bias,
461 })
462 }
463
464 fn init_weights(input_dim: usize, output_dim: usize) -> Array2<f32> {
466 use rand::Rng;
467 use rand_distr::{Distribution, Normal};
468
469 let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
470 let normal = Normal::new(0.0, scale as f64).unwrap();
471 let mut rng = rand::thread_rng();
472
473 Array2::from_shape_fn((output_dim, input_dim), |_| {
474 normal.sample(&mut rng) as f32
475 })
476 }
477
478 pub fn fuse_simple(
487 &self,
488 gnn_features: &Array2<f32>,
489 mincut_features: &Array2<f32>,
490 ) -> Result<Array2<f32>> {
491 let num_nodes = gnn_features.shape()[0];
492
493 let boundary_features = Array2::zeros((num_nodes, 8));
495
496 self.fuse(gnn_features, mincut_features, &boundary_features, None)
497 }
498
499 pub fn fuse(
510 &self,
511 gnn_features: &Array2<f32>,
512 mincut_features: &Array2<f32>,
513 boundary_features: &Array2<f32>,
514 confidences: Option<&[f32]>,
515 ) -> Result<Array2<f32>> {
516 let num_nodes = gnn_features.shape()[0];
517
518 if mincut_features.shape()[0] != num_nodes || boundary_features.shape()[0] != num_nodes {
519 return Err(NeuralDecoderError::shape_mismatch(
520 vec![num_nodes],
521 vec![mincut_features.shape()[0]],
522 ));
523 }
524
525 let gnn_proj = gnn_features.dot(&self.gnn_proj.t());
527 let mincut_proj = mincut_features.dot(&self.mincut_proj.t());
528 let boundary_proj = boundary_features.dot(&self.boundary_proj.t());
529
530 let (gnn_w, mincut_w, boundary_w) = if self.config.adaptive_weights {
532 if let Some(conf) = confidences {
533 let avg_conf: f32 = conf.iter().sum::<f32>() / conf.len() as f32;
535 let gnn_w = self.config.gnn_weight * (1.0 + avg_conf);
536 let mincut_w = self.config.mincut_weight * (2.0 - avg_conf);
537 let boundary_w = self.config.boundary_weight;
538 let total = gnn_w + mincut_w + boundary_w;
539 (gnn_w / total, mincut_w / total, boundary_w / total)
540 } else {
541 (self.config.gnn_weight, self.config.mincut_weight, self.config.boundary_weight)
542 }
543 } else {
544 (self.config.gnn_weight, self.config.mincut_weight, self.config.boundary_weight)
545 };
546
547 let mut combined = Array2::zeros((num_nodes, self.config.output_dim * 3));
549 for i in 0..num_nodes {
550 let node_scale = confidences.map(|c| c[i]).unwrap_or(1.0);
552
553 for j in 0..self.config.output_dim {
554 combined[[i, j]] = gnn_proj[[i, j]] * gnn_w * node_scale;
555 combined[[i, self.config.output_dim + j]] = mincut_proj[[i, j]] * mincut_w;
556 combined[[i, 2 * self.config.output_dim + j]] = boundary_proj[[i, j]] * boundary_w;
557 }
558 }
559
560 let output = combined.dot(&self.output_proj.t());
562 let activated = output.mapv(|v| v.max(0.0)); let with_bias = activated + &self.bias;
564
565 Ok(with_bias)
566 }
567
568 pub fn fuse_all(
577 &self,
578 gnn_embeddings: &Array2<f32>,
579 adjacency: &HashMap<usize, Vec<usize>>,
580 edge_weights: &HashMap<(usize, usize), f32>,
581 positions: &[(f32, f32)],
582 grid_size: usize,
583 ) -> Result<Array2<f32>> {
584 let num_nodes = gnn_embeddings.shape()[0];
585
586 let mincut_features = MinCutFeatures::extract(adjacency, edge_weights, num_nodes)?;
588 let mincut_array = mincut_features.to_features(num_nodes, self.config.mincut_dim);
589
590 let boundary_features = BoundaryFeatures::compute(positions, grid_size);
592 let boundary_array = boundary_features.to_features(8);
593
594 let coherence = CoherenceEstimator::new(3, 0.1);
596 let confidences = coherence.estimate(gnn_embeddings, adjacency);
597
598 self.fuse(gnn_embeddings, &mincut_array, &boundary_array, Some(&confidences))
600 }
601
602 pub fn config(&self) -> &FusionConfig {
604 &self.config
605 }
606
607 pub fn output_dim(&self) -> usize {
609 self.config.output_dim
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 fn create_test_graph() -> (HashMap<usize, Vec<usize>>, HashMap<(usize, usize), f32>) {
618 let mut adjacency = HashMap::new();
619 adjacency.insert(0, vec![1, 2]);
620 adjacency.insert(1, vec![0, 2, 3]);
621 adjacency.insert(2, vec![0, 1, 3]);
622 adjacency.insert(3, vec![1, 2]);
623
624 let mut edge_weights = HashMap::new();
625 edge_weights.insert((0, 1), 0.1);
626 edge_weights.insert((0, 2), 0.2);
627 edge_weights.insert((1, 2), 0.15);
628 edge_weights.insert((1, 3), 0.1);
629 edge_weights.insert((2, 3), 0.1);
630
631 (adjacency, edge_weights)
632 }
633
634 #[test]
635 fn test_mincut_features() {
636 let (adjacency, edge_weights) = create_test_graph();
637 let features = MinCutFeatures::extract(&adjacency, &edge_weights, 4).unwrap();
638
639 assert_eq!(features.local_cuts.len(), 4);
640 assert_eq!(features.error_chain_prob.len(), 4);
641 assert!(features.global_mincut > 0.0);
642 }
643
644 #[test]
645 fn test_boundary_features() {
646 let positions = vec![
647 (0.0, 0.0), (0.5, 0.5), (1.0, 0.5), (0.5, 1.0), ];
652
653 let features = BoundaryFeatures::compute(&positions, 1);
654
655 assert_eq!(features.distances.len(), 4);
656 assert!(features.distances[0] < features.distances[1]); assert_eq!(features.boundary_types[1], 0); }
659
660 #[test]
661 fn test_coherence_estimator() {
662 let predictions = Array2::from_shape_fn((4, 2), |(i, j)| {
663 if j == 0 { 0.8 } else { 0.2 }
664 });
665
666 let (adjacency, _) = create_test_graph();
667 let estimator = CoherenceEstimator::new(3, 0.1);
668 let confidences = estimator.estimate(&predictions, &adjacency);
669
670 assert_eq!(confidences.len(), 4);
671 for &c in &confidences {
672 assert!(c >= 0.1 && c <= 1.0);
673 }
674 }
675
676 #[test]
677 fn test_fusion_config_validation() {
678 let mut config = FusionConfig::default();
679 assert!(config.validate().is_ok());
680
681 config.gnn_weight = 0.8; assert!(config.validate().is_err());
683
684 config.gnn_weight = 0.5;
685 config.temperature = -1.0;
686 assert!(config.validate().is_err());
687 }
688
689 #[test]
690 fn test_feature_fusion() {
691 let config = FusionConfig {
692 gnn_dim: 16,
693 mincut_dim: 8,
694 output_dim: 8,
695 gnn_weight: 0.5,
696 mincut_weight: 0.3,
697 boundary_weight: 0.2,
698 adaptive_weights: false,
699 temperature: 1.0,
700 };
701
702 let fusion = FeatureFusion::new(config).unwrap();
703
704 let num_nodes = 4;
705 let gnn_features = Array2::from_shape_fn((num_nodes, 16), |(i, j)| {
706 ((i + j) as f32) / 100.0
707 });
708 let mincut_features = Array2::from_shape_fn((num_nodes, 8), |(i, j)| {
709 ((i * j) as f32) / 50.0
710 });
711 let boundary_features = Array2::from_shape_fn((num_nodes, 8), |(i, _)| {
712 (i as f32) / 4.0
713 });
714
715 let fused = fusion.fuse(
716 &gnn_features,
717 &mincut_features,
718 &boundary_features,
719 None,
720 ).unwrap();
721
722 assert_eq!(fused.shape(), &[num_nodes, 8]);
723 }
724
725 #[test]
726 fn test_fuse_all() {
727 let config = FusionConfig {
728 gnn_dim: 8,
729 mincut_dim: 4,
730 output_dim: 4,
731 gnn_weight: 0.5,
732 mincut_weight: 0.3,
733 boundary_weight: 0.2,
734 adaptive_weights: true,
735 temperature: 1.0,
736 };
737
738 let fusion = FeatureFusion::new(config).unwrap();
739 let (adjacency, edge_weights) = create_test_graph();
740
741 let gnn_embeddings = Array2::from_shape_fn((4, 8), |(i, j)| {
742 ((i + j) as f32) / 10.0
743 });
744
745 let positions = vec![
746 (0.0, 0.0),
747 (1.0, 0.0),
748 (0.0, 1.0),
749 (1.0, 1.0),
750 ];
751
752 let result = fusion.fuse_all(
753 &gnn_embeddings,
754 &adjacency,
755 &edge_weights,
756 &positions,
757 2,
758 );
759
760 assert!(result.is_ok());
761 let fused = result.unwrap();
762 assert_eq!(fused.shape(), &[4, 4]);
763 }
764
765 #[test]
766 fn test_mincut_features_to_array() {
767 let (adjacency, edge_weights) = create_test_graph();
768 let features = MinCutFeatures::extract(&adjacency, &edge_weights, 4).unwrap();
769
770 let array = features.to_features(4, 8);
771 assert_eq!(array.shape(), &[4, 8]);
772 }
773
774 #[test]
775 fn test_boundary_features_to_array() {
776 let positions = vec![(0.0, 0.0), (0.5, 0.5), (1.0, 0.0), (0.5, 1.0)];
777 let features = BoundaryFeatures::compute(&positions, 2);
778
779 let array = features.to_features(8);
780 assert_eq!(array.shape(), &[4, 8]);
781 }
782
783 #[test]
784 fn test_empty_graph_error() {
785 let adjacency = HashMap::new();
786 let edge_weights = HashMap::new();
787
788 let result = MinCutFeatures::extract(&adjacency, &edge_weights, 0);
789 assert!(matches!(result, Err(NeuralDecoderError::EmptyGraph)));
790 }
791}