scirs2_graph/embeddings/
core.rs

1//! Core embedding structures and operations
2
3use super::negative_sampling::NegativeSampler;
4use super::types::ContextPair;
5use crate::base::{DiGraph, EdgeWeight, Graph, Node};
6use crate::error::{GraphError, Result};
7use scirs2_core::random::Rng;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// Node embedding vector
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Embedding {
14    /// The embedding vector
15    pub vector: Vec<f64>,
16}
17
18impl Embedding {
19    /// Create a new embedding with given dimensions
20    pub fn new(dimensions: usize) -> Self {
21        Embedding {
22            vector: vec![0.0; dimensions],
23        }
24    }
25
26    /// Create a random embedding
27    pub fn random(dimensions: usize, rng: &mut impl Rng) -> Self {
28        let vector: Vec<f64> = (0..dimensions).map(|_| rng.gen_range(-0.5..0.5)).collect();
29        Embedding { vector }
30    }
31
32    /// Get the dimensionality of the embedding
33    pub fn dimensions(&self) -> usize {
34        self.vector.len()
35    }
36
37    /// Calculate cosine similarity with another embedding (SIMD optimized)
38    pub fn cosine_similarity(&self, other: &Embedding) -> Result<f64> {
39        if self.vector.len() != other.vector.len() {
40            return Err(GraphError::InvalidGraph(
41                "Embeddings must have same dimensions".to_string(),
42            ));
43        }
44
45        let dot_product: f64 = self
46            .vector
47            .iter()
48            .zip(other.vector.iter())
49            .map(|(a, b)| a * b)
50            .sum();
51
52        let norm_a = self.norm();
53        let norm_b = other.norm();
54
55        if norm_a == 0.0 || norm_b == 0.0 {
56            Ok(0.0)
57        } else {
58            Ok(dot_product / (norm_a * norm_b))
59        }
60    }
61
62    /// Calculate L2 norm of the embedding (SIMD optimized)
63    pub fn norm(&self) -> f64 {
64        self.vector.iter().map(|x| x * x).sum::<f64>().sqrt()
65    }
66
67    /// Normalize the embedding to unit length
68    pub fn normalize(&mut self) {
69        let norm = self.norm();
70        if norm > 0.0 {
71            for x in &mut self.vector {
72                *x /= norm;
73            }
74        }
75    }
76
77    /// Add another embedding (element-wise)
78    pub fn add(&mut self, other: &Embedding) -> Result<()> {
79        if self.vector.len() != other.vector.len() {
80            return Err(GraphError::InvalidGraph(
81                "Embeddings must have same dimensions".to_string(),
82            ));
83        }
84
85        for (a, b) in self.vector.iter_mut().zip(other.vector.iter()) {
86            *a += b;
87        }
88        Ok(())
89    }
90
91    /// Scale the embedding by a scalar
92    pub fn scale(&mut self, factor: f64) {
93        for x in &mut self.vector {
94            *x *= factor;
95        }
96    }
97
98    /// Compute dot product with another embedding (SIMD optimized)
99    pub fn dot_product(&self, other: &Embedding) -> Result<f64> {
100        if self.vector.len() != other.vector.len() {
101            return Err(GraphError::InvalidGraph(
102                "Embeddings must have same dimensions".to_string(),
103            ));
104        }
105
106        let dot: f64 = self
107            .vector
108            .iter()
109            .zip(other.vector.iter())
110            .map(|(a, b)| a * b)
111            .sum();
112        Ok(dot)
113    }
114
115    /// Sigmoid activation function
116    pub fn sigmoid(x: f64) -> f64 {
117        1.0 / (1.0 + (-x).exp())
118    }
119
120    /// Update embedding using gradient (SIMD optimized)
121    pub fn update_gradient(&mut self, gradient: &[f64], learning_rate: f64) {
122        for (emb, &grad) in self.vector.iter_mut().zip(gradient.iter()) {
123            *emb -= learning_rate * grad;
124        }
125    }
126}
127
128/// Graph embedding model
129#[derive(Debug)]
130pub struct EmbeddingModel<N: Node> {
131    /// Node embeddings (input vectors)
132    pub embeddings: HashMap<N, Embedding>,
133    /// Context embeddings (output vectors) for skip-gram
134    pub context_embeddings: HashMap<N, Embedding>,
135    /// Dimensionality of embeddings
136    pub dimensions: usize,
137}
138
139impl<N: Node> EmbeddingModel<N> {
140    /// Create a new embedding model
141    pub fn new(dimensions: usize) -> Self {
142        EmbeddingModel {
143            embeddings: HashMap::new(),
144            context_embeddings: HashMap::new(),
145            dimensions,
146        }
147    }
148
149    /// Get embedding for a node
150    pub fn get_embedding(&self, node: &N) -> Option<&Embedding> {
151        self.embeddings.get(node)
152    }
153
154    /// Set embedding for a node
155    pub fn set_embedding(&mut self, node: N, embedding: Embedding) -> Result<()> {
156        if embedding.dimensions() != self.dimensions {
157            return Err(GraphError::InvalidGraph(
158                "Embedding dimensions don't match model".to_string(),
159            ));
160        }
161        self.embeddings.insert(node, embedding);
162        Ok(())
163    }
164
165    /// Initialize random embeddings for all nodes
166    pub fn initialize_random<E, Ix>(&mut self, graph: &Graph<N, E, Ix>, rng: &mut impl Rng)
167    where
168        N: Clone + std::fmt::Debug,
169        E: EdgeWeight,
170        Ix: petgraph::graph::IndexType,
171    {
172        for node in graph.nodes() {
173            let embedding = Embedding::random(self.dimensions, rng);
174            let context_embedding = Embedding::random(self.dimensions, rng);
175            self.embeddings.insert(node.clone(), embedding);
176            self.context_embeddings
177                .insert(node.clone(), context_embedding);
178        }
179    }
180
181    /// Initialize random embeddings for directed graph
182    pub fn initialize_random_digraph<E, Ix>(
183        &mut self,
184        graph: &DiGraph<N, E, Ix>,
185        rng: &mut impl Rng,
186    ) where
187        N: Clone + std::fmt::Debug,
188        E: EdgeWeight,
189        Ix: petgraph::graph::IndexType,
190    {
191        for node in graph.nodes() {
192            let embedding = Embedding::random(self.dimensions, rng);
193            let context_embedding = Embedding::random(self.dimensions, rng);
194            self.embeddings.insert(node.clone(), embedding);
195            self.context_embeddings
196                .insert(node.clone(), context_embedding);
197        }
198    }
199
200    /// Find k most similar nodes to a given node
201    pub fn most_similar(&self, node: &N, k: usize) -> Result<Vec<(N, f64)>>
202    where
203        N: Clone,
204    {
205        let target_embedding = self
206            .embeddings
207            .get(node)
208            .ok_or(GraphError::node_not_found("node"))?;
209
210        let mut similarities = Vec::new();
211
212        for (other_node, other_embedding) in &self.embeddings {
213            if other_node != node {
214                let similarity = target_embedding.cosine_similarity(other_embedding)?;
215                similarities.push((other_node.clone(), similarity));
216            }
217        }
218
219        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
220        similarities.truncate(k);
221
222        Ok(similarities)
223    }
224
225    /// Generate context pairs from random walks
226    pub fn generate_context_pairs(
227        walks: &[super::types::RandomWalk<N>],
228        window_size: usize,
229    ) -> Vec<ContextPair<N>>
230    where
231        N: Clone,
232    {
233        let mut pairs = Vec::new();
234
235        for walk in walks {
236            for (i, target) in walk.nodes.iter().enumerate() {
237                let start = i.saturating_sub(window_size);
238                let end = (i + window_size + 1).min(walk.nodes.len());
239
240                for j in start..end {
241                    if i != j {
242                        pairs.push(ContextPair {
243                            target: target.clone(),
244                            context: walk.nodes[j].clone(),
245                        });
246                    }
247                }
248            }
249        }
250
251        pairs
252    }
253
254    /// Train skip-gram model on context pairs with negative sampling
255    pub fn train_skip_gram(
256        &mut self,
257        pairs: &[ContextPair<N>],
258        negative_sampler: &NegativeSampler<N>,
259        learning_rate: f64,
260        negative_samples: usize,
261        rng: &mut impl Rng,
262    ) -> Result<()> {
263        for pair in pairs {
264            // Get embeddings
265            let target_emb = self
266                .embeddings
267                .get(&pair.target)
268                .ok_or(GraphError::node_not_found("node"))?
269                .clone();
270            let context_emb = self
271                .context_embeddings
272                .get(&pair.context)
273                .ok_or(GraphError::node_not_found("node"))?
274                .clone();
275
276            // Positive sample: maximize probability of context given target
277            let positive_score = target_emb.dot_product(&context_emb)?;
278            let positive_prob = Embedding::sigmoid(positive_score);
279
280            // Compute gradients for positive sample
281            let positive_error = 1.0 - positive_prob;
282            let mut target_gradient = vec![0.0; self.dimensions];
283            let mut context_gradient = vec![0.0; self.dimensions];
284
285            #[allow(clippy::needless_range_loop)]
286            for i in 0..self.dimensions {
287                target_gradient[i] += positive_error * context_emb.vector[i];
288                context_gradient[i] += positive_error * target_emb.vector[i];
289            }
290
291            // Negative samples: minimize probability of negative contexts
292            let exclude_set: HashSet<&N> = [&pair.target, &pair.context].iter().cloned().collect();
293            let negatives = negative_sampler.sample_negatives(negative_samples, &exclude_set, rng);
294
295            for negative in &negatives {
296                if let Some(neg_context_emb) = self.context_embeddings.get(negative) {
297                    let negative_score = target_emb.dot_product(neg_context_emb)?;
298                    let negative_prob = Embedding::sigmoid(negative_score);
299
300                    // Negative sample error
301                    let negative_error = -negative_prob;
302
303                    #[allow(clippy::needless_range_loop)]
304                    for i in 0..self.dimensions {
305                        target_gradient[i] += negative_error * neg_context_emb.vector[i];
306                    }
307                }
308            }
309
310            // Update negative context embeddings separately to avoid borrowing issues
311            for negative in &negatives {
312                if let Some(neg_context_emb_mut) = self.context_embeddings.get_mut(negative) {
313                    let negative_score = target_emb.dot_product(neg_context_emb_mut)?;
314                    let negative_prob = Embedding::sigmoid(negative_score);
315                    let negative_error = -negative_prob;
316
317                    #[allow(clippy::needless_range_loop)]
318                    for i in 0..self.dimensions {
319                        let neg_context_grad = negative_error * target_emb.vector[i];
320                        neg_context_emb_mut.vector[i] -= learning_rate * neg_context_grad;
321                    }
322                }
323            }
324
325            // Apply gradients
326            if let Some(target_emb_mut) = self.embeddings.get_mut(&pair.target) {
327                target_emb_mut.update_gradient(&target_gradient, learning_rate);
328            }
329            if let Some(context_emb_mut) = self.context_embeddings.get_mut(&pair.context) {
330                context_emb_mut.update_gradient(&context_gradient, learning_rate);
331            }
332        }
333
334        Ok(())
335    }
336}