scirs2_graph/embeddings/
node2vec.rs

1//! Node2Vec graph embedding algorithm
2
3use super::core::EmbeddingModel;
4use super::negative_sampling::NegativeSampler;
5use super::random_walk::RandomWalkGenerator;
6use super::types::{Node2VecConfig, RandomWalk};
7use crate::base::{EdgeWeight, Graph, Node};
8use crate::error::Result;
9use scirs2_core::random::seq::SliceRandom;
10
11/// Basic Node2Vec implementation foundation
12pub struct Node2Vec<N: Node> {
13    config: Node2VecConfig,
14    model: EmbeddingModel<N>,
15    walk_generator: RandomWalkGenerator<N>,
16}
17
18impl<N: Node> Node2Vec<N> {
19    /// Create a new Node2Vec instance
20    pub fn new(config: Node2VecConfig) -> Self {
21        Node2Vec {
22            model: EmbeddingModel::new(config.dimensions),
23            config,
24            walk_generator: RandomWalkGenerator::new(),
25        }
26    }
27
28    /// Generate training data (random walks) for Node2Vec
29    pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
30    where
31        N: Clone + std::fmt::Debug,
32        E: EdgeWeight + Into<f64>,
33        Ix: petgraph::graph::IndexType,
34    {
35        let mut all_walks = Vec::new();
36
37        for node in graph.nodes() {
38            for _ in 0..self.config.num_walks {
39                let walk = self.walk_generator.node2vec_walk(
40                    graph,
41                    node,
42                    self.config.walk_length,
43                    self.config.p,
44                    self.config.q,
45                )?;
46                all_walks.push(walk);
47            }
48        }
49
50        Ok(all_walks)
51    }
52
53    /// Train the Node2Vec model with complete skip-gram implementation
54    pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
55    where
56        N: Clone + std::fmt::Debug,
57        E: EdgeWeight + Into<f64>,
58        Ix: petgraph::graph::IndexType,
59    {
60        // Initialize random embeddings
61        let mut rng = scirs2_core::random::rng();
62        self.model.initialize_random(graph, &mut rng);
63
64        // Create negative sampler
65        let negative_sampler = NegativeSampler::new(graph);
66
67        // Training loop over epochs
68        for epoch in 0..self.config.epochs {
69            // Generate walks for this epoch
70            let walks = self.generate_walks(graph)?;
71
72            // Generate context pairs from walks
73            let context_pairs =
74                EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
75
76            // Shuffle pairs for better training
77            let mut shuffled_pairs = context_pairs;
78            shuffled_pairs.shuffle(&mut rng);
79
80            // Train skip-gram model with negative sampling
81            let current_lr =
82                self.config.learning_rate * (1.0 - epoch as f64 / self.config.epochs as f64);
83
84            self.model.train_skip_gram(
85                &shuffled_pairs,
86                &negative_sampler,
87                current_lr,
88                self.config.negative_samples,
89                &mut rng,
90            )?;
91
92            if epoch % 10 == 0 || epoch == self.config.epochs - 1 {
93                println!(
94                    "Node2Vec epoch {}/{}, generated {} walks, {} context pairs",
95                    epoch + 1,
96                    self.config.epochs,
97                    walks.len(),
98                    shuffled_pairs.len()
99                );
100            }
101        }
102
103        Ok(())
104    }
105
106    /// Get the trained model
107    pub fn model(&self) -> &EmbeddingModel<N> {
108        &self.model
109    }
110
111    /// Get mutable reference to the model
112    pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
113        &mut self.model
114    }
115}