Skip to main content

scirs2_graph/
link_prediction.rs

1//! Link prediction algorithms for graph analysis
2//!
3//! This module implements various similarity-based link prediction methods
4//! and evaluation metrics. Given a graph, link prediction estimates the
5//! likelihood of a link between two nodes that are not currently connected.
6//!
7//! # Algorithms
8//! - **Common Neighbors**: Count of shared neighbors
9//! - **Jaccard Coefficient**: Common neighbors normalized by union
10//! - **Adamic-Adar Index**: Weighted common neighbors (inverse log degree)
11//! - **Preferential Attachment**: Product of degrees
12//! - **Resource Allocation Index**: Similar to Adamic-Adar with 1/degree
13//! - **Katz Similarity**: Truncated path-based similarity
14//! - **SimRank**: Recursive structural similarity
15//! - **ROC/AUC Evaluation**: Quality assessment for link prediction
16
17use crate::base::{EdgeWeight, Graph, IndexType, Node};
18use crate::error::{GraphError, Result};
19use std::collections::{HashMap, HashSet};
20use std::hash::Hash;
21
22/// A scored node pair for link prediction
23#[derive(Debug, Clone)]
24pub struct LinkScore<N: Node> {
25    /// First node
26    pub node_a: N,
27    /// Second node
28    pub node_b: N,
29    /// Prediction score (higher = more likely to form link)
30    pub score: f64,
31}
32
33/// Result of link prediction evaluation
34#[derive(Debug, Clone)]
35pub struct LinkPredictionEval {
36    /// Area Under the ROC Curve
37    pub auc: f64,
38    /// Average Precision
39    pub average_precision: f64,
40    /// Number of true positive predictions
41    pub true_positives: usize,
42    /// Number of false positive predictions
43    pub false_positives: usize,
44    /// Total positive examples
45    pub total_positives: usize,
46    /// Total negative examples
47    pub total_negatives: usize,
48}
49
50/// Configuration for link prediction
51#[derive(Debug, Clone)]
52pub struct LinkPredictionConfig {
53    /// Maximum number of predictions to return
54    pub max_predictions: usize,
55    /// Minimum score threshold for predictions
56    pub min_score: f64,
57    /// Whether to include self-loops in predictions
58    pub include_self_loops: bool,
59}
60
61impl Default for LinkPredictionConfig {
62    fn default() -> Self {
63        Self {
64            max_predictions: 100,
65            min_score: 0.0,
66            include_self_loops: false,
67        }
68    }
69}
70
71// ============================================================================
72// Common Neighbors
73// ============================================================================
74
75/// Common neighbors score between two nodes.
76///
77/// The score is the number of shared neighbors between nodes u and v.
78/// Higher score suggests higher likelihood of a future link.
79///
80/// Score(u, v) = |N(u) intersection N(v)|
81pub fn common_neighbors_score<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
82where
83    N: Node + Clone + Hash + Eq + std::fmt::Debug,
84    E: EdgeWeight,
85    Ix: IndexType,
86{
87    validate_nodes(graph, u, v)?;
88
89    let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
90    let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
91
92    let common = neighbors_u.intersection(&neighbors_v).count();
93    Ok(common as f64)
94}
95
96/// Compute common neighbors scores for all non-connected node pairs.
97pub fn common_neighbors_all<N, E, Ix>(
98    graph: &Graph<N, E, Ix>,
99    config: &LinkPredictionConfig,
100) -> Vec<LinkScore<N>>
101where
102    N: Node + Clone + Hash + Eq + std::fmt::Debug,
103    E: EdgeWeight,
104    Ix: IndexType,
105{
106    compute_all_scores(graph, config, |g, u, v| {
107        common_neighbors_score(g, u, v).unwrap_or(0.0)
108    })
109}
110
111// ============================================================================
112// Jaccard Coefficient
113// ============================================================================
114
115/// Jaccard coefficient between two nodes.
116///
117/// Score(u, v) = |N(u) intersection N(v)| / |N(u) union N(v)|
118///
119/// Returns 0.0 if both nodes have no neighbors.
120pub fn jaccard_coefficient<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
121where
122    N: Node + Clone + Hash + Eq + std::fmt::Debug,
123    E: EdgeWeight,
124    Ix: IndexType,
125{
126    validate_nodes(graph, u, v)?;
127
128    let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
129    let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
130
131    let intersection = neighbors_u.intersection(&neighbors_v).count();
132    let union = neighbors_u.union(&neighbors_v).count();
133
134    if union == 0 {
135        Ok(0.0)
136    } else {
137        Ok(intersection as f64 / union as f64)
138    }
139}
140
141/// Compute Jaccard coefficient for all non-connected node pairs.
142pub fn jaccard_coefficient_all<N, E, Ix>(
143    graph: &Graph<N, E, Ix>,
144    config: &LinkPredictionConfig,
145) -> Vec<LinkScore<N>>
146where
147    N: Node + Clone + Hash + Eq + std::fmt::Debug,
148    E: EdgeWeight,
149    Ix: IndexType,
150{
151    compute_all_scores(graph, config, |g, u, v| {
152        jaccard_coefficient(g, u, v).unwrap_or(0.0)
153    })
154}
155
156// ============================================================================
157// Adamic-Adar Index
158// ============================================================================
159
160/// Adamic-Adar index between two nodes.
161///
162/// Sums 1/log(degree) for each common neighbor. Nodes with fewer
163/// connections contribute more to the score.
164///
165/// Score(u, v) = sum_{w in N(u) intersection N(v)} 1 / log(|N(w)|)
166pub fn adamic_adar_index<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
167where
168    N: Node + Clone + Hash + Eq + std::fmt::Debug,
169    E: EdgeWeight,
170    Ix: IndexType,
171{
172    validate_nodes(graph, u, v)?;
173
174    let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
175    let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
176
177    let mut score = 0.0;
178    for common in neighbors_u.intersection(&neighbors_v) {
179        let degree = graph.degree(common);
180        if degree > 1 {
181            score += 1.0 / (degree as f64).ln();
182        }
183    }
184
185    Ok(score)
186}
187
188/// Compute Adamic-Adar index for all non-connected node pairs.
189pub fn adamic_adar_all<N, E, Ix>(
190    graph: &Graph<N, E, Ix>,
191    config: &LinkPredictionConfig,
192) -> Vec<LinkScore<N>>
193where
194    N: Node + Clone + Hash + Eq + std::fmt::Debug,
195    E: EdgeWeight,
196    Ix: IndexType,
197{
198    compute_all_scores(graph, config, |g, u, v| {
199        adamic_adar_index(g, u, v).unwrap_or(0.0)
200    })
201}
202
203// ============================================================================
204// Preferential Attachment
205// ============================================================================
206
207/// Preferential attachment score between two nodes.
208///
209/// Based on the Barabasi-Albert model: nodes with more connections
210/// are more likely to form new connections.
211///
212/// Score(u, v) = |N(u)| * |N(v)|
213pub fn preferential_attachment<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
214where
215    N: Node + Clone + Hash + Eq + std::fmt::Debug,
216    E: EdgeWeight,
217    Ix: IndexType,
218{
219    validate_nodes(graph, u, v)?;
220
221    let deg_u = graph.degree(u);
222    let deg_v = graph.degree(v);
223
224    Ok((deg_u * deg_v) as f64)
225}
226
227/// Compute preferential attachment for all non-connected node pairs.
228pub fn preferential_attachment_all<N, E, Ix>(
229    graph: &Graph<N, E, Ix>,
230    config: &LinkPredictionConfig,
231) -> Vec<LinkScore<N>>
232where
233    N: Node + Clone + Hash + Eq + std::fmt::Debug,
234    E: EdgeWeight,
235    Ix: IndexType,
236{
237    compute_all_scores(graph, config, |g, u, v| {
238        preferential_attachment(g, u, v).unwrap_or(0.0)
239    })
240}
241
242// ============================================================================
243// Resource Allocation Index
244// ============================================================================
245
246/// Resource allocation index between two nodes.
247///
248/// Similar to Adamic-Adar but uses 1/degree instead of 1/log(degree).
249/// Proposed by Zhou, Lu, and Zhang (2009).
250///
251/// Score(u, v) = sum_{w in N(u) intersection N(v)} 1 / |N(w)|
252pub fn resource_allocation_index<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<f64>
253where
254    N: Node + Clone + Hash + Eq + std::fmt::Debug,
255    E: EdgeWeight,
256    Ix: IndexType,
257{
258    validate_nodes(graph, u, v)?;
259
260    let neighbors_u: HashSet<N> = graph.neighbors(u)?.into_iter().collect();
261    let neighbors_v: HashSet<N> = graph.neighbors(v)?.into_iter().collect();
262
263    let mut score = 0.0;
264    for common in neighbors_u.intersection(&neighbors_v) {
265        let degree = graph.degree(common);
266        if degree > 0 {
267            score += 1.0 / degree as f64;
268        }
269    }
270
271    Ok(score)
272}
273
274/// Compute resource allocation for all non-connected node pairs.
275pub fn resource_allocation_all<N, E, Ix>(
276    graph: &Graph<N, E, Ix>,
277    config: &LinkPredictionConfig,
278) -> Vec<LinkScore<N>>
279where
280    N: Node + Clone + Hash + Eq + std::fmt::Debug,
281    E: EdgeWeight,
282    Ix: IndexType,
283{
284    compute_all_scores(graph, config, |g, u, v| {
285        resource_allocation_index(g, u, v).unwrap_or(0.0)
286    })
287}
288
289// ============================================================================
290// Katz Similarity (Truncated)
291// ============================================================================
292
293/// Truncated Katz similarity between two nodes.
294///
295/// Considers paths of different lengths weighted by a damping factor beta.
296/// Truncated at `max_path_length` for efficiency.
297///
298/// Score(u, v) = sum_{l=1}^{L} beta^l * |paths_l(u, v)|
299///
300/// # Arguments
301/// * `graph` - The graph
302/// * `u`, `v` - The node pair
303/// * `beta` - Damping factor (typically 0.001 to 0.1)
304/// * `max_path_length` - Maximum path length to consider
305pub fn katz_similarity<N, E, Ix>(
306    graph: &Graph<N, E, Ix>,
307    u: &N,
308    v: &N,
309    beta: f64,
310    max_path_length: usize,
311) -> Result<f64>
312where
313    N: Node + Clone + Hash + Eq + std::fmt::Debug,
314    E: EdgeWeight,
315    Ix: IndexType,
316{
317    validate_nodes(graph, u, v)?;
318
319    if beta <= 0.0 || beta >= 1.0 {
320        return Err(GraphError::InvalidGraph(
321            "Beta must be in (0, 1)".to_string(),
322        ));
323    }
324
325    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
326    let n = nodes.len();
327    let node_to_idx: HashMap<N, usize> = nodes
328        .iter()
329        .enumerate()
330        .map(|(i, n)| (n.clone(), i))
331        .collect();
332
333    let u_idx = node_to_idx
334        .get(u)
335        .ok_or_else(|| GraphError::node_not_found(format!("{u:?}")))?;
336    let v_idx = node_to_idx
337        .get(v)
338        .ok_or_else(|| GraphError::node_not_found(format!("{v:?}")))?;
339
340    // Build adjacency matrix as sparse representation
341    let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
342    for (i, node) in nodes.iter().enumerate() {
343        if let Ok(neighbors) = graph.neighbors(node) {
344            for neighbor in &neighbors {
345                if let Some(&j) = node_to_idx.get(neighbor) {
346                    adj[i].push(j);
347                }
348            }
349        }
350    }
351
352    // Count paths of each length using matrix power approach
353    // paths_l[i] = number of paths of length l from u to node i
354    let mut score = 0.0;
355    let mut current = vec![0.0f64; n];
356    current[*u_idx] = 1.0;
357
358    for l in 1..=max_path_length {
359        let mut next = vec![0.0f64; n];
360        for (i, &count) in current.iter().enumerate() {
361            if count > 0.0 {
362                for &j in &adj[i] {
363                    next[j] += count;
364                }
365            }
366        }
367
368        let beta_l = beta.powi(l as i32);
369        score += beta_l * next[*v_idx];
370        current = next;
371    }
372
373    Ok(score)
374}
375
376/// Compute Katz similarity for all non-connected node pairs.
377pub fn katz_similarity_all<N, E, Ix>(
378    graph: &Graph<N, E, Ix>,
379    beta: f64,
380    max_path_length: usize,
381    config: &LinkPredictionConfig,
382) -> Vec<LinkScore<N>>
383where
384    N: Node + Clone + Hash + Eq + std::fmt::Debug,
385    E: EdgeWeight,
386    Ix: IndexType,
387{
388    compute_all_scores(graph, config, |g, u, v| {
389        katz_similarity(g, u, v, beta, max_path_length).unwrap_or(0.0)
390    })
391}
392
393// ============================================================================
394// SimRank
395// ============================================================================
396
397/// SimRank: structural similarity based on the idea that two nodes are similar
398/// if they are referenced by similar nodes.
399///
400/// Uses iterative computation with damping factor C.
401/// SimRank(u, u) = 1, SimRank(u, v) = C / (|N(u)| * |N(v)|) * sum SimRank(a, b)
402///
403/// # Arguments
404/// * `graph` - The graph
405/// * `decay` - Decay/damping factor (typically 0.8)
406/// * `max_iterations` - Maximum iterations for convergence
407/// * `tolerance` - Convergence tolerance
408pub fn simrank<N, E, Ix>(
409    graph: &Graph<N, E, Ix>,
410    decay: f64,
411    max_iterations: usize,
412    tolerance: f64,
413) -> Result<HashMap<(N, N), f64>>
414where
415    N: Node + Clone + Hash + Eq + std::fmt::Debug,
416    E: EdgeWeight,
417    Ix: IndexType,
418{
419    if decay <= 0.0 || decay > 1.0 {
420        return Err(GraphError::InvalidGraph(
421            "Decay must be in (0, 1]".to_string(),
422        ));
423    }
424
425    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
426    let n = nodes.len();
427    let node_to_idx: HashMap<N, usize> = nodes
428        .iter()
429        .enumerate()
430        .map(|(i, n)| (n.clone(), i))
431        .collect();
432
433    // Build adjacency
434    let mut adj: Vec<Vec<usize>> = vec![vec![]; n];
435    for (i, node) in nodes.iter().enumerate() {
436        if let Ok(neighbors) = graph.neighbors(node) {
437            for neighbor in &neighbors {
438                if let Some(&j) = node_to_idx.get(neighbor) {
439                    adj[i].push(j);
440                }
441            }
442        }
443    }
444
445    // Initialize SimRank: S(a, a) = 1, S(a, b) = 0 for a != b
446    let mut sim = vec![vec![0.0f64; n]; n];
447    for i in 0..n {
448        sim[i][i] = 1.0;
449    }
450
451    // Iterate
452    for _ in 0..max_iterations {
453        let mut new_sim = vec![vec![0.0f64; n]; n];
454        let mut max_diff = 0.0f64;
455
456        for i in 0..n {
457            new_sim[i][i] = 1.0;
458            for j in (i + 1)..n {
459                let deg_i = adj[i].len();
460                let deg_j = adj[j].len();
461
462                if deg_i == 0 || deg_j == 0 {
463                    new_sim[i][j] = 0.0;
464                    new_sim[j][i] = 0.0;
465                    continue;
466                }
467
468                let mut sum = 0.0;
469                for &ni in &adj[i] {
470                    for &nj in &adj[j] {
471                        sum += sim[ni][nj];
472                    }
473                }
474
475                let new_val = decay * sum / (deg_i * deg_j) as f64;
476                new_sim[i][j] = new_val;
477                new_sim[j][i] = new_val;
478
479                let diff = (new_val - sim[i][j]).abs();
480                if diff > max_diff {
481                    max_diff = diff;
482                }
483            }
484        }
485
486        sim = new_sim;
487
488        if max_diff < tolerance {
489            break;
490        }
491    }
492
493    // Convert to HashMap
494    let mut result = HashMap::new();
495    for i in 0..n {
496        for j in i..n {
497            result.insert((nodes[i].clone(), nodes[j].clone()), sim[i][j]);
498            if i != j {
499                result.insert((nodes[j].clone(), nodes[i].clone()), sim[i][j]);
500            }
501        }
502    }
503
504    Ok(result)
505}
506
507/// SimRank score between a specific pair of nodes.
508pub fn simrank_score<N, E, Ix>(
509    graph: &Graph<N, E, Ix>,
510    u: &N,
511    v: &N,
512    decay: f64,
513    max_iterations: usize,
514) -> Result<f64>
515where
516    N: Node + Clone + Hash + Eq + std::fmt::Debug,
517    E: EdgeWeight,
518    Ix: IndexType,
519{
520    let all_scores = simrank(graph, decay, max_iterations, 1e-6)?;
521    all_scores
522        .get(&(u.clone(), v.clone()))
523        .copied()
524        .ok_or_else(|| GraphError::node_not_found(format!("{u:?}")))
525}
526
527// ============================================================================
528// ROC/AUC Evaluation
529// ============================================================================
530
531/// Evaluate link prediction quality using ROC AUC.
532///
533/// Compares predicted scores against known positive (existing) and
534/// negative (non-existing) edges.
535///
536/// # Arguments
537/// * `scores` - Predicted link scores for node pairs
538/// * `positive_edges` - Known positive edges (existing links)
539/// * `negative_edges` - Known negative edges (non-existing links)
540pub fn evaluate_link_prediction<N>(
541    scores: &[LinkScore<N>],
542    positive_edges: &HashSet<(N, N)>,
543    negative_edges: &HashSet<(N, N)>,
544) -> LinkPredictionEval
545where
546    N: Node + Clone + Hash + Eq + std::fmt::Debug,
547{
548    if positive_edges.is_empty() || negative_edges.is_empty() {
549        return LinkPredictionEval {
550            auc: 0.5,
551            average_precision: 0.0,
552            true_positives: 0,
553            false_positives: 0,
554            total_positives: positive_edges.len(),
555            total_negatives: negative_edges.len(),
556        };
557    }
558
559    // Build scored list with labels
560    let mut scored_labels: Vec<(f64, bool)> = Vec::new();
561
562    for score in scores {
563        let pair = (score.node_a.clone(), score.node_b.clone());
564        let reverse_pair = (score.node_b.clone(), score.node_a.clone());
565
566        let is_positive = positive_edges.contains(&pair) || positive_edges.contains(&reverse_pair);
567        let is_negative = negative_edges.contains(&pair) || negative_edges.contains(&reverse_pair);
568
569        if is_positive || is_negative {
570            scored_labels.push((score.score, is_positive));
571        }
572    }
573
574    // Sort by score descending
575    scored_labels.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
576
577    // Compute AUC using the trapezoidal rule
578    let total_positives = scored_labels.iter().filter(|(_, label)| *label).count();
579    let total_negatives = scored_labels.iter().filter(|(_, label)| !*label).count();
580
581    if total_positives == 0 || total_negatives == 0 {
582        return LinkPredictionEval {
583            auc: 0.5,
584            average_precision: 0.0,
585            true_positives: 0,
586            false_positives: 0,
587            total_positives,
588            total_negatives,
589        };
590    }
591
592    let mut auc = 0.0;
593    let mut tp = 0usize;
594    let mut fp = 0usize;
595    let mut prev_fpr = 0.0;
596    let mut prev_tpr = 0.0;
597
598    // Compute average precision
599    let mut ap = 0.0;
600    let mut running_tp = 0;
601
602    for (i, &(_, is_positive)) in scored_labels.iter().enumerate() {
603        if is_positive {
604            tp += 1;
605            running_tp += 1;
606            ap += running_tp as f64 / (i + 1) as f64;
607        } else {
608            fp += 1;
609        }
610
611        let tpr = tp as f64 / total_positives as f64;
612        let fpr = fp as f64 / total_negatives as f64;
613
614        // Trapezoidal rule
615        auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
616        prev_fpr = fpr;
617        prev_tpr = tpr;
618    }
619
620    // Complete the curve to (1, 1)
621    auc += (1.0 - prev_fpr) * (1.0 + prev_tpr) / 2.0;
622
623    let average_precision = if total_positives > 0 {
624        ap / total_positives as f64
625    } else {
626        0.0
627    };
628
629    LinkPredictionEval {
630        auc,
631        average_precision,
632        true_positives: tp,
633        false_positives: fp,
634        total_positives,
635        total_negatives,
636    }
637}
638
639/// Simplified AUC computation: randomly sample positive and negative pairs,
640/// score them, and estimate AUC.
641pub fn compute_auc<N, E, Ix, F>(
642    graph: &Graph<N, E, Ix>,
643    test_edges: &[(N, N)],
644    non_edges: &[(N, N)],
645    score_fn: F,
646) -> f64
647where
648    N: Node + Clone + Hash + Eq + std::fmt::Debug,
649    E: EdgeWeight,
650    Ix: IndexType,
651    F: Fn(&Graph<N, E, Ix>, &N, &N) -> f64,
652{
653    if test_edges.is_empty() || non_edges.is_empty() {
654        return 0.5;
655    }
656
657    let mut n_correct = 0usize;
658    let mut n_tie = 0usize;
659    let mut n_total = 0usize;
660
661    for (pu, pv) in test_edges {
662        let pos_score = score_fn(graph, pu, pv);
663        for (nu, nv) in non_edges {
664            let neg_score = score_fn(graph, nu, nv);
665            n_total += 1;
666            if pos_score > neg_score + 1e-12 {
667                n_correct += 1;
668            } else if (pos_score - neg_score).abs() <= 1e-12 {
669                n_tie += 1;
670            }
671        }
672    }
673
674    if n_total == 0 {
675        return 0.5;
676    }
677
678    (n_correct as f64 + 0.5 * n_tie as f64) / n_total as f64
679}
680
681// ============================================================================
682// Internal helpers
683// ============================================================================
684
685fn validate_nodes<N, E, Ix>(graph: &Graph<N, E, Ix>, u: &N, v: &N) -> Result<()>
686where
687    N: Node + std::fmt::Debug,
688    E: EdgeWeight,
689    Ix: IndexType,
690{
691    if !graph.has_node(u) {
692        return Err(GraphError::node_not_found(format!("{u:?}")));
693    }
694    if !graph.has_node(v) {
695        return Err(GraphError::node_not_found(format!("{v:?}")));
696    }
697    Ok(())
698}
699
700fn compute_all_scores<N, E, Ix, F>(
701    graph: &Graph<N, E, Ix>,
702    config: &LinkPredictionConfig,
703    score_fn: F,
704) -> Vec<LinkScore<N>>
705where
706    N: Node + Clone + Hash + Eq + std::fmt::Debug,
707    E: EdgeWeight,
708    Ix: IndexType,
709    F: Fn(&Graph<N, E, Ix>, &N, &N) -> f64,
710{
711    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
712    let mut scores = Vec::new();
713
714    for (i, u) in nodes.iter().enumerate() {
715        for v in nodes.iter().skip(i + 1) {
716            if !config.include_self_loops && u == v {
717                continue;
718            }
719            // Only predict for non-connected pairs
720            if graph.has_edge(u, v) {
721                continue;
722            }
723
724            let score = score_fn(graph, u, v);
725            if score >= config.min_score {
726                scores.push(LinkScore {
727                    node_a: u.clone(),
728                    node_b: v.clone(),
729                    score,
730                });
731            }
732        }
733    }
734
735    // Sort by score descending
736    scores.sort_by(|a, b| {
737        b.score
738            .partial_cmp(&a.score)
739            .unwrap_or(std::cmp::Ordering::Equal)
740    });
741
742    // Truncate to max
743    scores.truncate(config.max_predictions);
744    scores
745}
746
747// ============================================================================
748// Tests
749// ============================================================================
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754    use crate::error::Result as GraphResult;
755    use crate::generators::create_graph;
756
757    fn build_test_graph() -> Graph<i32, ()> {
758        let mut g = create_graph::<i32, ()>();
759        // Build a small social network:
760        //   0 -- 1 -- 2
761        //   |    |    |
762        //   3 -- 4 -- 5
763        let _ = g.add_edge(0, 1, ());
764        let _ = g.add_edge(1, 2, ());
765        let _ = g.add_edge(0, 3, ());
766        let _ = g.add_edge(1, 4, ());
767        let _ = g.add_edge(2, 5, ());
768        let _ = g.add_edge(3, 4, ());
769        let _ = g.add_edge(4, 5, ());
770        g
771    }
772
773    #[test]
774    fn test_common_neighbors() -> GraphResult<()> {
775        let g = build_test_graph();
776
777        // Nodes 0 and 2 share neighbor 1
778        let score = common_neighbors_score(&g, &0, &2)?;
779        assert!((score - 1.0).abs() < 1e-6);
780
781        // Nodes 0 and 4 share neighbors 1 and 3
782        let score = common_neighbors_score(&g, &0, &4)?;
783        assert!((score - 2.0).abs() < 1e-6);
784
785        // Nodes 0 and 5 share no neighbors
786        let score = common_neighbors_score(&g, &0, &5)?;
787        assert!((score - 0.0).abs() < 1e-6);
788        Ok(())
789    }
790
791    #[test]
792    fn test_jaccard_coefficient() -> GraphResult<()> {
793        let g = build_test_graph();
794
795        // Nodes 0 and 4: intersection={1,3}, union={1,2,3,5} (based on graph)
796        let score = jaccard_coefficient(&g, &0, &4)?;
797        assert!(score > 0.0 && score <= 1.0);
798
799        // Self-similarity should be handled
800        let score = jaccard_coefficient(&g, &0, &0)?;
801        assert!((score - 1.0).abs() < 1e-6);
802        Ok(())
803    }
804
805    #[test]
806    fn test_adamic_adar() -> GraphResult<()> {
807        let g = build_test_graph();
808
809        let score = adamic_adar_index(&g, &0, &4)?;
810        assert!(score > 0.0);
811
812        // Nodes with no common neighbors should have 0 score
813        let score = adamic_adar_index(&g, &0, &5)?;
814        assert!((score - 0.0).abs() < 1e-6);
815        Ok(())
816    }
817
818    #[test]
819    fn test_preferential_attachment() -> GraphResult<()> {
820        let g = build_test_graph();
821
822        // Node 0 has degree 2, node 4 has degree 3
823        let score = preferential_attachment(&g, &0, &4)?;
824        assert!((score - 6.0).abs() < 1e-6);
825
826        // Node 1 has degree 3, node 4 has degree 3
827        let score = preferential_attachment(&g, &1, &4)?;
828        assert!((score - 9.0).abs() < 1e-6);
829        Ok(())
830    }
831
832    #[test]
833    fn test_resource_allocation() -> GraphResult<()> {
834        let g = build_test_graph();
835
836        let score = resource_allocation_index(&g, &0, &4)?;
837        assert!(score > 0.0);
838
839        let score = resource_allocation_index(&g, &0, &5)?;
840        assert!((score - 0.0).abs() < 1e-6);
841        Ok(())
842    }
843
844    #[test]
845    fn test_katz_similarity() -> GraphResult<()> {
846        let g = build_test_graph();
847
848        let score = katz_similarity(&g, &0, &2, 0.05, 3)?;
849        assert!(score > 0.0);
850
851        // Closer nodes should have higher Katz similarity
852        let score_near = katz_similarity(&g, &0, &1, 0.05, 3)?;
853        let score_far = katz_similarity(&g, &0, &5, 0.05, 3)?;
854        assert!(score_near > score_far);
855        Ok(())
856    }
857
858    #[test]
859    fn test_katz_invalid_beta() {
860        let g = build_test_graph();
861        assert!(katz_similarity(&g, &0, &1, 0.0, 3).is_err());
862        assert!(katz_similarity(&g, &0, &1, 1.0, 3).is_err());
863    }
864
865    #[test]
866    fn test_simrank() -> GraphResult<()> {
867        let g = build_test_graph();
868
869        let scores = simrank(&g, 0.8, 10, 1e-4)?;
870
871        // Self-similarity should be 1.0
872        let self_score = scores.get(&(0, 0)).copied().unwrap_or(0.0);
873        assert!((self_score - 1.0).abs() < 1e-6);
874
875        // Structural similarity should be non-negative
876        for &score in scores.values() {
877            assert!(score >= -1e-6);
878        }
879        Ok(())
880    }
881
882    #[test]
883    fn test_simrank_score() -> GraphResult<()> {
884        let g = build_test_graph();
885        let score = simrank_score(&g, &0, &2, 0.8, 10)?;
886        assert!(score >= 0.0);
887        assert!(score <= 1.0);
888        Ok(())
889    }
890
891    #[test]
892    fn test_evaluate_link_prediction() {
893        let scores = vec![
894            LinkScore {
895                node_a: 0,
896                node_b: 1,
897                score: 0.9,
898            },
899            LinkScore {
900                node_a: 0,
901                node_b: 2,
902                score: 0.8,
903            },
904            LinkScore {
905                node_a: 0,
906                node_b: 3,
907                score: 0.3,
908            },
909            LinkScore {
910                node_a: 1,
911                node_b: 3,
912                score: 0.2,
913            },
914        ];
915
916        let mut positives = HashSet::new();
917        positives.insert((0, 1));
918        positives.insert((0, 2));
919
920        let mut negatives = HashSet::new();
921        negatives.insert((0, 3));
922        negatives.insert((1, 3));
923
924        let eval = evaluate_link_prediction(&scores, &positives, &negatives);
925        assert!(eval.auc >= 0.5); // Should be better than random
926        assert!(eval.true_positives > 0);
927    }
928
929    #[test]
930    fn test_compute_auc() -> GraphResult<()> {
931        let g = build_test_graph();
932
933        // Remove edge 0-2 conceptually (it doesn't exist, so 0-4 is already non-edge)
934        // Positive test: nodes that share many neighbors
935        let test_edges = vec![(0, 4)]; // 0 and 4 share neighbors 1,3
936        let non_edges = vec![(0, 5)]; // 0 and 5 share no neighbors
937
938        let auc = compute_auc(&g, &test_edges, &non_edges, |g, u, v| {
939            common_neighbors_score(g, u, v).unwrap_or(0.0)
940        });
941
942        assert!(auc >= 0.5); // Should predict correctly
943        Ok(())
944    }
945
946    #[test]
947    fn test_common_neighbors_all() {
948        let g = build_test_graph();
949        let config = LinkPredictionConfig {
950            max_predictions: 10,
951            min_score: 0.0,
952            include_self_loops: false,
953        };
954
955        let scores = common_neighbors_all(&g, &config);
956        // Should only include non-connected pairs
957        for score in &scores {
958            assert!(!g.has_edge(&score.node_a, &score.node_b));
959        }
960        // Should be sorted by score descending
961        for window in scores.windows(2) {
962            assert!(window[0].score >= window[1].score);
963        }
964    }
965
966    #[test]
967    fn test_invalid_nodes() {
968        let g = build_test_graph();
969        assert!(common_neighbors_score(&g, &0, &99).is_err());
970        assert!(jaccard_coefficient(&g, &99, &0).is_err());
971        assert!(adamic_adar_index(&g, &0, &99).is_err());
972    }
973
974    #[test]
975    fn test_empty_graph_link_prediction() -> GraphResult<()> {
976        let mut g = create_graph::<i32, ()>();
977        let _ = g.add_node(0);
978
979        let config = LinkPredictionConfig::default();
980        let scores = common_neighbors_all(&g, &config);
981        assert!(scores.is_empty());
982        Ok(())
983    }
984
985    #[test]
986    fn test_all_methods_consistency() -> GraphResult<()> {
987        let g = build_test_graph();
988
989        // All methods should return non-negative scores for same pair
990        let cn = common_neighbors_score(&g, &0, &4)?;
991        let jc = jaccard_coefficient(&g, &0, &4)?;
992        let aa = adamic_adar_index(&g, &0, &4)?;
993        let pa = preferential_attachment(&g, &0, &4)?;
994        let ra = resource_allocation_index(&g, &0, &4)?;
995        let kz = katz_similarity(&g, &0, &4, 0.05, 3)?;
996
997        assert!(cn >= 0.0);
998        assert!(jc >= 0.0);
999        assert!(aa >= 0.0);
1000        assert!(pa >= 0.0);
1001        assert!(ra >= 0.0);
1002        assert!(kz >= 0.0);
1003
1004        // Pairs with common neighbors should score > 0 for relevant methods
1005        assert!(cn > 0.0);
1006        assert!(jc > 0.0);
1007        assert!(aa > 0.0);
1008        assert!(ra > 0.0);
1009        Ok(())
1010    }
1011}