scirs2_graph/embeddings/
core.rs

1//! Core embedding structures and operations
2
3use super::negative_sampling::NegativeSampler;
4use super::types::ContextPair;
5use crate::base::{DiGraph, EdgeWeight, Graph, Node};
6use crate::error::{GraphError, Result};
7use scirs2_core::random::Rng;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// Node embedding vector
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Embedding {
14    /// The embedding vector
15    pub vector: Vec<f64>,
16}
17
18impl Embedding {
19    /// Create a new embedding with given dimensions
20    pub fn new(dimensions: usize) -> Self {
21        Embedding {
22            vector: vec![0.0; dimensions],
23        }
24    }
25
26    /// Create a random embedding
27    pub fn random(dimensions: usize, rng: &mut impl Rng) -> Self {
28        let vector: Vec<f64> = (0..dimensions)
29            .map(|_| rng.random_range(-0.5..0.5))
30            .collect();
31        Embedding { vector }
32    }
33
34    /// Get the dimensionality of the embedding
35    pub fn dimensions(&self) -> usize {
36        self.vector.len()
37    }
38
39    /// Calculate cosine similarity with another embedding (SIMD optimized)
40    pub fn cosine_similarity(&self, other: &Embedding) -> Result<f64> {
41        if self.vector.len() != other.vector.len() {
42            return Err(GraphError::InvalidGraph(
43                "Embeddings must have same dimensions".to_string(),
44            ));
45        }
46
47        let dot_product: f64 = self
48            .vector
49            .iter()
50            .zip(other.vector.iter())
51            .map(|(a, b)| a * b)
52            .sum();
53
54        let norm_a = self.norm();
55        let norm_b = other.norm();
56
57        if norm_a == 0.0 || norm_b == 0.0 {
58            Ok(0.0)
59        } else {
60            Ok(dot_product / (norm_a * norm_b))
61        }
62    }
63
64    /// Calculate L2 norm of the embedding (SIMD optimized)
65    pub fn norm(&self) -> f64 {
66        self.vector.iter().map(|x| x * x).sum::<f64>().sqrt()
67    }
68
69    /// Normalize the embedding to unit length
70    pub fn normalize(&mut self) {
71        let norm = self.norm();
72        if norm > 0.0 {
73            for x in &mut self.vector {
74                *x /= norm;
75            }
76        }
77    }
78
79    /// Add another embedding (element-wise)
80    pub fn add(&mut self, other: &Embedding) -> Result<()> {
81        if self.vector.len() != other.vector.len() {
82            return Err(GraphError::InvalidGraph(
83                "Embeddings must have same dimensions".to_string(),
84            ));
85        }
86
87        for (a, b) in self.vector.iter_mut().zip(other.vector.iter()) {
88            *a += b;
89        }
90        Ok(())
91    }
92
93    /// Scale the embedding by a scalar
94    pub fn scale(&mut self, factor: f64) {
95        for x in &mut self.vector {
96            *x *= factor;
97        }
98    }
99
100    /// Compute dot product with another embedding (SIMD optimized)
101    pub fn dot_product(&self, other: &Embedding) -> Result<f64> {
102        if self.vector.len() != other.vector.len() {
103            return Err(GraphError::InvalidGraph(
104                "Embeddings must have same dimensions".to_string(),
105            ));
106        }
107
108        let dot: f64 = self
109            .vector
110            .iter()
111            .zip(other.vector.iter())
112            .map(|(a, b)| a * b)
113            .sum();
114        Ok(dot)
115    }
116
117    /// Sigmoid activation function
118    pub fn sigmoid(x: f64) -> f64 {
119        1.0 / (1.0 + (-x).exp())
120    }
121
122    /// Update embedding using gradient (SIMD optimized)
123    pub fn update_gradient(&mut self, gradient: &[f64], learning_rate: f64) {
124        for (emb, &grad) in self.vector.iter_mut().zip(gradient.iter()) {
125            *emb -= learning_rate * grad;
126        }
127    }
128}
129
130/// Graph embedding model
131#[derive(Debug)]
132pub struct EmbeddingModel<N: Node> {
133    /// Node embeddings (input vectors)
134    pub embeddings: HashMap<N, Embedding>,
135    /// Context embeddings (output vectors) for skip-gram
136    pub context_embeddings: HashMap<N, Embedding>,
137    /// Dimensionality of embeddings
138    pub dimensions: usize,
139}
140
141impl<N: Node> EmbeddingModel<N> {
142    /// Create a new embedding model
143    pub fn new(dimensions: usize) -> Self {
144        EmbeddingModel {
145            embeddings: HashMap::new(),
146            context_embeddings: HashMap::new(),
147            dimensions,
148        }
149    }
150
151    /// Get embedding for a node
152    pub fn get_embedding(&self, node: &N) -> Option<&Embedding> {
153        self.embeddings.get(node)
154    }
155
156    /// Set embedding for a node
157    pub fn set_embedding(&mut self, node: N, embedding: Embedding) -> Result<()> {
158        if embedding.dimensions() != self.dimensions {
159            return Err(GraphError::InvalidGraph(
160                "Embedding dimensions don't match model".to_string(),
161            ));
162        }
163        self.embeddings.insert(node, embedding);
164        Ok(())
165    }
166
167    /// Initialize random embeddings for all nodes
168    pub fn initialize_random<E, Ix>(&mut self, graph: &Graph<N, E, Ix>, rng: &mut impl Rng)
169    where
170        N: Clone + std::fmt::Debug,
171        E: EdgeWeight,
172        Ix: petgraph::graph::IndexType,
173    {
174        for node in graph.nodes() {
175            let embedding = Embedding::random(self.dimensions, rng);
176            let context_embedding = Embedding::random(self.dimensions, rng);
177            self.embeddings.insert(node.clone(), embedding);
178            self.context_embeddings
179                .insert(node.clone(), context_embedding);
180        }
181    }
182
183    /// Initialize random embeddings for directed graph
184    pub fn initialize_random_digraph<E, Ix>(
185        &mut self,
186        graph: &DiGraph<N, E, Ix>,
187        rng: &mut impl Rng,
188    ) where
189        N: Clone + std::fmt::Debug,
190        E: EdgeWeight,
191        Ix: petgraph::graph::IndexType,
192    {
193        for node in graph.nodes() {
194            let embedding = Embedding::random(self.dimensions, rng);
195            let context_embedding = Embedding::random(self.dimensions, rng);
196            self.embeddings.insert(node.clone(), embedding);
197            self.context_embeddings
198                .insert(node.clone(), context_embedding);
199        }
200    }
201
202    /// Find k most similar nodes to a given node
203    pub fn most_similar(&self, node: &N, k: usize) -> Result<Vec<(N, f64)>>
204    where
205        N: Clone,
206    {
207        let target_embedding = self
208            .embeddings
209            .get(node)
210            .ok_or(GraphError::node_not_found("node"))?;
211
212        let mut similarities = Vec::new();
213
214        for (other_node, other_embedding) in &self.embeddings {
215            if other_node != node {
216                let similarity = target_embedding.cosine_similarity(other_embedding)?;
217                similarities.push((other_node.clone(), similarity));
218            }
219        }
220
221        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
222        similarities.truncate(k);
223
224        Ok(similarities)
225    }
226
227    /// Generate context pairs from random walks
228    pub fn generate_context_pairs(
229        walks: &[super::types::RandomWalk<N>],
230        window_size: usize,
231    ) -> Vec<ContextPair<N>>
232    where
233        N: Clone,
234    {
235        let mut pairs = Vec::new();
236
237        for walk in walks {
238            for (i, target) in walk.nodes.iter().enumerate() {
239                let start = i.saturating_sub(window_size);
240                let end = (i + window_size + 1).min(walk.nodes.len());
241
242                for j in start..end {
243                    if i != j {
244                        pairs.push(ContextPair {
245                            target: target.clone(),
246                            context: walk.nodes[j].clone(),
247                        });
248                    }
249                }
250            }
251        }
252
253        pairs
254    }
255
256    /// Train skip-gram model on context pairs with negative sampling
257    pub fn train_skip_gram(
258        &mut self,
259        pairs: &[ContextPair<N>],
260        negative_sampler: &NegativeSampler<N>,
261        learning_rate: f64,
262        negative_samples: usize,
263        rng: &mut impl Rng,
264    ) -> Result<()> {
265        for pair in pairs {
266            // Get embeddings
267            let target_emb = self
268                .embeddings
269                .get(&pair.target)
270                .ok_or(GraphError::node_not_found("node"))?
271                .clone();
272            let context_emb = self
273                .context_embeddings
274                .get(&pair.context)
275                .ok_or(GraphError::node_not_found("node"))?
276                .clone();
277
278            // Positive sample: maximize probability of context given target
279            let positive_score = target_emb.dot_product(&context_emb)?;
280            let positive_prob = Embedding::sigmoid(positive_score);
281
282            // Compute gradients for positive sample
283            let positive_error = 1.0 - positive_prob;
284            let mut target_gradient = vec![0.0; self.dimensions];
285            let mut context_gradient = vec![0.0; self.dimensions];
286
287            #[allow(clippy::needless_range_loop)]
288            for i in 0..self.dimensions {
289                target_gradient[i] += positive_error * context_emb.vector[i];
290                context_gradient[i] += positive_error * target_emb.vector[i];
291            }
292
293            // Negative samples: minimize probability of negative contexts
294            let exclude_set: HashSet<&N> = [&pair.target, &pair.context].iter().cloned().collect();
295            let negatives = negative_sampler.sample_negatives(negative_samples, &exclude_set, rng);
296
297            for negative in &negatives {
298                if let Some(neg_context_emb) = self.context_embeddings.get(negative) {
299                    let negative_score = target_emb.dot_product(neg_context_emb)?;
300                    let negative_prob = Embedding::sigmoid(negative_score);
301
302                    // Negative sample error
303                    let negative_error = -negative_prob;
304
305                    #[allow(clippy::needless_range_loop)]
306                    for i in 0..self.dimensions {
307                        target_gradient[i] += negative_error * neg_context_emb.vector[i];
308                    }
309                }
310            }
311
312            // Update negative context embeddings separately to avoid borrowing issues
313            for negative in &negatives {
314                if let Some(neg_context_emb_mut) = self.context_embeddings.get_mut(negative) {
315                    let negative_score = target_emb.dot_product(neg_context_emb_mut)?;
316                    let negative_prob = Embedding::sigmoid(negative_score);
317                    let negative_error = -negative_prob;
318
319                    #[allow(clippy::needless_range_loop)]
320                    for i in 0..self.dimensions {
321                        let neg_context_grad = negative_error * target_emb.vector[i];
322                        neg_context_emb_mut.vector[i] -= learning_rate * neg_context_grad;
323                    }
324                }
325            }
326
327            // Apply gradients
328            if let Some(target_emb_mut) = self.embeddings.get_mut(&pair.target) {
329                target_emb_mut.update_gradient(&target_gradient, learning_rate);
330            }
331            if let Some(context_emb_mut) = self.context_embeddings.get_mut(&pair.context) {
332                context_emb_mut.update_gradient(&context_gradient, learning_rate);
333            }
334        }
335
336        Ok(())
337    }
338}