scirs2_metrics/domains/graph_neural_networks/
node_level.rs

1//! Node-level task evaluation metrics for Graph Neural Networks
2//!
3//! This module provides metrics for evaluating node-level tasks such as
4//! node classification, regression, embedding quality, and fairness.
5
6#![allow(clippy::too_many_arguments)]
7#![allow(dead_code)]
8
9use super::core::{CalibrationMetrics, ClassMetrics, GroupFairnessMetrics, NodeId};
10use crate::error::{MetricsError, Result};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
12use scirs2_core::numeric::Float;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// Node-level task evaluation metrics
17#[derive(Debug, Clone)]
18pub struct NodeLevelMetrics {
19    /// Standard classification/regression metrics
20    pub classification_metrics: NodeClassificationMetrics,
21    /// Node embedding quality metrics
22    pub embedding_metrics: NodeEmbeddingMetrics,
23    /// Homophily and heterophily aware metrics
24    pub homophily_metrics: HomophilyAwareMetrics,
25    /// Fairness metrics for node predictions
26    pub fairness_metrics: NodeFairnessMetrics,
27}
28
29/// Node classification specific metrics
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct NodeClassificationMetrics {
32    /// Accuracy considering graph structure
33    pub structure_aware_accuracy: f64,
34    /// Macro F1 score
35    pub macro_f1: f64,
36    /// Micro F1 score
37    pub micro_f1: f64,
38    /// Per-class metrics
39    pub per_class_metrics: HashMap<String, ClassMetrics>,
40    /// Confidence calibration metrics
41    pub calibration_metrics: CalibrationMetrics,
42}
43
44impl Default for NodeClassificationMetrics {
45    fn default() -> Self {
46        Self {
47            structure_aware_accuracy: 0.0,
48            macro_f1: 0.0,
49            micro_f1: 0.0,
50            per_class_metrics: HashMap::new(),
51            calibration_metrics: CalibrationMetrics::default(),
52        }
53    }
54}
55
56impl NodeClassificationMetrics {
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Compute structure-aware accuracy considering graph connectivity
62    pub fn compute_structure_aware_accuracy<F: Float>(
63        &mut self,
64        predictions: &ArrayView1<F>,
65        ground_truth: &ArrayView1<F>,
66        adjacency_matrix: &ArrayView2<F>,
67    ) -> Result<f64>
68    where
69        F: std::iter::Sum + std::fmt::Debug,
70    {
71        let n = predictions.len();
72        if n != ground_truth.len() || adjacency_matrix.nrows() != n || adjacency_matrix.ncols() != n
73        {
74            return Err(MetricsError::DimensionMismatch(
75                "Predictions, ground truth, and adjacency matrix dimensions must match".to_string(),
76            ));
77        }
78
79        let mut correct_predictions = 0.0;
80        let mut total_predictions = 0.0;
81
82        for i in 0..n {
83            // Weight prediction correctness by node degree
84            let degree = adjacency_matrix.row(i).sum().to_f64().unwrap_or(1.0);
85            let weight = (degree + 1.0).ln(); // Log-weighted by degree
86
87            if (predictions[i] - ground_truth[i]).abs()
88                < F::from(0.5).expect("Failed to convert constant to float")
89            {
90                correct_predictions += weight;
91            }
92            total_predictions += weight;
93        }
94
95        let accuracy = if total_predictions > 0.0 {
96            correct_predictions / total_predictions
97        } else {
98            0.0
99        };
100
101        self.structure_aware_accuracy = accuracy;
102        Ok(accuracy)
103    }
104
105    /// Compute macro and micro F1 scores
106    pub fn compute_f1_scores<F: Float>(
107        &mut self,
108        predictions: &ArrayView1<F>,
109        ground_truth: &ArrayView1<F>,
110        class_labels: &[String],
111    ) -> Result<(f64, f64)>
112    where
113        F: std::iter::Sum + std::fmt::Debug,
114    {
115        let n = predictions.len();
116        if n != ground_truth.len() {
117            return Err(MetricsError::DimensionMismatch(
118                "Predictions and ground truth dimensions must match".to_string(),
119            ));
120        }
121
122        let mut class_metrics = HashMap::new();
123
124        for class_label in class_labels {
125            let mut tp = 0;
126            let mut fp = 0;
127            let mut fn_count = 0;
128            let mut tn = 0;
129
130            for i in 0..n {
131                let pred_class = predictions[i].to_usize().unwrap_or(0);
132                let true_class = ground_truth[i].to_usize().unwrap_or(0);
133                let current_class_idx = class_labels
134                    .iter()
135                    .position(|x| x == class_label)
136                    .unwrap_or(0);
137
138                match (
139                    pred_class == current_class_idx,
140                    true_class == current_class_idx,
141                ) {
142                    (true, true) => tp += 1,
143                    (true, false) => fp += 1,
144                    (false, true) => fn_count += 1,
145                    (false, false) => {
146                        #[allow(unused_assignments)]
147                        {
148                            tn += 1
149                        }
150                    }
151                }
152            }
153
154            let precision = if tp + fp > 0 {
155                tp as f64 / (tp + fp) as f64
156            } else {
157                0.0
158            };
159            let recall = if tp + fn_count > 0 {
160                tp as f64 / (tp + fn_count) as f64
161            } else {
162                0.0
163            };
164            let f1 = if precision + recall > 0.0 {
165                2.0 * precision * recall / (precision + recall)
166            } else {
167                0.0
168            };
169
170            class_metrics.insert(
171                class_label.clone(),
172                ClassMetrics {
173                    precision,
174                    recall,
175                    f1_score: f1,
176                    support: (tp + fn_count),
177                },
178            );
179        }
180
181        // Compute macro F1
182        let macro_f1 = class_metrics
183            .values()
184            .map(|metrics| metrics.f1_score)
185            .sum::<f64>()
186            / class_metrics.len() as f64;
187
188        // Compute micro F1
189        let total_tp: usize = class_metrics
190            .values()
191            .map(|m| (m.recall * m.support as f64) as usize)
192            .sum();
193        let total_support: usize = class_metrics.values().map(|m| m.support).sum();
194        let micro_f1 = if total_support > 0 {
195            total_tp as f64 / total_support as f64
196        } else {
197            0.0
198        };
199
200        self.macro_f1 = macro_f1;
201        self.micro_f1 = micro_f1;
202        self.per_class_metrics = class_metrics;
203
204        Ok((macro_f1, micro_f1))
205    }
206}
207
208/// Node embedding quality metrics
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct NodeEmbeddingMetrics {
211    /// Silhouette score for embeddings
212    pub silhouette_score: f64,
213    /// Intra-cluster cohesion
214    pub intra_cluster_cohesion: f64,
215    /// Inter-cluster separation
216    pub inter_cluster_separation: f64,
217    /// Embedding alignment with graph structure
218    pub structure_alignment: f64,
219    /// Neighborhood preservation score
220    pub neighborhood_preservation: f64,
221}
222
223impl Default for NodeEmbeddingMetrics {
224    fn default() -> Self {
225        Self {
226            silhouette_score: 0.0,
227            intra_cluster_cohesion: 0.0,
228            inter_cluster_separation: 0.0,
229            structure_alignment: 0.0,
230            neighborhood_preservation: 0.0,
231        }
232    }
233}
234
235impl NodeEmbeddingMetrics {
236    pub fn new() -> Self {
237        Self::default()
238    }
239
240    /// Compute embedding quality metrics
241    pub fn compute_embedding_quality<F: Float>(
242        &mut self,
243        embeddings: &ArrayView2<F>,
244        adjacency_matrix: &ArrayView2<F>,
245        node_labels: Option<&ArrayView1<F>>,
246    ) -> Result<()>
247    where
248        F: std::iter::Sum + std::fmt::Debug,
249    {
250        let n_nodes = embeddings.nrows();
251        if adjacency_matrix.nrows() != n_nodes || adjacency_matrix.ncols() != n_nodes {
252            return Err(MetricsError::DimensionMismatch(
253                "Embeddings and adjacency matrix dimensions must match".to_string(),
254            ));
255        }
256
257        // Compute neighborhood preservation
258        self.neighborhood_preservation =
259            self.compute_neighborhood_preservation(embeddings, adjacency_matrix)?;
260
261        // Compute structure alignment
262        self.structure_alignment =
263            self.compute_structure_alignment(embeddings, adjacency_matrix)?;
264
265        // If labels are provided, compute cluster-based metrics
266        if let Some(labels) = node_labels {
267            self.silhouette_score = self.compute_silhouette_score(embeddings, labels)?;
268        }
269
270        Ok(())
271    }
272
273    fn compute_neighborhood_preservation<F: Float>(
274        &self,
275        embeddings: &ArrayView2<F>,
276        adjacency_matrix: &ArrayView2<F>,
277    ) -> Result<f64>
278    where
279        F: std::iter::Sum + std::fmt::Debug,
280    {
281        let n_nodes = embeddings.nrows();
282        let mut preservation_sum = 0.0;
283
284        for i in 0..n_nodes {
285            let neighbors: Vec<usize> = (0..n_nodes)
286                .filter(|&j| i != j && adjacency_matrix[(i, j)] > F::zero())
287                .collect();
288
289            if neighbors.is_empty() {
290                continue;
291            }
292
293            // Compute distances to all other nodes in embedding space
294            let mut distances: Vec<(usize, f64)> = (0..n_nodes)
295                .filter(|&j| i != j)
296                .map(|j| {
297                    let dist = (0..embeddings.ncols())
298                        .map(|k| (embeddings[(i, k)] - embeddings[(j, k)]).powi(2))
299                        .sum::<F>()
300                        .sqrt()
301                        .to_f64()
302                        .unwrap_or(0.0);
303                    (j, dist)
304                })
305                .collect();
306
307            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
308
309            // Count how many graph neighbors are among k-nearest neighbors in embedding space
310            let k = neighbors.len().min(10); // Use top-k neighbors
311            let top_k_nodes: HashSet<usize> =
312                distances.iter().take(k).map(|(idx, _)| *idx).collect();
313            let preserved = neighbors
314                .iter()
315                .filter(|&&n| top_k_nodes.contains(&n))
316                .count();
317
318            preservation_sum += preserved as f64 / neighbors.len() as f64;
319        }
320
321        Ok(preservation_sum / n_nodes as f64)
322    }
323
324    fn compute_structure_alignment<F: Float>(
325        &self,
326        embeddings: &ArrayView2<F>,
327        adjacency_matrix: &ArrayView2<F>,
328    ) -> Result<f64>
329    where
330        F: std::iter::Sum + std::fmt::Debug,
331    {
332        let n_nodes = embeddings.nrows();
333        let mut alignment_sum = 0.0;
334        let mut pair_count = 0;
335
336        for i in 0..n_nodes {
337            for j in (i + 1)..n_nodes {
338                let graph_connected = adjacency_matrix[(i, j)] > F::zero();
339
340                // Compute embedding distance
341                let emb_distance = (0..embeddings.ncols())
342                    .map(|k| (embeddings[(i, k)] - embeddings[(j, k)]).powi(2))
343                    .sum::<F>()
344                    .sqrt()
345                    .to_f64()
346                    .unwrap_or(0.0);
347
348                // Connected nodes should be closer in embedding space
349                if graph_connected {
350                    alignment_sum += (-emb_distance).exp(); // Higher for smaller distances
351                } else {
352                    alignment_sum += emb_distance.min(1.0); // Higher for larger distances
353                }
354                pair_count += 1;
355            }
356        }
357
358        Ok(alignment_sum / pair_count as f64)
359    }
360
361    fn compute_silhouette_score<F: Float>(
362        &self,
363        embeddings: &ArrayView2<F>,
364        labels: &ArrayView1<F>,
365    ) -> Result<f64>
366    where
367        F: std::iter::Sum + std::fmt::Debug,
368    {
369        let n_nodes = embeddings.nrows();
370        if n_nodes != labels.len() {
371            return Err(MetricsError::DimensionMismatch(
372                "Embeddings and labels dimensions must match".to_string(),
373            ));
374        }
375
376        let mut silhouette_sum = 0.0;
377
378        for i in 0..n_nodes {
379            let node_label = labels[i];
380
381            // Compute average distance to nodes in same cluster
382            let same_cluster_distances: Vec<f64> = (0..n_nodes)
383                .filter(|&j| {
384                    i != j
385                        && (labels[j] - node_label).abs()
386                            < F::from(0.1).expect("Failed to convert constant to float")
387                })
388                .map(|j| {
389                    (0..embeddings.ncols())
390                        .map(|k| (embeddings[(i, k)] - embeddings[(j, k)]).powi(2))
391                        .sum::<F>()
392                        .sqrt()
393                        .to_f64()
394                        .unwrap_or(0.0)
395                })
396                .collect();
397
398            let a = if same_cluster_distances.is_empty() {
399                0.0
400            } else {
401                same_cluster_distances.iter().sum::<f64>() / same_cluster_distances.len() as f64
402            };
403
404            // Compute minimum average distance to nodes in other clusters
405            let mut other_labels: Vec<F> = (0..n_nodes)
406                .map(|j| labels[j])
407                .filter(|&label| {
408                    (label - node_label).abs()
409                        >= F::from(0.1).expect("Failed to convert constant to float")
410                })
411                .collect();
412            other_labels.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
413            other_labels.dedup_by(|a, b| {
414                (*a - *b).abs() < F::from(0.1).expect("Failed to convert constant to float")
415            });
416
417            let b = other_labels
418                .iter()
419                .map(|&other_label| {
420                    let distances: Vec<f64> = (0..n_nodes)
421                        .filter(|&j| {
422                            (labels[j] - other_label).abs()
423                                < F::from(0.1).expect("Failed to convert constant to float")
424                        })
425                        .map(|j| {
426                            (0..embeddings.ncols())
427                                .map(|k| (embeddings[(i, k)] - embeddings[(j, k)]).powi(2))
428                                .sum::<F>()
429                                .sqrt()
430                                .to_f64()
431                                .unwrap_or(0.0)
432                        })
433                        .collect();
434
435                    if distances.is_empty() {
436                        f64::INFINITY
437                    } else {
438                        distances.iter().sum::<f64>() / distances.len() as f64
439                    }
440                })
441                .fold(f64::INFINITY, f64::min);
442
443            let silhouette = if a.max(b) > 0.0 {
444                (b - a) / a.max(b)
445            } else {
446                0.0
447            };
448
449            silhouette_sum += silhouette;
450        }
451
452        Ok(silhouette_sum / n_nodes as f64)
453    }
454}
455
456/// Homophily-aware evaluation metrics
457#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct HomophilyAwareMetrics {
459    /// Homophily ratio of the graph
460    pub homophily_ratio: f64,
461    /// Performance on homophilic edges
462    pub homophilic_performance: f64,
463    /// Performance on heterophilic edges
464    pub heterophilic_performance: f64,
465    /// Difference in performance
466    pub performance_gap: f64,
467    /// Local homophily scores
468    pub local_homophily: HashMap<usize, f64>, // node_id -> local homophily
469}
470
471impl Default for HomophilyAwareMetrics {
472    fn default() -> Self {
473        Self {
474            homophily_ratio: 0.0,
475            homophilic_performance: 0.0,
476            heterophilic_performance: 0.0,
477            performance_gap: 0.0,
478            local_homophily: HashMap::new(),
479        }
480    }
481}
482
483impl HomophilyAwareMetrics {
484    pub fn new() -> Self {
485        Self::default()
486    }
487
488    /// Compute homophily-aware metrics
489    pub fn compute_homophily_metrics<F: Float>(
490        &mut self,
491        predictions: &ArrayView1<F>,
492        ground_truth: &ArrayView1<F>,
493        adjacency_matrix: &ArrayView2<F>,
494    ) -> Result<()>
495    where
496        F: std::iter::Sum + std::fmt::Debug,
497    {
498        let n_nodes = predictions.len();
499        if n_nodes != ground_truth.len() || adjacency_matrix.nrows() != n_nodes {
500            return Err(MetricsError::DimensionMismatch(
501                "All inputs must have matching dimensions".to_string(),
502            ));
503        }
504
505        // Compute global homophily ratio
506        self.homophily_ratio = self.compute_global_homophily(ground_truth, adjacency_matrix)?;
507
508        // Compute local homophily for each node
509        self.local_homophily = self.compute_local_homophily(ground_truth, adjacency_matrix)?;
510
511        // Compute performance on homophilic vs heterophilic edges
512        self.compute_edge_performance(predictions, ground_truth, adjacency_matrix)?;
513
514        Ok(())
515    }
516
517    fn compute_global_homophily<F: Float>(
518        &self,
519        labels: &ArrayView1<F>,
520        adjacency_matrix: &ArrayView2<F>,
521    ) -> Result<f64>
522    where
523        F: std::iter::Sum + std::fmt::Debug,
524    {
525        let n_nodes = labels.len();
526        let mut homophilic_edges = 0;
527        let mut total_edges = 0;
528
529        for i in 0..n_nodes {
530            for j in (i + 1)..n_nodes {
531                if adjacency_matrix[(i, j)] > F::zero() {
532                    total_edges += 1;
533                    if (labels[i] - labels[j]).abs()
534                        < F::from(0.1).expect("Failed to convert constant to float")
535                    {
536                        homophilic_edges += 1;
537                    }
538                }
539            }
540        }
541
542        Ok(if total_edges > 0 {
543            homophilic_edges as f64 / total_edges as f64
544        } else {
545            0.0
546        })
547    }
548
549    fn compute_local_homophily<F: Float>(
550        &self,
551        labels: &ArrayView1<F>,
552        adjacency_matrix: &ArrayView2<F>,
553    ) -> Result<HashMap<usize, f64>>
554    where
555        F: std::iter::Sum + std::fmt::Debug,
556    {
557        let n_nodes = labels.len();
558        let mut local_homophily = HashMap::new();
559
560        for i in 0..n_nodes {
561            let neighbors: Vec<usize> = (0..n_nodes)
562                .filter(|&j| i != j && adjacency_matrix[(i, j)] > F::zero())
563                .collect();
564
565            if neighbors.is_empty() {
566                local_homophily.insert(i, 0.0);
567                continue;
568            }
569
570            let same_label_neighbors = neighbors
571                .iter()
572                .filter(|&&j| {
573                    (labels[i] - labels[j]).abs()
574                        < F::from(0.1).expect("Failed to convert constant to float")
575                })
576                .count();
577
578            let homophily = same_label_neighbors as f64 / neighbors.len() as f64;
579            local_homophily.insert(i, homophily);
580        }
581
582        Ok(local_homophily)
583    }
584
585    fn compute_edge_performance<F: Float>(
586        &mut self,
587        predictions: &ArrayView1<F>,
588        ground_truth: &ArrayView1<F>,
589        adjacency_matrix: &ArrayView2<F>,
590    ) -> Result<()>
591    where
592        F: std::iter::Sum + std::fmt::Debug,
593    {
594        let n_nodes = predictions.len();
595        let mut homophilic_correct = 0;
596        let mut homophilic_total = 0;
597        let mut heterophilic_correct = 0;
598        let mut heterophilic_total = 0;
599
600        for i in 0..n_nodes {
601            for j in (i + 1)..n_nodes {
602                if adjacency_matrix[(i, j)] > F::zero() {
603                    let same_true_label = (ground_truth[i] - ground_truth[j]).abs()
604                        < F::from(0.1).expect("Failed to convert constant to float");
605                    let correct_i = (predictions[i] - ground_truth[i]).abs()
606                        < F::from(0.5).expect("Failed to convert constant to float");
607                    let correct_j = (predictions[j] - ground_truth[j]).abs()
608                        < F::from(0.5).expect("Failed to convert constant to float");
609                    let edge_correct = correct_i && correct_j;
610
611                    if same_true_label {
612                        homophilic_total += 1;
613                        if edge_correct {
614                            homophilic_correct += 1;
615                        }
616                    } else {
617                        heterophilic_total += 1;
618                        if edge_correct {
619                            heterophilic_correct += 1;
620                        }
621                    }
622                }
623            }
624        }
625
626        self.homophilic_performance = if homophilic_total > 0 {
627            homophilic_correct as f64 / homophilic_total as f64
628        } else {
629            0.0
630        };
631
632        self.heterophilic_performance = if heterophilic_total > 0 {
633            heterophilic_correct as f64 / heterophilic_total as f64
634        } else {
635            0.0
636        };
637
638        self.performance_gap = (self.homophilic_performance - self.heterophilic_performance).abs();
639
640        Ok(())
641    }
642}
643
644/// Fairness metrics for node-level predictions
645#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct NodeFairnessMetrics {
647    /// Demographic parity difference
648    pub demographic_parity: f64,
649    /// Equalized odds difference
650    pub equalized_odds: f64,
651    /// Individual fairness score
652    pub individual_fairness: f64,
653    /// Group fairness metrics
654    pub group_fairness: HashMap<String, GroupFairnessMetrics>,
655}
656
657impl Default for NodeFairnessMetrics {
658    fn default() -> Self {
659        Self {
660            demographic_parity: 0.0,
661            equalized_odds: 0.0,
662            individual_fairness: 0.0,
663            group_fairness: HashMap::new(),
664        }
665    }
666}
667
668impl NodeFairnessMetrics {
669    pub fn new() -> Self {
670        Self::default()
671    }
672
673    /// Compute fairness metrics
674    pub fn compute_fairness_metrics<F: Float>(
675        &mut self,
676        predictions: &ArrayView1<F>,
677        ground_truth: &ArrayView1<F>,
678        sensitive_attributes: &HashMap<usize, String>,
679    ) -> Result<()>
680    where
681        F: std::iter::Sum + std::fmt::Debug,
682    {
683        // Compute group-specific metrics
684        let mut group_metrics = HashMap::new();
685        let groups: std::collections::HashSet<_> = sensitive_attributes.values().cloned().collect();
686
687        for group in groups {
688            let group_indices: Vec<usize> = sensitive_attributes
689                .iter()
690                .filter(|(_, g)| *g == &group)
691                .map(|(idx, _)| *idx)
692                .collect();
693
694            if group_indices.is_empty() {
695                continue;
696            }
697
698            let mut tp = 0;
699            let mut fp = 0;
700            let mut tn = 0;
701            let mut fn_count = 0;
702
703            for &idx in &group_indices {
704                let pred = predictions[idx].to_f64().unwrap_or(0.0) > 0.5;
705                let truth = ground_truth[idx].to_f64().unwrap_or(0.0) > 0.5;
706
707                match (pred, truth) {
708                    (true, true) => tp += 1,
709                    (true, false) => fp += 1,
710                    (false, false) => tn += 1,
711                    (false, true) => fn_count += 1,
712                }
713            }
714
715            let tpr = if tp + fn_count > 0 {
716                tp as f64 / (tp + fn_count) as f64
717            } else {
718                0.0
719            };
720            let fpr = if fp + tn > 0 {
721                fp as f64 / (fp + tn) as f64
722            } else {
723                0.0
724            };
725            let precision = if tp + fp > 0 {
726                tp as f64 / (tp + fp) as f64
727            } else {
728                0.0
729            };
730            let selection_rate = if group_indices.len() > 0 {
731                group_indices
732                    .iter()
733                    .map(|&idx| {
734                        if predictions[idx].to_f64().unwrap_or(0.0) > 0.5 {
735                            1.0
736                        } else {
737                            0.0
738                        }
739                    })
740                    .sum::<f64>()
741                    / group_indices.len() as f64
742            } else {
743                0.0
744            };
745
746            group_metrics.insert(
747                group.clone(),
748                GroupFairnessMetrics {
749                    tpr,
750                    fpr,
751                    precision,
752                    selection_rate,
753                },
754            );
755        }
756
757        // Compute overall fairness metrics
758        if group_metrics.len() >= 2 {
759            let group_names: Vec<_> = group_metrics.keys().collect();
760            let group1 = group_metrics.get(group_names[0]).expect("Operation failed");
761            let group2 = group_metrics.get(group_names[1]).expect("Operation failed");
762
763            self.demographic_parity = (group1.selection_rate - group2.selection_rate).abs();
764            self.equalized_odds =
765                ((group1.tpr - group2.tpr).abs() + (group1.fpr - group2.fpr).abs()) / 2.0;
766        }
767
768        self.group_fairness = group_metrics;
769        Ok(())
770    }
771}
772
773impl NodeLevelMetrics {
774    /// Create new node-level metrics
775    pub fn new() -> Self {
776        Self {
777            classification_metrics: NodeClassificationMetrics::new(),
778            embedding_metrics: NodeEmbeddingMetrics::new(),
779            homophily_metrics: HomophilyAwareMetrics::new(),
780            fairness_metrics: NodeFairnessMetrics::new(),
781        }
782    }
783}
784
785impl Default for NodeLevelMetrics {
786    fn default() -> Self {
787        Self::new()
788    }
789}
790
791use std::collections::HashSet;