1use super::negative_sampling::NegativeSampler;
4use super::types::ContextPair;
5use crate::base::{DiGraph, EdgeWeight, Graph, Node};
6use crate::error::{GraphError, Result};
7use scirs2_core::random::Rng;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Embedding {
14 pub vector: Vec<f64>,
16}
17
18impl Embedding {
19 pub fn new(dimensions: usize) -> Self {
21 Embedding {
22 vector: vec![0.0; dimensions],
23 }
24 }
25
26 pub fn random(dimensions: usize, rng: &mut impl Rng) -> Self {
28 let vector: Vec<f64> = (0..dimensions)
29 .map(|_| rng.random_range(-0.5..0.5))
30 .collect();
31 Embedding { vector }
32 }
33
34 pub fn dimensions(&self) -> usize {
36 self.vector.len()
37 }
38
39 pub fn cosine_similarity(&self, other: &Embedding) -> Result<f64> {
41 if self.vector.len() != other.vector.len() {
42 return Err(GraphError::InvalidGraph(
43 "Embeddings must have same dimensions".to_string(),
44 ));
45 }
46
47 let dot_product: f64 = self
48 .vector
49 .iter()
50 .zip(other.vector.iter())
51 .map(|(a, b)| a * b)
52 .sum();
53
54 let norm_a = self.norm();
55 let norm_b = other.norm();
56
57 if norm_a == 0.0 || norm_b == 0.0 {
58 Ok(0.0)
59 } else {
60 Ok(dot_product / (norm_a * norm_b))
61 }
62 }
63
64 pub fn norm(&self) -> f64 {
66 self.vector.iter().map(|x| x * x).sum::<f64>().sqrt()
67 }
68
69 pub fn normalize(&mut self) {
71 let norm = self.norm();
72 if norm > 0.0 {
73 for x in &mut self.vector {
74 *x /= norm;
75 }
76 }
77 }
78
79 pub fn add(&mut self, other: &Embedding) -> Result<()> {
81 if self.vector.len() != other.vector.len() {
82 return Err(GraphError::InvalidGraph(
83 "Embeddings must have same dimensions".to_string(),
84 ));
85 }
86
87 for (a, b) in self.vector.iter_mut().zip(other.vector.iter()) {
88 *a += b;
89 }
90 Ok(())
91 }
92
93 pub fn scale(&mut self, factor: f64) {
95 for x in &mut self.vector {
96 *x *= factor;
97 }
98 }
99
100 pub fn dot_product(&self, other: &Embedding) -> Result<f64> {
102 if self.vector.len() != other.vector.len() {
103 return Err(GraphError::InvalidGraph(
104 "Embeddings must have same dimensions".to_string(),
105 ));
106 }
107
108 let dot: f64 = self
109 .vector
110 .iter()
111 .zip(other.vector.iter())
112 .map(|(a, b)| a * b)
113 .sum();
114 Ok(dot)
115 }
116
117 pub fn sigmoid(x: f64) -> f64 {
119 1.0 / (1.0 + (-x).exp())
120 }
121
122 pub fn update_gradient(&mut self, gradient: &[f64], learning_rate: f64) {
124 for (emb, &grad) in self.vector.iter_mut().zip(gradient.iter()) {
125 *emb -= learning_rate * grad;
126 }
127 }
128}
129
130#[derive(Debug)]
132pub struct EmbeddingModel<N: Node> {
133 pub embeddings: HashMap<N, Embedding>,
135 pub context_embeddings: HashMap<N, Embedding>,
137 pub dimensions: usize,
139}
140
141impl<N: Node> EmbeddingModel<N> {
142 pub fn new(dimensions: usize) -> Self {
144 EmbeddingModel {
145 embeddings: HashMap::new(),
146 context_embeddings: HashMap::new(),
147 dimensions,
148 }
149 }
150
151 pub fn get_embedding(&self, node: &N) -> Option<&Embedding> {
153 self.embeddings.get(node)
154 }
155
156 pub fn set_embedding(&mut self, node: N, embedding: Embedding) -> Result<()> {
158 if embedding.dimensions() != self.dimensions {
159 return Err(GraphError::InvalidGraph(
160 "Embedding dimensions don't match model".to_string(),
161 ));
162 }
163 self.embeddings.insert(node, embedding);
164 Ok(())
165 }
166
167 pub fn initialize_random<E, Ix>(&mut self, graph: &Graph<N, E, Ix>, rng: &mut impl Rng)
169 where
170 N: Clone + std::fmt::Debug,
171 E: EdgeWeight,
172 Ix: petgraph::graph::IndexType,
173 {
174 for node in graph.nodes() {
175 let embedding = Embedding::random(self.dimensions, rng);
176 let context_embedding = Embedding::random(self.dimensions, rng);
177 self.embeddings.insert(node.clone(), embedding);
178 self.context_embeddings
179 .insert(node.clone(), context_embedding);
180 }
181 }
182
183 pub fn initialize_random_digraph<E, Ix>(
185 &mut self,
186 graph: &DiGraph<N, E, Ix>,
187 rng: &mut impl Rng,
188 ) where
189 N: Clone + std::fmt::Debug,
190 E: EdgeWeight,
191 Ix: petgraph::graph::IndexType,
192 {
193 for node in graph.nodes() {
194 let embedding = Embedding::random(self.dimensions, rng);
195 let context_embedding = Embedding::random(self.dimensions, rng);
196 self.embeddings.insert(node.clone(), embedding);
197 self.context_embeddings
198 .insert(node.clone(), context_embedding);
199 }
200 }
201
202 pub fn most_similar(&self, node: &N, k: usize) -> Result<Vec<(N, f64)>>
204 where
205 N: Clone,
206 {
207 let target_embedding = self
208 .embeddings
209 .get(node)
210 .ok_or(GraphError::node_not_found("node"))?;
211
212 let mut similarities = Vec::new();
213
214 for (other_node, other_embedding) in &self.embeddings {
215 if other_node != node {
216 let similarity = target_embedding.cosine_similarity(other_embedding)?;
217 similarities.push((other_node.clone(), similarity));
218 }
219 }
220
221 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
222 similarities.truncate(k);
223
224 Ok(similarities)
225 }
226
227 pub fn generate_context_pairs(
229 walks: &[super::types::RandomWalk<N>],
230 window_size: usize,
231 ) -> Vec<ContextPair<N>>
232 where
233 N: Clone,
234 {
235 let mut pairs = Vec::new();
236
237 for walk in walks {
238 for (i, target) in walk.nodes.iter().enumerate() {
239 let start = i.saturating_sub(window_size);
240 let end = (i + window_size + 1).min(walk.nodes.len());
241
242 for j in start..end {
243 if i != j {
244 pairs.push(ContextPair {
245 target: target.clone(),
246 context: walk.nodes[j].clone(),
247 });
248 }
249 }
250 }
251 }
252
253 pairs
254 }
255
256 pub fn train_skip_gram(
258 &mut self,
259 pairs: &[ContextPair<N>],
260 negative_sampler: &NegativeSampler<N>,
261 learning_rate: f64,
262 negative_samples: usize,
263 rng: &mut impl Rng,
264 ) -> Result<()> {
265 for pair in pairs {
266 let target_emb = self
268 .embeddings
269 .get(&pair.target)
270 .ok_or(GraphError::node_not_found("node"))?
271 .clone();
272 let context_emb = self
273 .context_embeddings
274 .get(&pair.context)
275 .ok_or(GraphError::node_not_found("node"))?
276 .clone();
277
278 let positive_score = target_emb.dot_product(&context_emb)?;
280 let positive_prob = Embedding::sigmoid(positive_score);
281
282 let positive_error = 1.0 - positive_prob;
284 let mut target_gradient = vec![0.0; self.dimensions];
285 let mut context_gradient = vec![0.0; self.dimensions];
286
287 #[allow(clippy::needless_range_loop)]
288 for i in 0..self.dimensions {
289 target_gradient[i] += positive_error * context_emb.vector[i];
290 context_gradient[i] += positive_error * target_emb.vector[i];
291 }
292
293 let exclude_set: HashSet<&N> = [&pair.target, &pair.context].iter().cloned().collect();
295 let negatives = negative_sampler.sample_negatives(negative_samples, &exclude_set, rng);
296
297 for negative in &negatives {
298 if let Some(neg_context_emb) = self.context_embeddings.get(negative) {
299 let negative_score = target_emb.dot_product(neg_context_emb)?;
300 let negative_prob = Embedding::sigmoid(negative_score);
301
302 let negative_error = -negative_prob;
304
305 #[allow(clippy::needless_range_loop)]
306 for i in 0..self.dimensions {
307 target_gradient[i] += negative_error * neg_context_emb.vector[i];
308 }
309 }
310 }
311
312 for negative in &negatives {
314 if let Some(neg_context_emb_mut) = self.context_embeddings.get_mut(negative) {
315 let negative_score = target_emb.dot_product(neg_context_emb_mut)?;
316 let negative_prob = Embedding::sigmoid(negative_score);
317 let negative_error = -negative_prob;
318
319 #[allow(clippy::needless_range_loop)]
320 for i in 0..self.dimensions {
321 let neg_context_grad = negative_error * target_emb.vector[i];
322 neg_context_emb_mut.vector[i] -= learning_rate * neg_context_grad;
323 }
324 }
325 }
326
327 if let Some(target_emb_mut) = self.embeddings.get_mut(&pair.target) {
329 target_emb_mut.update_gradient(&target_gradient, learning_rate);
330 }
331 if let Some(context_emb_mut) = self.context_embeddings.get_mut(&pair.context) {
332 context_emb_mut.update_gradient(&context_gradient, learning_rate);
333 }
334 }
335
336 Ok(())
337 }
338}