1#![allow(clippy::too_many_arguments)]
7#![allow(dead_code)]
8
9use super::core::{CalibrationMetrics, ClassMetrics, GroupFairnessMetrics, NodeId};
10use crate::error::{MetricsError, Result};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
12use scirs2_core::numeric::Float;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub struct NodeLevelMetrics {
19 pub classification_metrics: NodeClassificationMetrics,
21 pub embedding_metrics: NodeEmbeddingMetrics,
23 pub homophily_metrics: HomophilyAwareMetrics,
25 pub fairness_metrics: NodeFairnessMetrics,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct NodeClassificationMetrics {
32 pub structure_aware_accuracy: f64,
34 pub macro_f1: f64,
36 pub micro_f1: f64,
38 pub per_class_metrics: HashMap<String, ClassMetrics>,
40 pub calibration_metrics: CalibrationMetrics,
42}
43
44impl Default for NodeClassificationMetrics {
45 fn default() -> Self {
46 Self {
47 structure_aware_accuracy: 0.0,
48 macro_f1: 0.0,
49 micro_f1: 0.0,
50 per_class_metrics: HashMap::new(),
51 calibration_metrics: CalibrationMetrics::default(),
52 }
53 }
54}
55
56impl NodeClassificationMetrics {
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn compute_structure_aware_accuracy<F: Float>(
63 &mut self,
64 predictions: &ArrayView1<F>,
65 ground_truth: &ArrayView1<F>,
66 adjacency_matrix: &ArrayView2<F>,
67 ) -> Result<f64>
68 where
69 F: std::iter::Sum + std::fmt::Debug,
70 {
71 let n = predictions.len();
72 if n != ground_truth.len() || adjacency_matrix.nrows() != n || adjacency_matrix.ncols() != n
73 {
74 return Err(MetricsError::DimensionMismatch(
75 "Predictions, ground truth, and adjacency matrix dimensions must match".to_string(),
76 ));
77 }
78
79 let mut correct_predictions = 0.0;
80 let mut total_predictions = 0.0;
81
82 for i in 0..n {
83 let degree = adjacency_matrix.row(i).sum().to_f64().unwrap_or(1.0);
85 let weight = (degree + 1.0).ln(); if (predictions[i] - ground_truth[i]).abs()
88 < F::from(0.5).expect("Failed to convert constant to float")
89 {
90 correct_predictions += weight;
91 }
92 total_predictions += weight;
93 }
94
95 let accuracy = if total_predictions > 0.0 {
96 correct_predictions / total_predictions
97 } else {
98 0.0
99 };
100
101 self.structure_aware_accuracy = accuracy;
102 Ok(accuracy)
103 }
104
105 pub fn compute_f1_scores<F: Float>(
107 &mut self,
108 predictions: &ArrayView1<F>,
109 ground_truth: &ArrayView1<F>,
110 class_labels: &[String],
111 ) -> Result<(f64, f64)>
112 where
113 F: std::iter::Sum + std::fmt::Debug,
114 {
115 let n = predictions.len();
116 if n != ground_truth.len() {
117 return Err(MetricsError::DimensionMismatch(
118 "Predictions and ground truth dimensions must match".to_string(),
119 ));
120 }
121
122 let mut class_metrics = HashMap::new();
123
124 for class_label in class_labels {
125 let mut tp = 0;
126 let mut fp = 0;
127 let mut fn_count = 0;
128 let mut tn = 0;
129
130 for i in 0..n {
131 let pred_class = predictions[i].to_usize().unwrap_or(0);
132 let true_class = ground_truth[i].to_usize().unwrap_or(0);
133 let current_class_idx = class_labels
134 .iter()
135 .position(|x| x == class_label)
136 .unwrap_or(0);
137
138 match (
139 pred_class == current_class_idx,
140 true_class == current_class_idx,
141 ) {
142 (true, true) => tp += 1,
143 (true, false) => fp += 1,
144 (false, true) => fn_count += 1,
145 (false, false) => {
146 #[allow(unused_assignments)]
147 {
148 tn += 1
149 }
150 }
151 }
152 }
153
154 let precision = if tp + fp > 0 {
155 tp as f64 / (tp + fp) as f64
156 } else {
157 0.0
158 };
159 let recall = if tp + fn_count > 0 {
160 tp as f64 / (tp + fn_count) as f64
161 } else {
162 0.0
163 };
164 let f1 = if precision + recall > 0.0 {
165 2.0 * precision * recall / (precision + recall)
166 } else {
167 0.0
168 };
169
170 class_metrics.insert(
171 class_label.clone(),
172 ClassMetrics {
173 precision,
174 recall,
175 f1_score: f1,
176 support: (tp + fn_count),
177 },
178 );
179 }
180
181 let macro_f1 = class_metrics
183 .values()
184 .map(|metrics| metrics.f1_score)
185 .sum::<f64>()
186 / class_metrics.len() as f64;
187
188 let total_tp: usize = class_metrics
190 .values()
191 .map(|m| (m.recall * m.support as f64) as usize)
192 .sum();
193 let total_support: usize = class_metrics.values().map(|m| m.support).sum();
194 let micro_f1 = if total_support > 0 {
195 total_tp as f64 / total_support as f64
196 } else {
197 0.0
198 };
199
200 self.macro_f1 = macro_f1;
201 self.micro_f1 = micro_f1;
202 self.per_class_metrics = class_metrics;
203
204 Ok((macro_f1, micro_f1))
205 }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct NodeEmbeddingMetrics {
211 pub silhouette_score: f64,
213 pub intra_cluster_cohesion: f64,
215 pub inter_cluster_separation: f64,
217 pub structure_alignment: f64,
219 pub neighborhood_preservation: f64,
221}
222
223impl Default for NodeEmbeddingMetrics {
224 fn default() -> Self {
225 Self {
226 silhouette_score: 0.0,
227 intra_cluster_cohesion: 0.0,
228 inter_cluster_separation: 0.0,
229 structure_alignment: 0.0,
230 neighborhood_preservation: 0.0,
231 }
232 }
233}
234
235impl NodeEmbeddingMetrics {
236 pub fn new() -> Self {
237 Self::default()
238 }
239
240 pub fn compute_embedding_quality<F: Float>(
242 &mut self,
243 embeddings: &ArrayView2<F>,
244 adjacency_matrix: &ArrayView2<F>,
245 node_labels: Option<&ArrayView1<F>>,
246 ) -> Result<()>
247 where
248 F: std::iter::Sum + std::fmt::Debug,
249 {
250 let n_nodes = embeddings.nrows();
251 if adjacency_matrix.nrows() != n_nodes || adjacency_matrix.ncols() != n_nodes {
252 return Err(MetricsError::DimensionMismatch(
253 "Embeddings and adjacency matrix dimensions must match".to_string(),
254 ));
255 }
256
257 self.neighborhood_preservation =
259 self.compute_neighborhood_preservation(embeddings, adjacency_matrix)?;
260
261 self.structure_alignment =
263 self.compute_structure_alignment(embeddings, adjacency_matrix)?;
264
265 if let Some(labels) = node_labels {
267 self.silhouette_score = self.compute_silhouette_score(embeddings, labels)?;
268 }
269
270 Ok(())
271 }
272
273 fn compute_neighborhood_preservation<F: Float>(
274 &self,
275 embeddings: &ArrayView2<F>,
276 adjacency_matrix: &ArrayView2<F>,
277 ) -> Result<f64>
278 where
279 F: std::iter::Sum + std::fmt::Debug,
280 {
281 let n_nodes = embeddings.nrows();
282 let mut preservation_sum = 0.0;
283
284 for i in 0..n_nodes {
285 let neighbors: Vec<usize> = (0..n_nodes)
286 .filter(|&j| i != j && adjacency_matrix[(i, j)] > F::zero())
287 .collect();
288
289 if neighbors.is_empty() {
290 continue;
291 }
292
293 let mut distances: Vec<(usize, f64)> = (0..n_nodes)
295 .filter(|&j| i != j)
296 .map(|j| {
297 let dist = (0..embeddings.ncols())
298 .map(|k| (embeddings[(i, k)] - embeddings[(j, k)]).powi(2))
299 .sum::<F>()
300 .sqrt()
301 .to_f64()
302 .unwrap_or(0.0);
303 (j, dist)
304 })
305 .collect();
306
307 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
308
309 let k = neighbors.len().min(10); let top_k_nodes: HashSet<usize> =
312 distances.iter().take(k).map(|(idx, _)| *idx).collect();
313 let preserved = neighbors
314 .iter()
315 .filter(|&&n| top_k_nodes.contains(&n))
316 .count();
317
318 preservation_sum += preserved as f64 / neighbors.len() as f64;
319 }
320
321 Ok(preservation_sum / n_nodes as f64)
322 }
323
324 fn compute_structure_alignment<F: Float>(
325 &self,
326 embeddings: &ArrayView2<F>,
327 adjacency_matrix: &ArrayView2<F>,
328 ) -> Result<f64>
329 where
330 F: std::iter::Sum + std::fmt::Debug,
331 {
332 let n_nodes = embeddings.nrows();
333 let mut alignment_sum = 0.0;
334 let mut pair_count = 0;
335
336 for i in 0..n_nodes {
337 for j in (i + 1)..n_nodes {
338 let graph_connected = adjacency_matrix[(i, j)] > F::zero();
339
340 let emb_distance = (0..embeddings.ncols())
342 .map(|k| (embeddings[(i, k)] - embeddings[(j, k)]).powi(2))
343 .sum::<F>()
344 .sqrt()
345 .to_f64()
346 .unwrap_or(0.0);
347
348 if graph_connected {
350 alignment_sum += (-emb_distance).exp(); } else {
352 alignment_sum += emb_distance.min(1.0); }
354 pair_count += 1;
355 }
356 }
357
358 Ok(alignment_sum / pair_count as f64)
359 }
360
361 fn compute_silhouette_score<F: Float>(
362 &self,
363 embeddings: &ArrayView2<F>,
364 labels: &ArrayView1<F>,
365 ) -> Result<f64>
366 where
367 F: std::iter::Sum + std::fmt::Debug,
368 {
369 let n_nodes = embeddings.nrows();
370 if n_nodes != labels.len() {
371 return Err(MetricsError::DimensionMismatch(
372 "Embeddings and labels dimensions must match".to_string(),
373 ));
374 }
375
376 let mut silhouette_sum = 0.0;
377
378 for i in 0..n_nodes {
379 let node_label = labels[i];
380
381 let same_cluster_distances: Vec<f64> = (0..n_nodes)
383 .filter(|&j| {
384 i != j
385 && (labels[j] - node_label).abs()
386 < F::from(0.1).expect("Failed to convert constant to float")
387 })
388 .map(|j| {
389 (0..embeddings.ncols())
390 .map(|k| (embeddings[(i, k)] - embeddings[(j, k)]).powi(2))
391 .sum::<F>()
392 .sqrt()
393 .to_f64()
394 .unwrap_or(0.0)
395 })
396 .collect();
397
398 let a = if same_cluster_distances.is_empty() {
399 0.0
400 } else {
401 same_cluster_distances.iter().sum::<f64>() / same_cluster_distances.len() as f64
402 };
403
404 let mut other_labels: Vec<F> = (0..n_nodes)
406 .map(|j| labels[j])
407 .filter(|&label| {
408 (label - node_label).abs()
409 >= F::from(0.1).expect("Failed to convert constant to float")
410 })
411 .collect();
412 other_labels.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
413 other_labels.dedup_by(|a, b| {
414 (*a - *b).abs() < F::from(0.1).expect("Failed to convert constant to float")
415 });
416
417 let b = other_labels
418 .iter()
419 .map(|&other_label| {
420 let distances: Vec<f64> = (0..n_nodes)
421 .filter(|&j| {
422 (labels[j] - other_label).abs()
423 < F::from(0.1).expect("Failed to convert constant to float")
424 })
425 .map(|j| {
426 (0..embeddings.ncols())
427 .map(|k| (embeddings[(i, k)] - embeddings[(j, k)]).powi(2))
428 .sum::<F>()
429 .sqrt()
430 .to_f64()
431 .unwrap_or(0.0)
432 })
433 .collect();
434
435 if distances.is_empty() {
436 f64::INFINITY
437 } else {
438 distances.iter().sum::<f64>() / distances.len() as f64
439 }
440 })
441 .fold(f64::INFINITY, f64::min);
442
443 let silhouette = if a.max(b) > 0.0 {
444 (b - a) / a.max(b)
445 } else {
446 0.0
447 };
448
449 silhouette_sum += silhouette;
450 }
451
452 Ok(silhouette_sum / n_nodes as f64)
453 }
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct HomophilyAwareMetrics {
459 pub homophily_ratio: f64,
461 pub homophilic_performance: f64,
463 pub heterophilic_performance: f64,
465 pub performance_gap: f64,
467 pub local_homophily: HashMap<usize, f64>, }
470
471impl Default for HomophilyAwareMetrics {
472 fn default() -> Self {
473 Self {
474 homophily_ratio: 0.0,
475 homophilic_performance: 0.0,
476 heterophilic_performance: 0.0,
477 performance_gap: 0.0,
478 local_homophily: HashMap::new(),
479 }
480 }
481}
482
483impl HomophilyAwareMetrics {
484 pub fn new() -> Self {
485 Self::default()
486 }
487
488 pub fn compute_homophily_metrics<F: Float>(
490 &mut self,
491 predictions: &ArrayView1<F>,
492 ground_truth: &ArrayView1<F>,
493 adjacency_matrix: &ArrayView2<F>,
494 ) -> Result<()>
495 where
496 F: std::iter::Sum + std::fmt::Debug,
497 {
498 let n_nodes = predictions.len();
499 if n_nodes != ground_truth.len() || adjacency_matrix.nrows() != n_nodes {
500 return Err(MetricsError::DimensionMismatch(
501 "All inputs must have matching dimensions".to_string(),
502 ));
503 }
504
505 self.homophily_ratio = self.compute_global_homophily(ground_truth, adjacency_matrix)?;
507
508 self.local_homophily = self.compute_local_homophily(ground_truth, adjacency_matrix)?;
510
511 self.compute_edge_performance(predictions, ground_truth, adjacency_matrix)?;
513
514 Ok(())
515 }
516
517 fn compute_global_homophily<F: Float>(
518 &self,
519 labels: &ArrayView1<F>,
520 adjacency_matrix: &ArrayView2<F>,
521 ) -> Result<f64>
522 where
523 F: std::iter::Sum + std::fmt::Debug,
524 {
525 let n_nodes = labels.len();
526 let mut homophilic_edges = 0;
527 let mut total_edges = 0;
528
529 for i in 0..n_nodes {
530 for j in (i + 1)..n_nodes {
531 if adjacency_matrix[(i, j)] > F::zero() {
532 total_edges += 1;
533 if (labels[i] - labels[j]).abs()
534 < F::from(0.1).expect("Failed to convert constant to float")
535 {
536 homophilic_edges += 1;
537 }
538 }
539 }
540 }
541
542 Ok(if total_edges > 0 {
543 homophilic_edges as f64 / total_edges as f64
544 } else {
545 0.0
546 })
547 }
548
549 fn compute_local_homophily<F: Float>(
550 &self,
551 labels: &ArrayView1<F>,
552 adjacency_matrix: &ArrayView2<F>,
553 ) -> Result<HashMap<usize, f64>>
554 where
555 F: std::iter::Sum + std::fmt::Debug,
556 {
557 let n_nodes = labels.len();
558 let mut local_homophily = HashMap::new();
559
560 for i in 0..n_nodes {
561 let neighbors: Vec<usize> = (0..n_nodes)
562 .filter(|&j| i != j && adjacency_matrix[(i, j)] > F::zero())
563 .collect();
564
565 if neighbors.is_empty() {
566 local_homophily.insert(i, 0.0);
567 continue;
568 }
569
570 let same_label_neighbors = neighbors
571 .iter()
572 .filter(|&&j| {
573 (labels[i] - labels[j]).abs()
574 < F::from(0.1).expect("Failed to convert constant to float")
575 })
576 .count();
577
578 let homophily = same_label_neighbors as f64 / neighbors.len() as f64;
579 local_homophily.insert(i, homophily);
580 }
581
582 Ok(local_homophily)
583 }
584
585 fn compute_edge_performance<F: Float>(
586 &mut self,
587 predictions: &ArrayView1<F>,
588 ground_truth: &ArrayView1<F>,
589 adjacency_matrix: &ArrayView2<F>,
590 ) -> Result<()>
591 where
592 F: std::iter::Sum + std::fmt::Debug,
593 {
594 let n_nodes = predictions.len();
595 let mut homophilic_correct = 0;
596 let mut homophilic_total = 0;
597 let mut heterophilic_correct = 0;
598 let mut heterophilic_total = 0;
599
600 for i in 0..n_nodes {
601 for j in (i + 1)..n_nodes {
602 if adjacency_matrix[(i, j)] > F::zero() {
603 let same_true_label = (ground_truth[i] - ground_truth[j]).abs()
604 < F::from(0.1).expect("Failed to convert constant to float");
605 let correct_i = (predictions[i] - ground_truth[i]).abs()
606 < F::from(0.5).expect("Failed to convert constant to float");
607 let correct_j = (predictions[j] - ground_truth[j]).abs()
608 < F::from(0.5).expect("Failed to convert constant to float");
609 let edge_correct = correct_i && correct_j;
610
611 if same_true_label {
612 homophilic_total += 1;
613 if edge_correct {
614 homophilic_correct += 1;
615 }
616 } else {
617 heterophilic_total += 1;
618 if edge_correct {
619 heterophilic_correct += 1;
620 }
621 }
622 }
623 }
624 }
625
626 self.homophilic_performance = if homophilic_total > 0 {
627 homophilic_correct as f64 / homophilic_total as f64
628 } else {
629 0.0
630 };
631
632 self.heterophilic_performance = if heterophilic_total > 0 {
633 heterophilic_correct as f64 / heterophilic_total as f64
634 } else {
635 0.0
636 };
637
638 self.performance_gap = (self.homophilic_performance - self.heterophilic_performance).abs();
639
640 Ok(())
641 }
642}
643
644#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct NodeFairnessMetrics {
647 pub demographic_parity: f64,
649 pub equalized_odds: f64,
651 pub individual_fairness: f64,
653 pub group_fairness: HashMap<String, GroupFairnessMetrics>,
655}
656
657impl Default for NodeFairnessMetrics {
658 fn default() -> Self {
659 Self {
660 demographic_parity: 0.0,
661 equalized_odds: 0.0,
662 individual_fairness: 0.0,
663 group_fairness: HashMap::new(),
664 }
665 }
666}
667
668impl NodeFairnessMetrics {
669 pub fn new() -> Self {
670 Self::default()
671 }
672
673 pub fn compute_fairness_metrics<F: Float>(
675 &mut self,
676 predictions: &ArrayView1<F>,
677 ground_truth: &ArrayView1<F>,
678 sensitive_attributes: &HashMap<usize, String>,
679 ) -> Result<()>
680 where
681 F: std::iter::Sum + std::fmt::Debug,
682 {
683 let mut group_metrics = HashMap::new();
685 let groups: std::collections::HashSet<_> = sensitive_attributes.values().cloned().collect();
686
687 for group in groups {
688 let group_indices: Vec<usize> = sensitive_attributes
689 .iter()
690 .filter(|(_, g)| *g == &group)
691 .map(|(idx, _)| *idx)
692 .collect();
693
694 if group_indices.is_empty() {
695 continue;
696 }
697
698 let mut tp = 0;
699 let mut fp = 0;
700 let mut tn = 0;
701 let mut fn_count = 0;
702
703 for &idx in &group_indices {
704 let pred = predictions[idx].to_f64().unwrap_or(0.0) > 0.5;
705 let truth = ground_truth[idx].to_f64().unwrap_or(0.0) > 0.5;
706
707 match (pred, truth) {
708 (true, true) => tp += 1,
709 (true, false) => fp += 1,
710 (false, false) => tn += 1,
711 (false, true) => fn_count += 1,
712 }
713 }
714
715 let tpr = if tp + fn_count > 0 {
716 tp as f64 / (tp + fn_count) as f64
717 } else {
718 0.0
719 };
720 let fpr = if fp + tn > 0 {
721 fp as f64 / (fp + tn) as f64
722 } else {
723 0.0
724 };
725 let precision = if tp + fp > 0 {
726 tp as f64 / (tp + fp) as f64
727 } else {
728 0.0
729 };
730 let selection_rate = if group_indices.len() > 0 {
731 group_indices
732 .iter()
733 .map(|&idx| {
734 if predictions[idx].to_f64().unwrap_or(0.0) > 0.5 {
735 1.0
736 } else {
737 0.0
738 }
739 })
740 .sum::<f64>()
741 / group_indices.len() as f64
742 } else {
743 0.0
744 };
745
746 group_metrics.insert(
747 group.clone(),
748 GroupFairnessMetrics {
749 tpr,
750 fpr,
751 precision,
752 selection_rate,
753 },
754 );
755 }
756
757 if group_metrics.len() >= 2 {
759 let group_names: Vec<_> = group_metrics.keys().collect();
760 let group1 = group_metrics.get(group_names[0]).expect("Operation failed");
761 let group2 = group_metrics.get(group_names[1]).expect("Operation failed");
762
763 self.demographic_parity = (group1.selection_rate - group2.selection_rate).abs();
764 self.equalized_odds =
765 ((group1.tpr - group2.tpr).abs() + (group1.fpr - group2.fpr).abs()) / 2.0;
766 }
767
768 self.group_fairness = group_metrics;
769 Ok(())
770 }
771}
772
773impl NodeLevelMetrics {
774 pub fn new() -> Self {
776 Self {
777 classification_metrics: NodeClassificationMetrics::new(),
778 embedding_metrics: NodeEmbeddingMetrics::new(),
779 homophily_metrics: HomophilyAwareMetrics::new(),
780 fairness_metrics: NodeFairnessMetrics::new(),
781 }
782 }
783}
784
785impl Default for NodeLevelMetrics {
786 fn default() -> Self {
787 Self::new()
788 }
789}
790
791use std::collections::HashSet;