1use crate::error::{GraphError, Result};
18use scirs2_core::ndarray::Array2;
19use scirs2_core::random::{Rng, RngExt, SeedableRng};
20
21#[derive(Debug, Clone)]
26struct Linear {
27 weight: Vec<Vec<f64>>,
28 bias: Vec<f64>,
29 out_dim: usize,
30}
31
32impl Linear {
33 fn new(in_dim: usize, out_dim: usize) -> Self {
34 let scale = (2.0 / in_dim as f64).sqrt();
35 let mut rng = scirs2_core::random::rng();
36 let weight: Vec<Vec<f64>> = (0..out_dim)
37 .map(|_| {
38 (0..in_dim)
39 .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
40 .collect()
41 })
42 .collect();
43 Linear {
44 weight,
45 bias: vec![0.0; out_dim],
46 out_dim,
47 }
48 }
49
50 fn forward(&self, x: &[f64]) -> Vec<f64> {
51 let mut out = self.bias.clone();
52 for (i, row) in self.weight.iter().enumerate() {
53 for (j, &w) in row.iter().enumerate() {
54 out[i] += w * x[j];
55 }
56 }
57 out
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Default)]
67#[non_exhaustive]
68pub enum PoolingType {
69 Sum,
71 #[default]
73 Mean,
74 Max,
76}
77
78impl PoolingType {
79 fn pool(&self, node_feats: &Array2<f64>, nodes: &[usize]) -> Vec<f64> {
81 if nodes.is_empty() {
82 return vec![0.0; node_feats.ncols()];
83 }
84 let d = node_feats.ncols();
85 match self {
86 PoolingType::Sum => {
87 let mut out = vec![0.0_f64; d];
88 for &i in nodes {
89 for k in 0..d {
90 out[k] += node_feats[[i, k]];
91 }
92 }
93 out
94 }
95 PoolingType::Mean => {
96 let mut out = vec![0.0_f64; d];
97 let inv_n = 1.0 / nodes.len() as f64;
98 for &i in nodes {
99 for k in 0..d {
100 out[k] += node_feats[[i, k]] * inv_n;
101 }
102 }
103 out
104 }
105 PoolingType::Max => {
106 let mut out = vec![f64::NEG_INFINITY; d];
107 for &i in nodes {
108 for k in 0..d {
109 if node_feats[[i, k]] > out[k] {
110 out[k] = node_feats[[i, k]];
111 }
112 }
113 }
114 for v in out.iter_mut() {
116 if *v == f64::NEG_INFINITY {
117 *v = 0.0;
118 }
119 }
120 out
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
132#[non_exhaustive]
133pub struct HyperedgePredictorConfig {
134 pub hidden_dim: usize,
136 pub pooling: PoolingType,
138 pub n_hidden_layers: usize,
140}
141
142impl Default for HyperedgePredictorConfig {
143 fn default() -> Self {
144 HyperedgePredictorConfig {
145 hidden_dim: 64,
146 pooling: PoolingType::Mean,
147 n_hidden_layers: 2,
148 }
149 }
150}
151
152#[derive(Debug, Clone)]
163pub struct HyperedgePredictor {
164 layers: Vec<Linear>,
166 in_dim: usize,
168 config: HyperedgePredictorConfig,
170}
171
172impl HyperedgePredictor {
173 pub fn new(in_dim: usize, config: HyperedgePredictorConfig) -> Self {
179 let h = config.hidden_dim;
180 let mut layers = Vec::new();
181 layers.push(Linear::new(in_dim, h));
183 for _ in 1..config.n_hidden_layers {
185 layers.push(Linear::new(h, h));
186 }
187 layers.push(Linear::new(h, 1));
189 HyperedgePredictor {
190 layers,
191 in_dim,
192 config,
193 }
194 }
195
196 pub fn score(&self, node_feats: &Array2<f64>, candidate: &[usize]) -> Result<f64> {
205 if candidate.is_empty() {
206 return Err(GraphError::InvalidParameter {
207 param: "candidate".to_string(),
208 value: "empty".to_string(),
209 expected: "non-empty set of node indices".to_string(),
210 context: "HyperedgePredictor::score".to_string(),
211 });
212 }
213 if node_feats.ncols() != self.in_dim {
214 return Err(GraphError::InvalidParameter {
215 param: "node_feats".to_string(),
216 value: format!("ncols={}", node_feats.ncols()),
217 expected: format!("ncols={}", self.in_dim),
218 context: "HyperedgePredictor::score".to_string(),
219 });
220 }
221 for &i in candidate {
222 if i >= node_feats.nrows() {
223 return Err(GraphError::InvalidParameter {
224 param: "candidate".to_string(),
225 value: format!("node {i}"),
226 expected: format!("< {}", node_feats.nrows()),
227 context: "HyperedgePredictor::score".to_string(),
228 });
229 }
230 }
231
232 let pooled = self.config.pooling.pool(node_feats, candidate);
234
235 let mut h = pooled;
237 for (i, layer) in self.layers.iter().enumerate() {
238 h = layer.forward(&h);
239 if i < self.layers.len() - 1 {
240 for v in h.iter_mut() {
242 *v = *v / (1.0 + (-*v).exp());
243 }
244 }
245 }
246
247 let logit = h[0];
249 let prob = 1.0 / (1.0 + (-logit).exp());
250 Ok(prob)
251 }
252
253 pub fn predict_batch(
262 &self,
263 node_feats: &Array2<f64>,
264 candidates: &[Vec<usize>],
265 ) -> Result<Vec<f64>> {
266 candidates
267 .iter()
268 .map(|c| self.score(node_feats, c))
269 .collect()
270 }
271}
272
273pub fn generate_negatives(
290 positives: &[Vec<usize>],
291 n_nodes: usize,
292 n_neg_per_pos: usize,
293) -> Vec<Vec<usize>> {
294 if positives.is_empty() || n_nodes == 0 {
295 return Vec::new();
296 }
297
298 let mut rng = scirs2_core::random::seeded_rng(42u64);
299 let mut negatives = Vec::new();
300
301 use std::collections::HashSet;
303 let pos_set: HashSet<Vec<usize>> = positives
304 .iter()
305 .map(|p| {
306 let mut sorted = p.clone();
307 sorted.sort();
308 sorted
309 })
310 .collect();
311
312 for pos in positives {
313 let k = pos.len();
314 if k == 0 || k > n_nodes {
315 continue;
316 }
317
318 let mut generated = 0;
319 let mut attempts = 0;
320 while generated < n_neg_per_pos && attempts < 1000 {
321 attempts += 1;
322 let mut candidate: Vec<usize> = (0..n_nodes).collect();
324 for i in 0..k {
326 let j = i + (rng.random::<f64>() * (n_nodes - i) as f64) as usize;
327 let j = j.min(n_nodes - 1);
328 candidate.swap(i, j);
329 }
330 let mut neg: Vec<usize> = candidate[..k].to_vec();
331 neg.sort();
332
333 if !pos_set.contains(&neg) {
335 negatives.push(neg);
336 generated += 1;
337 }
338 }
339 }
340
341 negatives
342}
343
344pub fn roc_auc(labels: &[bool], scores: &[f64]) -> f64 {
357 assert_eq!(
358 labels.len(),
359 scores.len(),
360 "labels and scores must have equal length"
361 );
362 if labels.is_empty() {
363 return 0.5;
364 }
365
366 let n_pos = labels.iter().filter(|&&l| l).count();
367 let n_neg = labels.len() - n_pos;
368 if n_pos == 0 || n_neg == 0 {
369 return 0.5;
370 }
371
372 let mut indices: Vec<usize> = (0..labels.len()).collect();
374 indices.sort_by(|&a, &b| {
375 scores[b]
376 .partial_cmp(&scores[a])
377 .unwrap_or(std::cmp::Ordering::Equal)
378 });
379
380 let mut tpr_points = vec![0.0_f64];
382 let mut fpr_points = vec![0.0_f64];
383 let mut tp = 0usize;
384 let mut fp = 0usize;
385
386 for &i in &indices {
387 if labels[i] {
388 tp += 1;
389 } else {
390 fp += 1;
391 }
392 let tpr = tp as f64 / n_pos as f64;
393 let fpr = fp as f64 / n_neg as f64;
394 tpr_points.push(tpr);
395 fpr_points.push(fpr);
396 }
397
398 let mut auc = 0.0_f64;
400 for i in 1..fpr_points.len() {
401 let dfpr = fpr_points[i] - fpr_points[i - 1];
402 let avg_tpr = (tpr_points[i] + tpr_points[i - 1]) / 2.0;
403 auc += dfpr * avg_tpr;
404 }
405
406 auc.clamp(0.0, 1.0)
407}
408
409#[cfg(test)]
414mod tests {
415 use super::*;
416 use scirs2_core::ndarray::Array2;
417
418 fn make_feats(n: usize, d: usize) -> Array2<f64> {
419 let data: Vec<f64> = (0..n * d).map(|i| (i as f64 + 1.0) * 0.1).collect();
420 Array2::from_shape_vec((n, d), data).expect("feats")
421 }
422
423 #[test]
424 fn test_predictor_score_in_unit_interval() {
425 let config = HyperedgePredictorConfig {
426 hidden_dim: 8,
427 ..Default::default()
428 };
429 let predictor = HyperedgePredictor::new(4, config);
430 let feats = make_feats(5, 4);
431 let candidate = vec![0, 1, 2];
432 let score = predictor.score(&feats, &candidate).expect("score");
433 assert!(
434 (0.0..=1.0).contains(&score),
435 "score must be in [0,1], got {score}"
436 );
437 }
438
439 #[test]
440 fn test_predictor_batch_all_in_unit_interval() {
441 let config = HyperedgePredictorConfig {
442 hidden_dim: 8,
443 ..Default::default()
444 };
445 let predictor = HyperedgePredictor::new(4, config);
446 let feats = make_feats(6, 4);
447 let candidates = vec![vec![0, 1], vec![1, 2, 3], vec![3, 4, 5], vec![0, 2, 4]];
448 let scores = predictor.predict_batch(&feats, &candidates).expect("batch");
449 for s in &scores {
450 assert!(*s >= 0.0 && *s <= 1.0, "score {s} not in [0,1]");
451 }
452 assert_eq!(scores.len(), 4);
453 }
454
455 #[test]
456 fn test_generate_negatives_differ_from_positives() {
457 let positives = vec![vec![0, 1, 2], vec![3, 4, 5]];
458 let negatives = generate_negatives(&positives, 8, 3);
459 use std::collections::HashSet;
461 let pos_set: HashSet<Vec<usize>> = positives.iter().cloned().collect();
462 for neg in &negatives {
463 let mut sorted = neg.clone();
464 sorted.sort();
465 assert!(
466 !pos_set.contains(&sorted),
467 "negative {:?} should not match a positive",
468 neg
469 );
470 }
471 }
472
473 #[test]
474 fn test_generate_negatives_count() {
475 let positives = vec![vec![0, 1, 2], vec![3, 4, 5]];
476 let negatives = generate_negatives(&positives, 20, 5);
477 assert!(negatives.len() <= 10 + 5, "too many negatives generated");
479 assert!(!negatives.is_empty(), "some negatives should be generated");
480 }
481
482 #[test]
483 fn test_roc_auc_perfect() {
484 let labels = vec![true, true, true, false, false, false];
486 let scores = vec![0.9, 0.8, 0.7, 0.3, 0.2, 0.1];
487 let auc = roc_auc(&labels, &scores);
488 assert!(
489 (auc - 1.0).abs() < 1e-10,
490 "perfect AUC should be 1.0, got {auc}"
491 );
492 }
493
494 #[test]
495 fn test_roc_auc_worst() {
496 let labels = vec![true, true, true, false, false, false];
498 let scores = vec![0.1, 0.2, 0.3, 0.7, 0.8, 0.9];
499 let auc = roc_auc(&labels, &scores);
500 assert!(auc < 0.1, "worst AUC should be ~0.0, got {auc}");
501 }
502
503 #[test]
504 fn test_roc_auc_random_approx_half() {
505 let labels = vec![
508 true, false, true, false, true, false, true, false, true, false,
509 ];
510 let scores = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
511 let auc = roc_auc(&labels, &scores);
512 assert!(
514 (0.0..=1.0).contains(&auc),
515 "AUC must be in [0,1], got {auc}"
516 );
517 }
518
519 #[test]
520 fn test_pooling_mean() {
521 let feats = make_feats(4, 3);
522 let pooled = PoolingType::Mean.pool(&feats, &[0, 1, 2]);
523 assert_eq!(pooled.len(), 3);
524 let expected_col0 = (feats[[0, 0]] + feats[[1, 0]] + feats[[2, 0]]) / 3.0;
526 assert!((pooled[0] - expected_col0).abs() < 1e-12);
527 }
528
529 #[test]
530 fn test_pooling_sum() {
531 let feats = make_feats(4, 3);
532 let pooled = PoolingType::Sum.pool(&feats, &[0, 1]);
533 let expected = feats[[0, 0]] + feats[[1, 0]];
534 assert!((pooled[0] - expected).abs() < 1e-12);
535 }
536
537 #[test]
538 fn test_pooling_max() {
539 let feats = make_feats(4, 3);
540 let pooled = PoolingType::Max.pool(&feats, &[0, 1, 2]);
541 assert!((pooled[0] - feats[[2, 0]]).abs() < 1e-12);
543 }
544
545 #[test]
546 fn test_predictor_empty_candidate_error() {
547 let config = HyperedgePredictorConfig::default();
548 let predictor = HyperedgePredictor::new(4, config);
549 let feats = make_feats(5, 4);
550 let result = predictor.score(&feats, &[]);
551 assert!(result.is_err(), "empty candidate should return error");
552 }
553}