1use super::negative_sampling::NegativeSampler;
4use super::types::ContextPair;
5use crate::base::{DiGraph, EdgeWeight, Graph, Node};
6use crate::error::{GraphError, Result};
7use scirs2_core::random::Rng;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Embedding {
14 pub vector: Vec<f64>,
16}
17
18impl Embedding {
19 pub fn new(dimensions: usize) -> Self {
21 Embedding {
22 vector: vec![0.0; dimensions],
23 }
24 }
25
26 pub fn random(dimensions: usize, rng: &mut impl Rng) -> Self {
28 let vector: Vec<f64> = (0..dimensions).map(|_| rng.gen_range(-0.5..0.5)).collect();
29 Embedding { vector }
30 }
31
32 pub fn dimensions(&self) -> usize {
34 self.vector.len()
35 }
36
37 pub fn cosine_similarity(&self, other: &Embedding) -> Result<f64> {
39 if self.vector.len() != other.vector.len() {
40 return Err(GraphError::InvalidGraph(
41 "Embeddings must have same dimensions".to_string(),
42 ));
43 }
44
45 let dot_product: f64 = self
46 .vector
47 .iter()
48 .zip(other.vector.iter())
49 .map(|(a, b)| a * b)
50 .sum();
51
52 let norm_a = self.norm();
53 let norm_b = other.norm();
54
55 if norm_a == 0.0 || norm_b == 0.0 {
56 Ok(0.0)
57 } else {
58 Ok(dot_product / (norm_a * norm_b))
59 }
60 }
61
62 pub fn norm(&self) -> f64 {
64 self.vector.iter().map(|x| x * x).sum::<f64>().sqrt()
65 }
66
67 pub fn normalize(&mut self) {
69 let norm = self.norm();
70 if norm > 0.0 {
71 for x in &mut self.vector {
72 *x /= norm;
73 }
74 }
75 }
76
77 pub fn add(&mut self, other: &Embedding) -> Result<()> {
79 if self.vector.len() != other.vector.len() {
80 return Err(GraphError::InvalidGraph(
81 "Embeddings must have same dimensions".to_string(),
82 ));
83 }
84
85 for (a, b) in self.vector.iter_mut().zip(other.vector.iter()) {
86 *a += b;
87 }
88 Ok(())
89 }
90
91 pub fn scale(&mut self, factor: f64) {
93 for x in &mut self.vector {
94 *x *= factor;
95 }
96 }
97
98 pub fn dot_product(&self, other: &Embedding) -> Result<f64> {
100 if self.vector.len() != other.vector.len() {
101 return Err(GraphError::InvalidGraph(
102 "Embeddings must have same dimensions".to_string(),
103 ));
104 }
105
106 let dot: f64 = self
107 .vector
108 .iter()
109 .zip(other.vector.iter())
110 .map(|(a, b)| a * b)
111 .sum();
112 Ok(dot)
113 }
114
115 pub fn sigmoid(x: f64) -> f64 {
117 1.0 / (1.0 + (-x).exp())
118 }
119
120 pub fn update_gradient(&mut self, gradient: &[f64], learning_rate: f64) {
122 for (emb, &grad) in self.vector.iter_mut().zip(gradient.iter()) {
123 *emb -= learning_rate * grad;
124 }
125 }
126}
127
128#[derive(Debug)]
130pub struct EmbeddingModel<N: Node> {
131 pub embeddings: HashMap<N, Embedding>,
133 pub context_embeddings: HashMap<N, Embedding>,
135 pub dimensions: usize,
137}
138
139impl<N: Node> EmbeddingModel<N> {
140 pub fn new(dimensions: usize) -> Self {
142 EmbeddingModel {
143 embeddings: HashMap::new(),
144 context_embeddings: HashMap::new(),
145 dimensions,
146 }
147 }
148
149 pub fn get_embedding(&self, node: &N) -> Option<&Embedding> {
151 self.embeddings.get(node)
152 }
153
154 pub fn set_embedding(&mut self, node: N, embedding: Embedding) -> Result<()> {
156 if embedding.dimensions() != self.dimensions {
157 return Err(GraphError::InvalidGraph(
158 "Embedding dimensions don't match model".to_string(),
159 ));
160 }
161 self.embeddings.insert(node, embedding);
162 Ok(())
163 }
164
165 pub fn initialize_random<E, Ix>(&mut self, graph: &Graph<N, E, Ix>, rng: &mut impl Rng)
167 where
168 N: Clone + std::fmt::Debug,
169 E: EdgeWeight,
170 Ix: petgraph::graph::IndexType,
171 {
172 for node in graph.nodes() {
173 let embedding = Embedding::random(self.dimensions, rng);
174 let context_embedding = Embedding::random(self.dimensions, rng);
175 self.embeddings.insert(node.clone(), embedding);
176 self.context_embeddings
177 .insert(node.clone(), context_embedding);
178 }
179 }
180
181 pub fn initialize_random_digraph<E, Ix>(
183 &mut self,
184 graph: &DiGraph<N, E, Ix>,
185 rng: &mut impl Rng,
186 ) where
187 N: Clone + std::fmt::Debug,
188 E: EdgeWeight,
189 Ix: petgraph::graph::IndexType,
190 {
191 for node in graph.nodes() {
192 let embedding = Embedding::random(self.dimensions, rng);
193 let context_embedding = Embedding::random(self.dimensions, rng);
194 self.embeddings.insert(node.clone(), embedding);
195 self.context_embeddings
196 .insert(node.clone(), context_embedding);
197 }
198 }
199
200 pub fn most_similar(&self, node: &N, k: usize) -> Result<Vec<(N, f64)>>
202 where
203 N: Clone,
204 {
205 let target_embedding = self
206 .embeddings
207 .get(node)
208 .ok_or(GraphError::node_not_found("node"))?;
209
210 let mut similarities = Vec::new();
211
212 for (other_node, other_embedding) in &self.embeddings {
213 if other_node != node {
214 let similarity = target_embedding.cosine_similarity(other_embedding)?;
215 similarities.push((other_node.clone(), similarity));
216 }
217 }
218
219 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
220 similarities.truncate(k);
221
222 Ok(similarities)
223 }
224
225 pub fn generate_context_pairs(
227 walks: &[super::types::RandomWalk<N>],
228 window_size: usize,
229 ) -> Vec<ContextPair<N>>
230 where
231 N: Clone,
232 {
233 let mut pairs = Vec::new();
234
235 for walk in walks {
236 for (i, target) in walk.nodes.iter().enumerate() {
237 let start = i.saturating_sub(window_size);
238 let end = (i + window_size + 1).min(walk.nodes.len());
239
240 for j in start..end {
241 if i != j {
242 pairs.push(ContextPair {
243 target: target.clone(),
244 context: walk.nodes[j].clone(),
245 });
246 }
247 }
248 }
249 }
250
251 pairs
252 }
253
254 pub fn train_skip_gram(
256 &mut self,
257 pairs: &[ContextPair<N>],
258 negative_sampler: &NegativeSampler<N>,
259 learning_rate: f64,
260 negative_samples: usize,
261 rng: &mut impl Rng,
262 ) -> Result<()> {
263 for pair in pairs {
264 let target_emb = self
266 .embeddings
267 .get(&pair.target)
268 .ok_or(GraphError::node_not_found("node"))?
269 .clone();
270 let context_emb = self
271 .context_embeddings
272 .get(&pair.context)
273 .ok_or(GraphError::node_not_found("node"))?
274 .clone();
275
276 let positive_score = target_emb.dot_product(&context_emb)?;
278 let positive_prob = Embedding::sigmoid(positive_score);
279
280 let positive_error = 1.0 - positive_prob;
282 let mut target_gradient = vec![0.0; self.dimensions];
283 let mut context_gradient = vec![0.0; self.dimensions];
284
285 #[allow(clippy::needless_range_loop)]
286 for i in 0..self.dimensions {
287 target_gradient[i] += positive_error * context_emb.vector[i];
288 context_gradient[i] += positive_error * target_emb.vector[i];
289 }
290
291 let exclude_set: HashSet<&N> = [&pair.target, &pair.context].iter().cloned().collect();
293 let negatives = negative_sampler.sample_negatives(negative_samples, &exclude_set, rng);
294
295 for negative in &negatives {
296 if let Some(neg_context_emb) = self.context_embeddings.get(negative) {
297 let negative_score = target_emb.dot_product(neg_context_emb)?;
298 let negative_prob = Embedding::sigmoid(negative_score);
299
300 let negative_error = -negative_prob;
302
303 #[allow(clippy::needless_range_loop)]
304 for i in 0..self.dimensions {
305 target_gradient[i] += negative_error * neg_context_emb.vector[i];
306 }
307 }
308 }
309
310 for negative in &negatives {
312 if let Some(neg_context_emb_mut) = self.context_embeddings.get_mut(negative) {
313 let negative_score = target_emb.dot_product(neg_context_emb_mut)?;
314 let negative_prob = Embedding::sigmoid(negative_score);
315 let negative_error = -negative_prob;
316
317 #[allow(clippy::needless_range_loop)]
318 for i in 0..self.dimensions {
319 let neg_context_grad = negative_error * target_emb.vector[i];
320 neg_context_emb_mut.vector[i] -= learning_rate * neg_context_grad;
321 }
322 }
323 }
324
325 if let Some(target_emb_mut) = self.embeddings.get_mut(&pair.target) {
327 target_emb_mut.update_gradient(&target_gradient, learning_rate);
328 }
329 if let Some(context_emb_mut) = self.context_embeddings.get_mut(&pair.context) {
330 context_emb_mut.update_gradient(&context_gradient, learning_rate);
331 }
332 }
333
334 Ok(())
335 }
336}