scirs2_graph/embeddings/
deepwalk.rs

1//! DeepWalk graph embedding algorithm
2
3use super::core::EmbeddingModel;
4use super::negative_sampling::NegativeSampler;
5use super::random_walk::RandomWalkGenerator;
6use super::types::{DeepWalkConfig, RandomWalk};
7use crate::base::{EdgeWeight, Graph, Node};
8use crate::error::Result;
9use scirs2_core::random::seq::SliceRandom;
10
11/// Basic DeepWalk implementation foundation
12pub struct DeepWalk<N: Node> {
13    config: DeepWalkConfig,
14    model: EmbeddingModel<N>,
15    walk_generator: RandomWalkGenerator<N>,
16}
17
18impl<N: Node> DeepWalk<N> {
19    /// Create a new DeepWalk instance
20    pub fn new(config: DeepWalkConfig) -> Self {
21        DeepWalk {
22            model: EmbeddingModel::new(config.dimensions),
23            config,
24            walk_generator: RandomWalkGenerator::new(),
25        }
26    }
27
28    /// Generate training data (random walks) for DeepWalk
29    pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
30    where
31        N: Clone + std::fmt::Debug,
32        E: EdgeWeight,
33        Ix: petgraph::graph::IndexType,
34    {
35        let mut all_walks = Vec::new();
36
37        for node in graph.nodes() {
38            for _ in 0..self.config.num_walks {
39                let walk =
40                    self.walk_generator
41                        .simple_random_walk(graph, node, self.config.walk_length)?;
42                all_walks.push(walk);
43            }
44        }
45
46        Ok(all_walks)
47    }
48
49    /// Train the DeepWalk model with complete skip-gram implementation
50    pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
51    where
52        N: Clone + std::fmt::Debug,
53        E: EdgeWeight,
54        Ix: petgraph::graph::IndexType,
55    {
56        // Initialize random embeddings
57        let mut rng = scirs2_core::random::rng();
58        self.model.initialize_random(graph, &mut rng);
59
60        // Create negative sampler
61        let negative_sampler = NegativeSampler::new(graph);
62
63        // Training loop over epochs
64        for epoch in 0..self.config.epochs {
65            // Generate walks for this epoch
66            let walks = self.generate_walks(graph)?;
67
68            // Generate context pairs from walks
69            let context_pairs =
70                EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
71
72            // Shuffle pairs for better training
73            let mut shuffled_pairs = context_pairs;
74            shuffled_pairs.shuffle(&mut rng);
75
76            // Train skip-gram model with negative sampling
77            let current_lr =
78                self.config.learning_rate * (1.0 - epoch as f64 / self.config.epochs as f64);
79
80            self.model.train_skip_gram(
81                &shuffled_pairs,
82                &negative_sampler,
83                current_lr,
84                self.config.negative_samples,
85                &mut rng,
86            )?;
87
88            if epoch % 10 == 0 || epoch == self.config.epochs - 1 {
89                println!(
90                    "DeepWalk epoch {}/{}, generated {} walks, {} context pairs",
91                    epoch + 1,
92                    self.config.epochs,
93                    walks.len(),
94                    shuffled_pairs.len()
95                );
96            }
97        }
98
99        Ok(())
100    }
101
102    /// Get the trained model
103    pub fn model(&self) -> &EmbeddingModel<N> {
104        &self.model
105    }
106
107    /// Get mutable reference to the model
108    pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
109        &mut self.model
110    }
111}