Skip to main content

scirs2_graph/embeddings/
node2vec.rs

1//! Node2Vec graph embedding algorithm
2//!
3//! Implements the Node2Vec algorithm from Grover & Leskovec (2016) for learning
4//! continuous feature representations for nodes in networks. Uses biased random
5//! walks with return parameter p and in-out parameter q to explore neighborhoods.
6//!
7//! # References
8//! - Grover, A. & Leskovec, J. (2016). node2vec: Scalable Feature Learning for Networks. KDD 2016.
9
10use super::core::EmbeddingModel;
11use super::negative_sampling::NegativeSampler;
12use super::random_walk::RandomWalkGenerator;
13use super::types::{Node2VecConfig, RandomWalk};
14use crate::base::{DiGraph, EdgeWeight, Graph, Node};
15use crate::error::Result;
16use scirs2_core::random::seq::SliceRandom;
17use scirs2_core::random::RngExt;
18
19/// Node2Vec embedding algorithm
20///
21/// Learns node embeddings using biased second-order random walks followed
22/// by skip-gram optimization with negative sampling.
23pub struct Node2Vec<N: Node> {
24    config: Node2VecConfig,
25    model: EmbeddingModel<N>,
26    walk_generator: RandomWalkGenerator<N>,
27}
28
29impl<N: Node> Node2Vec<N> {
30    /// Create a new Node2Vec instance
31    pub fn new(config: Node2VecConfig) -> Self {
32        Node2Vec {
33            model: EmbeddingModel::new(config.dimensions),
34            config,
35            walk_generator: RandomWalkGenerator::new(),
36        }
37    }
38
39    /// Generate training data (biased random walks) for Node2Vec on undirected graph
40    pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
41    where
42        N: Clone + std::fmt::Debug,
43        E: EdgeWeight + Into<f64>,
44        Ix: petgraph::graph::IndexType,
45    {
46        let mut all_walks = Vec::new();
47
48        for node in graph.nodes() {
49            for _ in 0..self.config.num_walks {
50                let walk = self.walk_generator.node2vec_walk(
51                    graph,
52                    node,
53                    self.config.walk_length,
54                    self.config.p,
55                    self.config.q,
56                )?;
57                all_walks.push(walk);
58            }
59        }
60
61        Ok(all_walks)
62    }
63
64    /// Generate training data (biased random walks) for Node2Vec on directed graph
65    pub fn generate_walks_digraph<E, Ix>(
66        &mut self,
67        graph: &DiGraph<N, E, Ix>,
68    ) -> Result<Vec<RandomWalk<N>>>
69    where
70        N: Clone + std::fmt::Debug,
71        E: EdgeWeight + Into<f64>,
72        Ix: petgraph::graph::IndexType,
73    {
74        let mut all_walks = Vec::new();
75
76        for node in graph.nodes() {
77            for _ in 0..self.config.num_walks {
78                let walk = self.walk_generator.node2vec_walk_digraph(
79                    graph,
80                    node,
81                    self.config.walk_length,
82                    self.config.p,
83                    self.config.q,
84                )?;
85                all_walks.push(walk);
86            }
87        }
88
89        Ok(all_walks)
90    }
91
92    /// Train the Node2Vec model on an undirected graph
93    pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
94    where
95        N: Clone + std::fmt::Debug,
96        E: EdgeWeight + Into<f64>,
97        Ix: petgraph::graph::IndexType,
98    {
99        // Initialize random embeddings
100        let mut rng = scirs2_core::random::rng();
101        self.model.initialize_random(graph, &mut rng);
102
103        // Create negative sampler
104        let negative_sampler = NegativeSampler::new(graph);
105
106        // Training loop over epochs
107        for epoch in 0..self.config.epochs {
108            // Generate walks for this epoch
109            let walks = self.generate_walks(graph)?;
110
111            // Generate context pairs from walks
112            let context_pairs =
113                EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
114
115            // Shuffle pairs for better training
116            let mut shuffled_pairs = context_pairs;
117            shuffled_pairs.shuffle(&mut rng);
118
119            // Train skip-gram model with negative sampling
120            // Linear learning rate decay
121            let current_lr = self.config.learning_rate
122                * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
123
124            self.model.train_skip_gram(
125                &shuffled_pairs,
126                &negative_sampler,
127                current_lr,
128                self.config.negative_samples,
129                &mut rng,
130            )?;
131        }
132
133        Ok(())
134    }
135
136    /// Train the Node2Vec model on a directed graph
137    pub fn train_digraph<E, Ix>(&mut self, graph: &DiGraph<N, E, Ix>) -> Result<()>
138    where
139        N: Clone + std::fmt::Debug,
140        E: EdgeWeight + Into<f64>,
141        Ix: petgraph::graph::IndexType,
142    {
143        // Initialize random embeddings for directed graph
144        let mut rng = scirs2_core::random::rng();
145        self.model.initialize_random_digraph(graph, &mut rng);
146
147        // Create negative sampler from the undirected view
148        // For DiGraph, we build a temporary sampler from node degrees
149        let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
150        let node_degrees: Vec<f64> = nodes.iter().map(|n| graph.degree(n) as f64).collect();
151
152        // Build cumulative distribution for negative sampling
153        let total_degree: f64 = node_degrees.iter().sum();
154        let frequencies: Vec<f64> = node_degrees
155            .iter()
156            .map(|d| (d / total_degree.max(1.0)).powf(0.75))
157            .collect();
158        let total_freq: f64 = frequencies.iter().sum();
159        let normalized: Vec<f64> = frequencies
160            .iter()
161            .map(|f| f / total_freq.max(1e-10))
162            .collect();
163
164        let mut cumulative = vec![0.0; normalized.len()];
165        if !cumulative.is_empty() {
166            cumulative[0] = normalized[0];
167            for i in 1..normalized.len() {
168                cumulative[i] = cumulative[i - 1] + normalized[i];
169            }
170        }
171
172        // Training loop
173        for epoch in 0..self.config.epochs {
174            let walks = self.generate_walks_digraph(graph)?;
175            let context_pairs =
176                EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
177
178            let mut shuffled_pairs = context_pairs;
179            shuffled_pairs.shuffle(&mut rng);
180
181            let current_lr = self.config.learning_rate
182                * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
183
184            // Manual skip-gram training for directed graphs
185            // (since NegativeSampler is built for Graph, not DiGraph)
186            for pair in &shuffled_pairs {
187                self.train_pair_digraph(
188                    pair,
189                    &nodes,
190                    &cumulative,
191                    current_lr,
192                    self.config.negative_samples,
193                    &mut rng,
194                );
195            }
196        }
197
198        Ok(())
199    }
200
201    /// Train on a single context pair for directed graphs
202    fn train_pair_digraph(
203        &mut self,
204        pair: &super::types::ContextPair<N>,
205        nodes: &[N],
206        cumulative: &[f64],
207        learning_rate: f64,
208        num_negative: usize,
209        rng: &mut impl scirs2_core::random::Rng,
210    ) where
211        N: Clone,
212    {
213        let dim = self.config.dimensions;
214
215        // Get target embedding
216        let target_emb = match self.model.embeddings.get(&pair.target) {
217            Some(e) => e.clone(),
218            None => return,
219        };
220
221        // Get context embedding
222        let context_emb = match self.model.context_embeddings.get(&pair.context) {
223            Some(e) => e.clone(),
224            None => return,
225        };
226
227        // Positive sample gradient
228        let dot: f64 = target_emb
229            .vector
230            .iter()
231            .zip(context_emb.vector.iter())
232            .map(|(a, b)| a * b)
233            .sum();
234        let sig = 1.0 / (1.0 + (-dot).exp());
235        let g = learning_rate * (1.0 - sig);
236
237        let mut target_grad = vec![0.0; dim];
238        for d in 0..dim {
239            target_grad[d] += g * context_emb.vector[d];
240        }
241
242        // Update context embedding
243        if let Some(ctx) = self.model.context_embeddings.get_mut(&pair.context) {
244            for d in 0..dim {
245                ctx.vector[d] += g * target_emb.vector[d];
246            }
247        }
248
249        // Negative samples
250        for _ in 0..num_negative {
251            let r = rng.random::<f64>();
252            let neg_idx = cumulative
253                .iter()
254                .position(|&c| r <= c)
255                .unwrap_or(cumulative.len().saturating_sub(1));
256
257            if neg_idx >= nodes.len() {
258                continue;
259            }
260
261            let neg_node = &nodes[neg_idx];
262            if neg_node == &pair.target || neg_node == &pair.context {
263                continue;
264            }
265
266            if let Some(neg_emb) = self.model.context_embeddings.get(neg_node) {
267                let neg_dot: f64 = target_emb
268                    .vector
269                    .iter()
270                    .zip(neg_emb.vector.iter())
271                    .map(|(a, b)| a * b)
272                    .sum();
273                let neg_sig = 1.0 / (1.0 + (-neg_dot).exp());
274                let neg_g = learning_rate * (-neg_sig);
275
276                for d in 0..dim {
277                    target_grad[d] += neg_g * neg_emb.vector[d];
278                }
279
280                // Update negative context
281                if let Some(neg_ctx) = self.model.context_embeddings.get_mut(neg_node) {
282                    for d in 0..dim {
283                        neg_ctx.vector[d] += neg_g * target_emb.vector[d];
284                    }
285                }
286            }
287        }
288
289        // Apply accumulated gradient to target
290        if let Some(target) = self.model.embeddings.get_mut(&pair.target) {
291            for d in 0..dim {
292                target.vector[d] += target_grad[d];
293            }
294        }
295    }
296
297    /// Get the trained model
298    pub fn model(&self) -> &EmbeddingModel<N> {
299        &self.model
300    }
301
302    /// Get mutable reference to the model
303    pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
304        &mut self.model
305    }
306
307    /// Get the configuration
308    pub fn config(&self) -> &Node2VecConfig {
309        &self.config
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    fn make_triangle() -> Graph<i32, f64> {
318        let mut g = Graph::new();
319        for i in 0..3 {
320            g.add_node(i);
321        }
322        let _ = g.add_edge(0, 1, 1.0);
323        let _ = g.add_edge(1, 2, 1.0);
324        let _ = g.add_edge(0, 2, 1.0);
325        g
326    }
327
328    fn make_star_graph() -> Graph<i32, f64> {
329        let mut g = Graph::new();
330        for i in 0..5 {
331            g.add_node(i);
332        }
333        // Node 0 is the center
334        for i in 1..5 {
335            let _ = g.add_edge(0, i, 1.0);
336        }
337        g
338    }
339
340    fn make_directed_chain() -> DiGraph<i32, f64> {
341        let mut g = DiGraph::new();
342        for i in 0..5 {
343            g.add_node(i);
344        }
345        let _ = g.add_edge(0, 1, 1.0);
346        let _ = g.add_edge(1, 2, 1.0);
347        let _ = g.add_edge(2, 3, 1.0);
348        let _ = g.add_edge(3, 4, 1.0);
349        g
350    }
351
352    #[test]
353    fn test_node2vec_train_basic() {
354        let g = make_triangle();
355        let config = Node2VecConfig {
356            dimensions: 8,
357            walk_length: 5,
358            num_walks: 3,
359            window_size: 2,
360            p: 1.0,
361            q: 1.0,
362            epochs: 2,
363            learning_rate: 0.025,
364            negative_samples: 2,
365        };
366
367        let mut n2v = Node2Vec::new(config);
368        let result = n2v.train(&g);
369        assert!(result.is_ok(), "Node2Vec training should succeed");
370
371        // All nodes should have embeddings
372        for node in [0, 1, 2] {
373            assert!(
374                n2v.model().get_embedding(&node).is_some(),
375                "Node {node} should have an embedding"
376            );
377        }
378    }
379
380    #[test]
381    fn test_node2vec_walk_generation() {
382        let g = make_triangle();
383        let config = Node2VecConfig {
384            dimensions: 8,
385            walk_length: 10,
386            num_walks: 2,
387            p: 1.0,
388            q: 1.0,
389            ..Default::default()
390        };
391
392        let mut n2v = Node2Vec::new(config);
393        let walks = n2v.generate_walks(&g);
394        assert!(walks.is_ok());
395
396        let walks = walks.expect("walks should be valid");
397        // 3 nodes * 2 walks per node = 6 walks total
398        assert_eq!(walks.len(), 6);
399
400        // Each walk should have at most walk_length nodes
401        for walk in &walks {
402            assert!(walk.nodes.len() <= 10);
403            assert!(!walk.nodes.is_empty());
404        }
405    }
406
407    #[test]
408    fn test_node2vec_biased_walks() {
409        // With p=0.5 (low), walks should favor returning to previous nodes
410        // With q=2.0 (high), walks should favor local (BFS-like) exploration
411        let g = make_star_graph();
412        let config = Node2VecConfig {
413            dimensions: 8,
414            walk_length: 20,
415            num_walks: 5,
416            p: 0.5,
417            q: 2.0,
418            ..Default::default()
419        };
420
421        let mut n2v = Node2Vec::new(config);
422        let walks = n2v.generate_walks(&g);
423        assert!(walks.is_ok());
424
425        let walks = walks.expect("walks should be valid");
426        assert!(!walks.is_empty());
427
428        // Verify walks contain valid nodes
429        for walk in &walks {
430            for node in &walk.nodes {
431                assert!(
432                    (0..5).contains(node),
433                    "Walk should only contain valid nodes, got {node}"
434                );
435            }
436        }
437    }
438
439    #[test]
440    fn test_node2vec_embedding_similarity() {
441        let g = make_triangle();
442        let config = Node2VecConfig {
443            dimensions: 16,
444            walk_length: 10,
445            num_walks: 10,
446            window_size: 3,
447            p: 1.0,
448            q: 1.0,
449            epochs: 5,
450            learning_rate: 0.05,
451            negative_samples: 3,
452        };
453
454        let mut n2v = Node2Vec::new(config);
455        let _ = n2v.train(&g);
456
457        // In a triangle, all nodes are structurally equivalent
458        // so similarities should be computable (not NaN)
459        let model = n2v.model();
460        let sim_01 = model.most_similar(&0, 2);
461        assert!(sim_01.is_ok());
462
463        let sim_01 = sim_01.expect("similarity should be valid");
464        assert_eq!(sim_01.len(), 2, "Should find 2 most similar nodes");
465
466        for (node, score) in &sim_01 {
467            assert!(
468                score.is_finite(),
469                "Similarity for node {node} should be finite"
470            );
471        }
472    }
473
474    #[test]
475    fn test_node2vec_digraph_train() {
476        let g = make_directed_chain();
477        let config = Node2VecConfig {
478            dimensions: 8,
479            walk_length: 4,
480            num_walks: 3,
481            window_size: 2,
482            p: 1.0,
483            q: 1.0,
484            epochs: 2,
485            learning_rate: 0.025,
486            negative_samples: 2,
487        };
488
489        let mut n2v = Node2Vec::new(config);
490        let result = n2v.train_digraph(&g);
491        assert!(result.is_ok(), "DiGraph Node2Vec training should succeed");
492
493        // All nodes should have embeddings
494        for node in 0..5 {
495            assert!(
496                n2v.model().get_embedding(&node).is_some(),
497                "Node {node} should have an embedding in directed graph"
498            );
499        }
500    }
501
502    #[test]
503    fn test_node2vec_config() {
504        let config = Node2VecConfig::default();
505        assert_eq!(config.dimensions, 128);
506        assert_eq!(config.walk_length, 80);
507        assert_eq!(config.p, 1.0);
508        assert_eq!(config.q, 1.0);
509
510        let n2v: Node2Vec<i32> = Node2Vec::new(config);
511        assert_eq!(n2v.config().dimensions, 128);
512    }
513}