Skip to main content

scirs2_graph/hypergraph/
edge_prediction.rs

1//! Hyperedge Prediction module.
2//!
3//! Implements scoring and evaluation utilities for the hyperedge prediction task:
4//! given a set of nodes, predict the probability that they form a hyperedge.
5//!
6//! ## Approach
7//!
8//! 1. **Pooling**: aggregate node features of candidate hyperedge members
9//! 2. **Scoring MLP**: map pooled features through a small MLP to a scalar score
10//! 3. **Sigmoid**: convert score to probability
11//!
12//! ## Evaluation
13//!
14//! - ROC-AUC computed using the trapezoidal rule
15//! - Negative sampling: random k-subsets of the node set
16
17use crate::error::{GraphError, Result};
18use scirs2_core::ndarray::Array2;
19use scirs2_core::random::{Rng, RngExt, SeedableRng};
20
21// ============================================================================
22// Linear layer helper (same pattern as other modules)
23// ============================================================================
24
25#[derive(Debug, Clone)]
26struct Linear {
27    weight: Vec<Vec<f64>>,
28    bias: Vec<f64>,
29    out_dim: usize,
30}
31
32impl Linear {
33    fn new(in_dim: usize, out_dim: usize) -> Self {
34        let scale = (2.0 / in_dim as f64).sqrt();
35        let mut rng = scirs2_core::random::rng();
36        let weight: Vec<Vec<f64>> = (0..out_dim)
37            .map(|_| {
38                (0..in_dim)
39                    .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
40                    .collect()
41            })
42            .collect();
43        Linear {
44            weight,
45            bias: vec![0.0; out_dim],
46            out_dim,
47        }
48    }
49
50    fn forward(&self, x: &[f64]) -> Vec<f64> {
51        let mut out = self.bias.clone();
52        for (i, row) in self.weight.iter().enumerate() {
53            for (j, &w) in row.iter().enumerate() {
54                out[i] += w * x[j];
55            }
56        }
57        out
58    }
59}
60
61// ============================================================================
62// Pooling Types
63// ============================================================================
64
65/// Pooling method for aggregating node features within a hyperedge candidate.
66#[derive(Debug, Clone, PartialEq, Default)]
67#[non_exhaustive]
68pub enum PoolingType {
69    /// Sum of node feature vectors.
70    Sum,
71    /// Element-wise mean.
72    #[default]
73    Mean,
74    /// Element-wise maximum.
75    Max,
76}
77
78impl PoolingType {
79    /// Aggregate node features using this pooling method.
80    fn pool(&self, node_feats: &Array2<f64>, nodes: &[usize]) -> Vec<f64> {
81        if nodes.is_empty() {
82            return vec![0.0; node_feats.ncols()];
83        }
84        let d = node_feats.ncols();
85        match self {
86            PoolingType::Sum => {
87                let mut out = vec![0.0_f64; d];
88                for &i in nodes {
89                    for k in 0..d {
90                        out[k] += node_feats[[i, k]];
91                    }
92                }
93                out
94            }
95            PoolingType::Mean => {
96                let mut out = vec![0.0_f64; d];
97                let inv_n = 1.0 / nodes.len() as f64;
98                for &i in nodes {
99                    for k in 0..d {
100                        out[k] += node_feats[[i, k]] * inv_n;
101                    }
102                }
103                out
104            }
105            PoolingType::Max => {
106                let mut out = vec![f64::NEG_INFINITY; d];
107                for &i in nodes {
108                    for k in 0..d {
109                        if node_feats[[i, k]] > out[k] {
110                            out[k] = node_feats[[i, k]];
111                        }
112                    }
113                }
114                // Replace -inf with 0 for nodes with no features
115                for v in out.iter_mut() {
116                    if *v == f64::NEG_INFINITY {
117                        *v = 0.0;
118                    }
119                }
120                out
121            }
122        }
123    }
124}
125
126// ============================================================================
127// Configuration
128// ============================================================================
129
130/// Configuration for the HyperedgePredictor.
131#[derive(Debug, Clone)]
132#[non_exhaustive]
133pub struct HyperedgePredictorConfig {
134    /// Hidden dimension of the scoring MLP.
135    pub hidden_dim: usize,
136    /// Pooling method for aggregating node features.
137    pub pooling: PoolingType,
138    /// Number of hidden layers in the scoring MLP (not counting the output layer).
139    pub n_hidden_layers: usize,
140}
141
142impl Default for HyperedgePredictorConfig {
143    fn default() -> Self {
144        HyperedgePredictorConfig {
145            hidden_dim: 64,
146            pooling: PoolingType::Mean,
147            n_hidden_layers: 2,
148        }
149    }
150}
151
152// ============================================================================
153// HyperedgePredictor
154// ============================================================================
155
156/// Hyperedge predictor: scores candidate node sets as potential hyperedges.
157///
158/// Architecture:
159/// ```text
160/// pool(node_feats[candidate]) → MLP → sigmoid → probability
161/// ```
162#[derive(Debug, Clone)]
163pub struct HyperedgePredictor {
164    /// MLP layers.
165    layers: Vec<Linear>,
166    /// Input feature dimension (per-node).
167    in_dim: usize,
168    /// Configuration.
169    config: HyperedgePredictorConfig,
170}
171
172impl HyperedgePredictor {
173    /// Create a new HyperedgePredictor.
174    ///
175    /// # Arguments
176    /// - `in_dim`: node feature dimension
177    /// - `config`: predictor configuration
178    pub fn new(in_dim: usize, config: HyperedgePredictorConfig) -> Self {
179        let h = config.hidden_dim;
180        let mut layers = Vec::new();
181        // Input layer: in_dim → hidden
182        layers.push(Linear::new(in_dim, h));
183        // Hidden layers
184        for _ in 1..config.n_hidden_layers {
185            layers.push(Linear::new(h, h));
186        }
187        // Output layer: hidden → 1
188        layers.push(Linear::new(h, 1));
189        HyperedgePredictor {
190            layers,
191            in_dim,
192            config,
193        }
194    }
195
196    /// Score a single candidate hyperedge (set of node indices).
197    ///
198    /// # Arguments
199    /// - `node_feats`: all node features [N × in_dim]
200    /// - `candidate`: indices of nodes in the candidate hyperedge
201    ///
202    /// # Returns
203    /// Probability in [0, 1] that the candidate is a real hyperedge.
204    pub fn score(&self, node_feats: &Array2<f64>, candidate: &[usize]) -> Result<f64> {
205        if candidate.is_empty() {
206            return Err(GraphError::InvalidParameter {
207                param: "candidate".to_string(),
208                value: "empty".to_string(),
209                expected: "non-empty set of node indices".to_string(),
210                context: "HyperedgePredictor::score".to_string(),
211            });
212        }
213        if node_feats.ncols() != self.in_dim {
214            return Err(GraphError::InvalidParameter {
215                param: "node_feats".to_string(),
216                value: format!("ncols={}", node_feats.ncols()),
217                expected: format!("ncols={}", self.in_dim),
218                context: "HyperedgePredictor::score".to_string(),
219            });
220        }
221        for &i in candidate {
222            if i >= node_feats.nrows() {
223                return Err(GraphError::InvalidParameter {
224                    param: "candidate".to_string(),
225                    value: format!("node {i}"),
226                    expected: format!("< {}", node_feats.nrows()),
227                    context: "HyperedgePredictor::score".to_string(),
228                });
229            }
230        }
231
232        // Pool node features
233        let pooled = self.config.pooling.pool(node_feats, candidate);
234
235        // MLP forward pass
236        let mut h = pooled;
237        for (i, layer) in self.layers.iter().enumerate() {
238            h = layer.forward(&h);
239            if i < self.layers.len() - 1 {
240                // SiLU activation
241                for v in h.iter_mut() {
242                    *v = *v / (1.0 + (-*v).exp());
243                }
244            }
245        }
246
247        // Sigmoid output
248        let logit = h[0];
249        let prob = 1.0 / (1.0 + (-logit).exp());
250        Ok(prob)
251    }
252
253    /// Score a batch of candidate hyperedges.
254    ///
255    /// # Arguments
256    /// - `node_feats`: all node features [N × in_dim]
257    /// - `candidates`: list of candidate hyperedges (each is a list of node indices)
258    ///
259    /// # Returns
260    /// Vector of probabilities in [0, 1].
261    pub fn predict_batch(
262        &self,
263        node_feats: &Array2<f64>,
264        candidates: &[Vec<usize>],
265    ) -> Result<Vec<f64>> {
266        candidates
267            .iter()
268            .map(|c| self.score(node_feats, c))
269            .collect()
270    }
271}
272
273// ============================================================================
274// Negative sampling
275// ============================================================================
276
277/// Generate random negative hyperedge samples.
278///
279/// For each positive hyperedge, generates `n_neg_per_pos` random k-subsets of
280/// the node set, ensuring the generated set differs from the positive hyperedge.
281///
282/// # Arguments
283/// - `positives`: known positive hyperedges
284/// - `n_nodes`: total number of nodes
285/// - `n_neg_per_pos`: number of negatives per positive
286///
287/// # Returns
288/// List of negative hyperedge candidates.
289pub fn generate_negatives(
290    positives: &[Vec<usize>],
291    n_nodes: usize,
292    n_neg_per_pos: usize,
293) -> Vec<Vec<usize>> {
294    if positives.is_empty() || n_nodes == 0 {
295        return Vec::new();
296    }
297
298    let mut rng = scirs2_core::random::seeded_rng(42u64);
299    let mut negatives = Vec::new();
300
301    // Build a set of all positive hyperedges for fast lookup
302    use std::collections::HashSet;
303    let pos_set: HashSet<Vec<usize>> = positives
304        .iter()
305        .map(|p| {
306            let mut sorted = p.clone();
307            sorted.sort();
308            sorted
309        })
310        .collect();
311
312    for pos in positives {
313        let k = pos.len();
314        if k == 0 || k > n_nodes {
315            continue;
316        }
317
318        let mut generated = 0;
319        let mut attempts = 0;
320        while generated < n_neg_per_pos && attempts < 1000 {
321            attempts += 1;
322            // Sample k unique nodes without replacement
323            let mut candidate: Vec<usize> = (0..n_nodes).collect();
324            // Fisher-Yates partial shuffle for k elements
325            for i in 0..k {
326                let j = i + (rng.random::<f64>() * (n_nodes - i) as f64) as usize;
327                let j = j.min(n_nodes - 1);
328                candidate.swap(i, j);
329            }
330            let mut neg: Vec<usize> = candidate[..k].to_vec();
331            neg.sort();
332
333            // Check it's not a known positive
334            if !pos_set.contains(&neg) {
335                negatives.push(neg);
336                generated += 1;
337            }
338        }
339    }
340
341    negatives
342}
343
344// ============================================================================
345// ROC-AUC computation
346// ============================================================================
347
348/// Compute the ROC-AUC (Area Under the ROC Curve) using the trapezoidal rule.
349///
350/// # Arguments
351/// - `labels`: ground truth labels (true = positive, false = negative)
352/// - `scores`: predicted scores / probabilities (higher = more likely positive)
353///
354/// # Returns
355/// AUC value in [0, 1]. Returns 0.5 for degenerate inputs.
356pub fn roc_auc(labels: &[bool], scores: &[f64]) -> f64 {
357    assert_eq!(
358        labels.len(),
359        scores.len(),
360        "labels and scores must have equal length"
361    );
362    if labels.is_empty() {
363        return 0.5;
364    }
365
366    let n_pos = labels.iter().filter(|&&l| l).count();
367    let n_neg = labels.len() - n_pos;
368    if n_pos == 0 || n_neg == 0 {
369        return 0.5;
370    }
371
372    // Sort by score descending
373    let mut indices: Vec<usize> = (0..labels.len()).collect();
374    indices.sort_by(|&a, &b| {
375        scores[b]
376            .partial_cmp(&scores[a])
377            .unwrap_or(std::cmp::Ordering::Equal)
378    });
379
380    // Compute ROC curve points (FPR, TPR) using sorted scores
381    let mut tpr_points = vec![0.0_f64];
382    let mut fpr_points = vec![0.0_f64];
383    let mut tp = 0usize;
384    let mut fp = 0usize;
385
386    for &i in &indices {
387        if labels[i] {
388            tp += 1;
389        } else {
390            fp += 1;
391        }
392        let tpr = tp as f64 / n_pos as f64;
393        let fpr = fp as f64 / n_neg as f64;
394        tpr_points.push(tpr);
395        fpr_points.push(fpr);
396    }
397
398    // Trapezoidal rule: AUC = sum of trapezoids
399    let mut auc = 0.0_f64;
400    for i in 1..fpr_points.len() {
401        let dfpr = fpr_points[i] - fpr_points[i - 1];
402        let avg_tpr = (tpr_points[i] + tpr_points[i - 1]) / 2.0;
403        auc += dfpr * avg_tpr;
404    }
405
406    auc.clamp(0.0, 1.0)
407}
408
409// ============================================================================
410// Tests
411// ============================================================================
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use scirs2_core::ndarray::Array2;
417
418    fn make_feats(n: usize, d: usize) -> Array2<f64> {
419        let data: Vec<f64> = (0..n * d).map(|i| (i as f64 + 1.0) * 0.1).collect();
420        Array2::from_shape_vec((n, d), data).expect("feats")
421    }
422
423    #[test]
424    fn test_predictor_score_in_unit_interval() {
425        let config = HyperedgePredictorConfig {
426            hidden_dim: 8,
427            ..Default::default()
428        };
429        let predictor = HyperedgePredictor::new(4, config);
430        let feats = make_feats(5, 4);
431        let candidate = vec![0, 1, 2];
432        let score = predictor.score(&feats, &candidate).expect("score");
433        assert!(
434            (0.0..=1.0).contains(&score),
435            "score must be in [0,1], got {score}"
436        );
437    }
438
439    #[test]
440    fn test_predictor_batch_all_in_unit_interval() {
441        let config = HyperedgePredictorConfig {
442            hidden_dim: 8,
443            ..Default::default()
444        };
445        let predictor = HyperedgePredictor::new(4, config);
446        let feats = make_feats(6, 4);
447        let candidates = vec![vec![0, 1], vec![1, 2, 3], vec![3, 4, 5], vec![0, 2, 4]];
448        let scores = predictor.predict_batch(&feats, &candidates).expect("batch");
449        for s in &scores {
450            assert!(*s >= 0.0 && *s <= 1.0, "score {s} not in [0,1]");
451        }
452        assert_eq!(scores.len(), 4);
453    }
454
455    #[test]
456    fn test_generate_negatives_differ_from_positives() {
457        let positives = vec![vec![0, 1, 2], vec![3, 4, 5]];
458        let negatives = generate_negatives(&positives, 8, 3);
459        // Check that none of the negatives are in positives
460        use std::collections::HashSet;
461        let pos_set: HashSet<Vec<usize>> = positives.iter().cloned().collect();
462        for neg in &negatives {
463            let mut sorted = neg.clone();
464            sorted.sort();
465            assert!(
466                !pos_set.contains(&sorted),
467                "negative {:?} should not match a positive",
468                neg
469            );
470        }
471    }
472
473    #[test]
474    fn test_generate_negatives_count() {
475        let positives = vec![vec![0, 1, 2], vec![3, 4, 5]];
476        let negatives = generate_negatives(&positives, 20, 5);
477        // Up to 2 positives × 5 negatives each = 10 negatives (may be fewer if hard to sample)
478        assert!(negatives.len() <= 10 + 5, "too many negatives generated");
479        assert!(!negatives.is_empty(), "some negatives should be generated");
480    }
481
482    #[test]
483    fn test_roc_auc_perfect() {
484        // Perfect predictor: positive scores all higher than negative scores
485        let labels = vec![true, true, true, false, false, false];
486        let scores = vec![0.9, 0.8, 0.7, 0.3, 0.2, 0.1];
487        let auc = roc_auc(&labels, &scores);
488        assert!(
489            (auc - 1.0).abs() < 1e-10,
490            "perfect AUC should be 1.0, got {auc}"
491        );
492    }
493
494    #[test]
495    fn test_roc_auc_worst() {
496        // Worst predictor: all negative scores higher than positive scores
497        let labels = vec![true, true, true, false, false, false];
498        let scores = vec![0.1, 0.2, 0.3, 0.7, 0.8, 0.9];
499        let auc = roc_auc(&labels, &scores);
500        assert!(auc < 0.1, "worst AUC should be ~0.0, got {auc}");
501    }
502
503    #[test]
504    fn test_roc_auc_random_approx_half() {
505        // Uninformative predictor: scores are random w.r.t. labels
506        // With fixed labels and scores, AUC should be close to 0.5
507        let labels = vec![
508            true, false, true, false, true, false, true, false, true, false,
509        ];
510        let scores = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
511        let auc = roc_auc(&labels, &scores);
512        // With all equal scores, AUC depends on tie-breaking → ≈ 0.5
513        assert!(
514            (0.0..=1.0).contains(&auc),
515            "AUC must be in [0,1], got {auc}"
516        );
517    }
518
519    #[test]
520    fn test_pooling_mean() {
521        let feats = make_feats(4, 3);
522        let pooled = PoolingType::Mean.pool(&feats, &[0, 1, 2]);
523        assert_eq!(pooled.len(), 3);
524        // Mean of rows 0,1,2 for column 0: (0.1+0.4+0.7)/3
525        let expected_col0 = (feats[[0, 0]] + feats[[1, 0]] + feats[[2, 0]]) / 3.0;
526        assert!((pooled[0] - expected_col0).abs() < 1e-12);
527    }
528
529    #[test]
530    fn test_pooling_sum() {
531        let feats = make_feats(4, 3);
532        let pooled = PoolingType::Sum.pool(&feats, &[0, 1]);
533        let expected = feats[[0, 0]] + feats[[1, 0]];
534        assert!((pooled[0] - expected).abs() < 1e-12);
535    }
536
537    #[test]
538    fn test_pooling_max() {
539        let feats = make_feats(4, 3);
540        let pooled = PoolingType::Max.pool(&feats, &[0, 1, 2]);
541        // Max of rows 0,1,2 for column 0: max(0.1, 0.4, 0.7) = 0.7
542        assert!((pooled[0] - feats[[2, 0]]).abs() < 1e-12);
543    }
544
545    #[test]
546    fn test_predictor_empty_candidate_error() {
547        let config = HyperedgePredictorConfig::default();
548        let predictor = HyperedgePredictor::new(4, config);
549        let feats = make_feats(5, 4);
550        let result = predictor.score(&feats, &[]);
551        assert!(result.is_err(), "empty candidate should return error");
552    }
553}