scirs2_graph/embeddings/
negative_sampling.rs1use super::types::NegativeSamplingStrategy;
4use crate::base::{EdgeWeight, Graph, Node};
5use scirs2_core::random::Rng;
6use std::collections::HashSet;
7
8#[derive(Debug, Clone)]
10pub struct NegativeSampler<N: Node> {
11 vocabulary: Vec<N>,
13 #[allow(dead_code)]
15 frequencies: Vec<f64>,
16 cumulative: Vec<f64>,
18}
19
20impl<N: Node> NegativeSampler<N> {
21 pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Self
23 where
24 N: Clone + std::fmt::Debug,
25 E: EdgeWeight,
26 Ix: petgraph::graph::IndexType,
27 {
28 let vocabulary: Vec<N> = graph.nodes().into_iter().cloned().collect();
29 let node_degrees = vocabulary
30 .iter()
31 .map(|node| graph.degree(node) as f64)
32 .collect::<Vec<_>>();
33
34 let total_degree: f64 = node_degrees.iter().sum();
36 let frequencies: Vec<f64> = node_degrees
37 .iter()
38 .map(|d| (d / total_degree).powf(0.75))
39 .collect();
40
41 let total_freq: f64 = frequencies.iter().sum();
42 let frequencies: Vec<f64> = frequencies.iter().map(|f| f / total_freq).collect();
43
44 let mut cumulative = vec![0.0; frequencies.len()];
46 cumulative[0] = frequencies[0];
47 for i in 1..frequencies.len() {
48 cumulative[i] = cumulative[i - 1] + frequencies[i];
49 }
50
51 NegativeSampler {
52 vocabulary,
53 frequencies,
54 cumulative,
55 }
56 }
57
58 pub fn sample(&self, rng: &mut impl Rng) -> Option<&N> {
60 if self.vocabulary.is_empty() {
61 return None;
62 }
63
64 let r = rng.random::<f64>();
65 for (i, &cum_freq) in self.cumulative.iter().enumerate() {
66 if r <= cum_freq {
67 return Some(&self.vocabulary[i]);
68 }
69 }
70
71 self.vocabulary.last()
72 }
73
74 pub fn sample_negatives(
76 &self,
77 count: usize,
78 exclude: &HashSet<&N>,
79 rng: &mut impl Rng,
80 ) -> Vec<N> {
81 let mut negatives = Vec::new();
82 let mut attempts = 0;
83 let max_attempts = count * 10; while negatives.len() < count && attempts < max_attempts {
86 if let Some(candidate) = self.sample(rng) {
87 if !exclude.contains(candidate) {
88 negatives.push(candidate.clone());
89 }
90 }
91 attempts += 1;
92 }
93
94 negatives
95 }
96}