scirs2_graph/embeddings/
negative_sampling.rs

1//! Negative sampling strategies for graph embeddings
2
3use super::types::NegativeSamplingStrategy;
4use crate::base::{EdgeWeight, Graph, Node};
5use scirs2_core::random::Rng;
6use std::collections::HashSet;
7
8/// Negative sampling configuration
9#[derive(Debug, Clone)]
10pub struct NegativeSampler<N: Node> {
11    /// Vocabulary (all nodes)
12    vocabulary: Vec<N>,
13    /// Frequency distribution for sampling
14    #[allow(dead_code)]
15    frequencies: Vec<f64>,
16    /// Cumulative distribution for fast sampling
17    cumulative: Vec<f64>,
18}
19
20impl<N: Node> NegativeSampler<N> {
21    /// Create a new negative sampler from graph
22    pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Self
23    where
24        N: Clone + std::fmt::Debug,
25        E: EdgeWeight,
26        Ix: petgraph::graph::IndexType,
27    {
28        let vocabulary: Vec<N> = graph.nodes().into_iter().cloned().collect();
29        let node_degrees = vocabulary
30            .iter()
31            .map(|node| graph.degree(node) as f64)
32            .collect::<Vec<_>>();
33
34        // Use subsampling with power 0.75 as in Word2Vec
35        let total_degree: f64 = node_degrees.iter().sum();
36        let frequencies: Vec<f64> = node_degrees
37            .iter()
38            .map(|d| (d / total_degree).powf(0.75))
39            .collect();
40
41        let total_freq: f64 = frequencies.iter().sum();
42        let frequencies: Vec<f64> = frequencies.iter().map(|f| f / total_freq).collect();
43
44        // Build cumulative distribution
45        let mut cumulative = vec![0.0; frequencies.len()];
46        cumulative[0] = frequencies[0];
47        for i in 1..frequencies.len() {
48            cumulative[i] = cumulative[i - 1] + frequencies[i];
49        }
50
51        NegativeSampler {
52            vocabulary,
53            frequencies,
54            cumulative,
55        }
56    }
57
58    /// Sample a negative node
59    pub fn sample(&self, rng: &mut impl Rng) -> Option<&N> {
60        if self.vocabulary.is_empty() {
61            return None;
62        }
63
64        let r = rng.random::<f64>();
65        for (i, &cum_freq) in self.cumulative.iter().enumerate() {
66            if r <= cum_freq {
67                return Some(&self.vocabulary[i]);
68            }
69        }
70
71        self.vocabulary.last()
72    }
73
74    /// Sample multiple negative nodes excluding target and context
75    pub fn sample_negatives(
76        &self,
77        count: usize,
78        exclude: &HashSet<&N>,
79        rng: &mut impl Rng,
80    ) -> Vec<N> {
81        let mut negatives = Vec::new();
82        let mut attempts = 0;
83        let max_attempts = count * 10; // Prevent infinite loops
84
85        while negatives.len() < count && attempts < max_attempts {
86            if let Some(candidate) = self.sample(rng) {
87                if !exclude.contains(candidate) {
88                    negatives.push(candidate.clone());
89                }
90            }
91            attempts += 1;
92        }
93
94        negatives
95    }
96}