sklears_manifold/
node2vec.rs

1//! Node2Vec Algorithm implementation
2//!
3//! This module provides Node2Vec for learning continuous feature representations for nodes in networks.
4
5use scirs2_core::ndarray::{Array2, ArrayView2};
6use scirs2_core::random::rngs::StdRng;
7use scirs2_core::random::thread_rng;
8use scirs2_core::random::Rng;
9use scirs2_core::random::SeedableRng;
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Estimator, Fit, Transform, Untrained},
13    types::Float,
14};
15use std::collections::HashMap;
16
17/// Node2Vec Algorithm
18///
19/// Node2vec is a framework for learning continuous feature representations for nodes
20/// in networks. It uses biased random walks to generate contexts that preserve both
21/// local and global network structure.
22#[derive(Debug, Clone)]
23pub struct Node2Vec<S = Untrained> {
24    state: S,
25    n_components: usize,
26    walk_length: usize,
27    num_walks: usize,
28    p: f64, // Return parameter
29    q: f64, // In-out parameter
30    window_size: usize,
31    min_count: usize,
32    batch_words: usize,
33    epochs: usize,
34    learning_rate: f64,
35    negative_samples: usize,
36    random_state: Option<u64>,
37}
38
39impl Node2Vec<Untrained> {
40    /// Create a new Node2Vec instance
41    pub fn new() -> Self {
42        Self {
43            state: Untrained,
44            n_components: 128,
45            walk_length: 80,
46            num_walks: 10,
47            p: 1.0,
48            q: 1.0,
49            window_size: 10,
50            min_count: 1,
51            batch_words: 4,
52            epochs: 1,
53            learning_rate: 0.025,
54            negative_samples: 5,
55            random_state: None,
56        }
57    }
58
59    /// Set the number of components
60    pub fn n_components(mut self, n_components: usize) -> Self {
61        self.n_components = n_components;
62        self
63    }
64
65    /// Set the walk length
66    pub fn walk_length(mut self, walk_length: usize) -> Self {
67        self.walk_length = walk_length;
68        self
69    }
70
71    /// Set the number of walks
72    pub fn num_walks(mut self, num_walks: usize) -> Self {
73        self.num_walks = num_walks;
74        self
75    }
76
77    /// Set the return parameter (p)
78    pub fn p(mut self, p: f64) -> Self {
79        self.p = p;
80        self
81    }
82
83    /// Set the in-out parameter (q)
84    pub fn q(mut self, q: f64) -> Self {
85        self.q = q;
86        self
87    }
88
89    /// Set the window size
90    pub fn window_size(mut self, window_size: usize) -> Self {
91        self.window_size = window_size;
92        self
93    }
94
95    /// Set the minimum count
96    pub fn min_count(mut self, min_count: usize) -> Self {
97        self.min_count = min_count;
98        self
99    }
100
101    /// Set the batch words
102    pub fn batch_words(mut self, batch_words: usize) -> Self {
103        self.batch_words = batch_words;
104        self
105    }
106
107    /// Set the number of epochs
108    pub fn epochs(mut self, epochs: usize) -> Self {
109        self.epochs = epochs;
110        self
111    }
112
113    /// Set the learning rate
114    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
115        self.learning_rate = learning_rate;
116        self
117    }
118
119    /// Set the number of negative samples
120    pub fn negative_samples(mut self, negative_samples: usize) -> Self {
121        self.negative_samples = negative_samples;
122        self
123    }
124
125    /// Set the random state
126    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
127        self.random_state = random_state;
128        self
129    }
130}
131
132impl Default for Node2Vec<Untrained> {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138/// Trained state for Node2Vec
139#[derive(Debug, Clone)]
140pub struct Node2VecTrained {
141    /// Node embeddings
142    node_embeddings: Array2<f64>,
143    /// Vocabulary mapping
144    vocab: HashMap<usize, usize>,
145}
146
147impl Estimator for Node2Vec<Untrained> {
148    type Config = ();
149    type Error = SklearsError;
150    type Float = Float;
151
152    fn config(&self) -> &Self::Config {
153        &()
154    }
155}
156
157impl Estimator for Node2Vec<Node2VecTrained> {
158    type Config = ();
159    type Error = SklearsError;
160    type Float = Float;
161
162    fn config(&self) -> &Self::Config {
163        &()
164    }
165}
166
167impl Fit<ArrayView2<'_, Float>, ()> for Node2Vec<Untrained> {
168    type Fitted = Node2Vec<Node2VecTrained>;
169
170    fn fit(self, x: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
171        let (n_samples, _) = x.dim();
172
173        if n_samples < 2 {
174            return Err(SklearsError::InvalidParameter {
175                name: "n_samples".to_string(),
176                reason: "Node2Vec requires at least 2 samples".to_string(),
177            });
178        }
179
180        // Convert to f64 for computation
181        let x_f64 = x.mapv(|v| v);
182
183        // Build adjacency matrix from data
184        let adjacency = self.build_adjacency_matrix(&x_f64)?;
185
186        // Generate biased random walks using Node2Vec parameters
187        let walks = self.generate_node2vec_walks(&adjacency)?;
188
189        // Train skip-gram model on walks
190        let (node_embeddings, vocab) = self.train_skipgram_on_walks(&walks)?;
191
192        Ok(Node2Vec {
193            state: Node2VecTrained {
194                node_embeddings,
195                vocab,
196            },
197            n_components: self.n_components,
198            walk_length: self.walk_length,
199            num_walks: self.num_walks,
200            p: self.p,
201            q: self.q,
202            window_size: self.window_size,
203            min_count: self.min_count,
204            batch_words: self.batch_words,
205            epochs: self.epochs,
206            learning_rate: self.learning_rate,
207            negative_samples: self.negative_samples,
208            random_state: self.random_state,
209        })
210    }
211}
212
213impl Node2Vec<Untrained> {
214    fn build_adjacency_matrix(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
215        let n_samples = x.nrows();
216        let mut adjacency = Array2::zeros((n_samples, n_samples));
217
218        // Simple k-NN graph construction
219        let k = 10.min(n_samples - 1);
220
221        for i in 0..n_samples {
222            let mut distances: Vec<(usize, f64)> = Vec::new();
223
224            for j in 0..n_samples {
225                if i != j {
226                    let dist = (&x.row(i) - &x.row(j)).mapv(|v| v * v).sum().sqrt();
227                    distances.push((j, dist));
228                }
229            }
230
231            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
232
233            // Connect to k nearest neighbors
234            for &(j, dist) in distances.iter().take(k) {
235                let weight = (-dist).exp(); // Gaussian weight
236                adjacency[(i, j)] = weight;
237                adjacency[(j, i)] = weight; // Make symmetric
238            }
239        }
240
241        Ok(adjacency)
242    }
243
244    fn generate_node2vec_walks(&self, adjacency: &Array2<f64>) -> SklResult<Vec<Vec<usize>>> {
245        let n_nodes = adjacency.nrows();
246        let mut rng = if let Some(seed) = self.random_state {
247            StdRng::seed_from_u64(seed)
248        } else {
249            StdRng::seed_from_u64(thread_rng().random::<u64>())
250        };
251
252        let mut all_walks = Vec::new();
253
254        for start_node in 0..n_nodes {
255            for _ in 0..self.num_walks {
256                let walk = self.node2vec_walk(start_node, adjacency, &mut rng)?;
257                if walk.len() >= 2 {
258                    all_walks.push(walk);
259                }
260            }
261        }
262
263        Ok(all_walks)
264    }
265
266    fn node2vec_walk(
267        &self,
268        start_node: usize,
269        adjacency: &Array2<f64>,
270        rng: &mut StdRng,
271    ) -> SklResult<Vec<usize>> {
272        let mut walk = vec![start_node];
273        let mut prev_node = None;
274        let mut current_node = start_node;
275
276        for _ in 1..self.walk_length {
277            let neighbors = self.get_neighbors(current_node, adjacency);
278
279            if neighbors.is_empty() {
280                break;
281            }
282
283            let next_node = if let Some(prev) = prev_node {
284                self.biased_choice(current_node, prev, &neighbors, adjacency, rng)?
285            } else {
286                // First step: uniform random choice
287                neighbors[rng.gen_range(0..neighbors.len())]
288            };
289
290            walk.push(next_node);
291            prev_node = Some(current_node);
292            current_node = next_node;
293        }
294
295        Ok(walk)
296    }
297
298    fn get_neighbors(&self, node: usize, adjacency: &Array2<f64>) -> Vec<usize> {
299        adjacency
300            .row(node)
301            .iter()
302            .enumerate()
303            .filter_map(|(idx, &weight)| if weight > 0.0 { Some(idx) } else { None })
304            .collect()
305    }
306
307    fn biased_choice(
308        &self,
309        current: usize,
310        prev: usize,
311        neighbors: &[usize],
312        adjacency: &Array2<f64>,
313        rng: &mut StdRng,
314    ) -> SklResult<usize> {
315        let mut weights = Vec::new();
316        let mut total_weight = 0.0;
317
318        for &neighbor in neighbors {
319            let edge_weight = adjacency[(current, neighbor)];
320
321            let bias = if neighbor == prev {
322                // Return to previous node
323                1.0 / self.p
324            } else if adjacency[(prev, neighbor)] > 0.0 {
325                // Neighbor is also connected to previous node (local exploration)
326                1.0
327            } else {
328                // Move away from previous node (global exploration)
329                1.0 / self.q
330            };
331
332            let final_weight = edge_weight * bias;
333            weights.push(final_weight);
334            total_weight += final_weight;
335        }
336
337        if total_weight <= 0.0 {
338            // Fallback to uniform choice
339            return Ok(neighbors[rng.gen_range(0..neighbors.len())]);
340        }
341
342        // Weighted random choice
343        let mut cumulative = 0.0;
344        let threshold = rng.gen::<f64>() * total_weight;
345
346        for (i, &weight) in weights.iter().enumerate() {
347            cumulative += weight;
348            if cumulative >= threshold {
349                return Ok(neighbors[i]);
350            }
351        }
352
353        // Fallback (should not reach here)
354        Ok(neighbors[neighbors.len() - 1])
355    }
356
357    fn train_skipgram_on_walks(
358        &self,
359        walks: &[Vec<usize>],
360    ) -> SklResult<(Array2<f64>, HashMap<usize, usize>)> {
361        // Build vocabulary
362        let mut word_count = HashMap::new();
363        for walk in walks {
364            for &word in walk {
365                *word_count.entry(word).or_insert(0) += 1;
366            }
367        }
368
369        // Filter by min_count
370        let vocab: HashMap<usize, usize> = word_count
371            .iter()
372            .filter(|(_, &count)| count >= self.min_count)
373            .enumerate()
374            .map(|(idx, (&word, _))| (word, idx))
375            .collect();
376
377        let vocab_size = vocab.len();
378        if vocab_size == 0 {
379            return Err(SklearsError::InvalidInput(
380                "No words meet minimum count requirement".to_string(),
381            ));
382        }
383
384        let mut rng = if let Some(seed) = self.random_state {
385            StdRng::seed_from_u64(seed)
386        } else {
387            StdRng::seed_from_u64(thread_rng().random::<u64>())
388        };
389
390        // Initialize embeddings
391        let mut node_embeddings = Array2::zeros((vocab_size, self.n_components));
392        for i in 0..vocab_size {
393            for j in 0..self.n_components {
394                node_embeddings[(i, j)] = rng.sample::<f64, _>(scirs2_core::StandardNormal) * 0.1;
395            }
396        }
397
398        // Simplified skip-gram training
399        for _epoch in 0..self.epochs {
400            for walk in walks {
401                for (center_idx, &center_word) in walk.iter().enumerate() {
402                    if let Some(&center_vocab_idx) = vocab.get(&center_word) {
403                        // Context window
404                        let start = center_idx.saturating_sub(self.window_size);
405                        let end = (center_idx + self.window_size + 1).min(walk.len());
406
407                        for context_idx in start..end {
408                            if context_idx != center_idx {
409                                if let Some(&context_word) = walk.get(context_idx) {
410                                    if let Some(&context_vocab_idx) = vocab.get(&context_word) {
411                                        // Simplified gradient update
412                                        let dot_product: f64 = node_embeddings
413                                            .row(center_vocab_idx)
414                                            .iter()
415                                            .zip(node_embeddings.row(context_vocab_idx).iter())
416                                            .map(|(a, b)| a * b)
417                                            .sum();
418
419                                        let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
420                                        let gradient = self.learning_rate * (1.0 - sigmoid);
421
422                                        for k in 0..self.n_components {
423                                            let center_val = node_embeddings[(center_vocab_idx, k)];
424                                            let context_val =
425                                                node_embeddings[(context_vocab_idx, k)];
426
427                                            node_embeddings[(center_vocab_idx, k)] +=
428                                                gradient * context_val;
429                                            node_embeddings[(context_vocab_idx, k)] +=
430                                                gradient * center_val;
431                                        }
432                                    }
433                                }
434                            }
435                        }
436                    }
437                }
438            }
439        }
440
441        Ok((node_embeddings, vocab))
442    }
443}
444
445impl Transform<ArrayView2<'_, Float>, Array2<Float>> for Node2Vec<Node2VecTrained> {
446    fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
447        let (n_samples, _) = x.dim();
448
449        if n_samples != self.state.vocab.len() {
450            return Err(SklearsError::InvalidInput(
451                "Input size must match training data size for Node2Vec".to_string(),
452            ));
453        }
454
455        // Return the learned embeddings
456        Ok(self.state.node_embeddings.mapv(|v| v as Float))
457    }
458}
459
460impl Node2Vec<Node2VecTrained> {
461    /// Get the learned node embeddings
462    pub fn node_embeddings(&self) -> &Array2<f64> {
463        &self.state.node_embeddings
464    }
465
466    /// Get the vocabulary mapping
467    pub fn vocab(&self) -> &HashMap<usize, usize> {
468        &self.state.vocab
469    }
470}