scirs2_graph/embeddings/
node2vec.rs1use super::core::EmbeddingModel;
4use super::negative_sampling::NegativeSampler;
5use super::random_walk::RandomWalkGenerator;
6use super::types::{Node2VecConfig, RandomWalk};
7use crate::base::{EdgeWeight, Graph, Node};
8use crate::error::Result;
9use scirs2_core::random::seq::SliceRandom;
10
11pub struct Node2Vec<N: Node> {
13 config: Node2VecConfig,
14 model: EmbeddingModel<N>,
15 walk_generator: RandomWalkGenerator<N>,
16}
17
18impl<N: Node> Node2Vec<N> {
19 pub fn new(config: Node2VecConfig) -> Self {
21 Node2Vec {
22 model: EmbeddingModel::new(config.dimensions),
23 config,
24 walk_generator: RandomWalkGenerator::new(),
25 }
26 }
27
28 pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
30 where
31 N: Clone + std::fmt::Debug,
32 E: EdgeWeight + Into<f64>,
33 Ix: petgraph::graph::IndexType,
34 {
35 let mut all_walks = Vec::new();
36
37 for node in graph.nodes() {
38 for _ in 0..self.config.num_walks {
39 let walk = self.walk_generator.node2vec_walk(
40 graph,
41 node,
42 self.config.walk_length,
43 self.config.p,
44 self.config.q,
45 )?;
46 all_walks.push(walk);
47 }
48 }
49
50 Ok(all_walks)
51 }
52
53 pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
55 where
56 N: Clone + std::fmt::Debug,
57 E: EdgeWeight + Into<f64>,
58 Ix: petgraph::graph::IndexType,
59 {
60 let mut rng = scirs2_core::random::rng();
62 self.model.initialize_random(graph, &mut rng);
63
64 let negative_sampler = NegativeSampler::new(graph);
66
67 for epoch in 0..self.config.epochs {
69 let walks = self.generate_walks(graph)?;
71
72 let context_pairs =
74 EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
75
76 let mut shuffled_pairs = context_pairs;
78 shuffled_pairs.shuffle(&mut rng);
79
80 let current_lr =
82 self.config.learning_rate * (1.0 - epoch as f64 / self.config.epochs as f64);
83
84 self.model.train_skip_gram(
85 &shuffled_pairs,
86 &negative_sampler,
87 current_lr,
88 self.config.negative_samples,
89 &mut rng,
90 )?;
91
92 if epoch % 10 == 0 || epoch == self.config.epochs - 1 {
93 println!(
94 "Node2Vec epoch {}/{}, generated {} walks, {} context pairs",
95 epoch + 1,
96 self.config.epochs,
97 walks.len(),
98 shuffled_pairs.len()
99 );
100 }
101 }
102
103 Ok(())
104 }
105
106 pub fn model(&self) -> &EmbeddingModel<N> {
108 &self.model
109 }
110
111 pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
113 &mut self.model
114 }
115}