scirs2_graph/embeddings/
deepwalk.rs1use super::core::EmbeddingModel;
4use super::negative_sampling::NegativeSampler;
5use super::random_walk::RandomWalkGenerator;
6use super::types::{DeepWalkConfig, RandomWalk};
7use crate::base::{EdgeWeight, Graph, Node};
8use crate::error::Result;
9use scirs2_core::random::seq::SliceRandom;
10
11pub struct DeepWalk<N: Node> {
13 config: DeepWalkConfig,
14 model: EmbeddingModel<N>,
15 walk_generator: RandomWalkGenerator<N>,
16}
17
18impl<N: Node> DeepWalk<N> {
19 pub fn new(config: DeepWalkConfig) -> Self {
21 DeepWalk {
22 model: EmbeddingModel::new(config.dimensions),
23 config,
24 walk_generator: RandomWalkGenerator::new(),
25 }
26 }
27
28 pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
30 where
31 N: Clone + std::fmt::Debug,
32 E: EdgeWeight,
33 Ix: petgraph::graph::IndexType,
34 {
35 let mut all_walks = Vec::new();
36
37 for node in graph.nodes() {
38 for _ in 0..self.config.num_walks {
39 let walk =
40 self.walk_generator
41 .simple_random_walk(graph, node, self.config.walk_length)?;
42 all_walks.push(walk);
43 }
44 }
45
46 Ok(all_walks)
47 }
48
49 pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
51 where
52 N: Clone + std::fmt::Debug,
53 E: EdgeWeight,
54 Ix: petgraph::graph::IndexType,
55 {
56 let mut rng = scirs2_core::random::rng();
58 self.model.initialize_random(graph, &mut rng);
59
60 let negative_sampler = NegativeSampler::new(graph);
62
63 for epoch in 0..self.config.epochs {
65 let walks = self.generate_walks(graph)?;
67
68 let context_pairs =
70 EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
71
72 let mut shuffled_pairs = context_pairs;
74 shuffled_pairs.shuffle(&mut rng);
75
76 let current_lr =
78 self.config.learning_rate * (1.0 - epoch as f64 / self.config.epochs as f64);
79
80 self.model.train_skip_gram(
81 &shuffled_pairs,
82 &negative_sampler,
83 current_lr,
84 self.config.negative_samples,
85 &mut rng,
86 )?;
87
88 if epoch % 10 == 0 || epoch == self.config.epochs - 1 {
89 println!(
90 "DeepWalk epoch {}/{}, generated {} walks, {} context pairs",
91 epoch + 1,
92 self.config.epochs,
93 walks.len(),
94 shuffled_pairs.len()
95 );
96 }
97 }
98
99 Ok(())
100 }
101
102 pub fn model(&self) -> &EmbeddingModel<N> {
104 &self.model
105 }
106
107 pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
109 &mut self.model
110 }
111}