1use scirs2_core::ndarray::{Array2, ArrayView2};
6use scirs2_core::random::rngs::StdRng;
7use scirs2_core::random::thread_rng;
8use scirs2_core::random::Rng;
9use scirs2_core::random::SeedableRng;
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Fit, Transform, Untrained},
13 types::Float,
14};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone)]
23pub struct Node2Vec<S = Untrained> {
24 state: S,
25 n_components: usize,
26 walk_length: usize,
27 num_walks: usize,
28 p: f64, q: f64, window_size: usize,
31 min_count: usize,
32 batch_words: usize,
33 epochs: usize,
34 learning_rate: f64,
35 negative_samples: usize,
36 random_state: Option<u64>,
37}
38
39impl Node2Vec<Untrained> {
40 pub fn new() -> Self {
42 Self {
43 state: Untrained,
44 n_components: 128,
45 walk_length: 80,
46 num_walks: 10,
47 p: 1.0,
48 q: 1.0,
49 window_size: 10,
50 min_count: 1,
51 batch_words: 4,
52 epochs: 1,
53 learning_rate: 0.025,
54 negative_samples: 5,
55 random_state: None,
56 }
57 }
58
59 pub fn n_components(mut self, n_components: usize) -> Self {
61 self.n_components = n_components;
62 self
63 }
64
65 pub fn walk_length(mut self, walk_length: usize) -> Self {
67 self.walk_length = walk_length;
68 self
69 }
70
71 pub fn num_walks(mut self, num_walks: usize) -> Self {
73 self.num_walks = num_walks;
74 self
75 }
76
77 pub fn p(mut self, p: f64) -> Self {
79 self.p = p;
80 self
81 }
82
83 pub fn q(mut self, q: f64) -> Self {
85 self.q = q;
86 self
87 }
88
89 pub fn window_size(mut self, window_size: usize) -> Self {
91 self.window_size = window_size;
92 self
93 }
94
95 pub fn min_count(mut self, min_count: usize) -> Self {
97 self.min_count = min_count;
98 self
99 }
100
101 pub fn batch_words(mut self, batch_words: usize) -> Self {
103 self.batch_words = batch_words;
104 self
105 }
106
107 pub fn epochs(mut self, epochs: usize) -> Self {
109 self.epochs = epochs;
110 self
111 }
112
113 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
115 self.learning_rate = learning_rate;
116 self
117 }
118
119 pub fn negative_samples(mut self, negative_samples: usize) -> Self {
121 self.negative_samples = negative_samples;
122 self
123 }
124
125 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
127 self.random_state = random_state;
128 self
129 }
130}
131
132impl Default for Node2Vec<Untrained> {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct Node2VecTrained {
141 node_embeddings: Array2<f64>,
143 vocab: HashMap<usize, usize>,
145}
146
147impl Estimator for Node2Vec<Untrained> {
148 type Config = ();
149 type Error = SklearsError;
150 type Float = Float;
151
152 fn config(&self) -> &Self::Config {
153 &()
154 }
155}
156
157impl Estimator for Node2Vec<Node2VecTrained> {
158 type Config = ();
159 type Error = SklearsError;
160 type Float = Float;
161
162 fn config(&self) -> &Self::Config {
163 &()
164 }
165}
166
167impl Fit<ArrayView2<'_, Float>, ()> for Node2Vec<Untrained> {
168 type Fitted = Node2Vec<Node2VecTrained>;
169
170 fn fit(self, x: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
171 let (n_samples, _) = x.dim();
172
173 if n_samples < 2 {
174 return Err(SklearsError::InvalidParameter {
175 name: "n_samples".to_string(),
176 reason: "Node2Vec requires at least 2 samples".to_string(),
177 });
178 }
179
180 let x_f64 = x.mapv(|v| v);
182
183 let adjacency = self.build_adjacency_matrix(&x_f64)?;
185
186 let walks = self.generate_node2vec_walks(&adjacency)?;
188
189 let (node_embeddings, vocab) = self.train_skipgram_on_walks(&walks)?;
191
192 Ok(Node2Vec {
193 state: Node2VecTrained {
194 node_embeddings,
195 vocab,
196 },
197 n_components: self.n_components,
198 walk_length: self.walk_length,
199 num_walks: self.num_walks,
200 p: self.p,
201 q: self.q,
202 window_size: self.window_size,
203 min_count: self.min_count,
204 batch_words: self.batch_words,
205 epochs: self.epochs,
206 learning_rate: self.learning_rate,
207 negative_samples: self.negative_samples,
208 random_state: self.random_state,
209 })
210 }
211}
212
213impl Node2Vec<Untrained> {
214 fn build_adjacency_matrix(&self, x: &Array2<f64>) -> SklResult<Array2<f64>> {
215 let n_samples = x.nrows();
216 let mut adjacency = Array2::zeros((n_samples, n_samples));
217
218 let k = 10.min(n_samples - 1);
220
221 for i in 0..n_samples {
222 let mut distances: Vec<(usize, f64)> = Vec::new();
223
224 for j in 0..n_samples {
225 if i != j {
226 let dist = (&x.row(i) - &x.row(j)).mapv(|v| v * v).sum().sqrt();
227 distances.push((j, dist));
228 }
229 }
230
231 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
232
233 for &(j, dist) in distances.iter().take(k) {
235 let weight = (-dist).exp(); adjacency[(i, j)] = weight;
237 adjacency[(j, i)] = weight; }
239 }
240
241 Ok(adjacency)
242 }
243
244 fn generate_node2vec_walks(&self, adjacency: &Array2<f64>) -> SklResult<Vec<Vec<usize>>> {
245 let n_nodes = adjacency.nrows();
246 let mut rng = if let Some(seed) = self.random_state {
247 StdRng::seed_from_u64(seed)
248 } else {
249 StdRng::seed_from_u64(thread_rng().random::<u64>())
250 };
251
252 let mut all_walks = Vec::new();
253
254 for start_node in 0..n_nodes {
255 for _ in 0..self.num_walks {
256 let walk = self.node2vec_walk(start_node, adjacency, &mut rng)?;
257 if walk.len() >= 2 {
258 all_walks.push(walk);
259 }
260 }
261 }
262
263 Ok(all_walks)
264 }
265
266 fn node2vec_walk(
267 &self,
268 start_node: usize,
269 adjacency: &Array2<f64>,
270 rng: &mut StdRng,
271 ) -> SklResult<Vec<usize>> {
272 let mut walk = vec![start_node];
273 let mut prev_node = None;
274 let mut current_node = start_node;
275
276 for _ in 1..self.walk_length {
277 let neighbors = self.get_neighbors(current_node, adjacency);
278
279 if neighbors.is_empty() {
280 break;
281 }
282
283 let next_node = if let Some(prev) = prev_node {
284 self.biased_choice(current_node, prev, &neighbors, adjacency, rng)?
285 } else {
286 neighbors[rng.gen_range(0..neighbors.len())]
288 };
289
290 walk.push(next_node);
291 prev_node = Some(current_node);
292 current_node = next_node;
293 }
294
295 Ok(walk)
296 }
297
298 fn get_neighbors(&self, node: usize, adjacency: &Array2<f64>) -> Vec<usize> {
299 adjacency
300 .row(node)
301 .iter()
302 .enumerate()
303 .filter_map(|(idx, &weight)| if weight > 0.0 { Some(idx) } else { None })
304 .collect()
305 }
306
307 fn biased_choice(
308 &self,
309 current: usize,
310 prev: usize,
311 neighbors: &[usize],
312 adjacency: &Array2<f64>,
313 rng: &mut StdRng,
314 ) -> SklResult<usize> {
315 let mut weights = Vec::new();
316 let mut total_weight = 0.0;
317
318 for &neighbor in neighbors {
319 let edge_weight = adjacency[(current, neighbor)];
320
321 let bias = if neighbor == prev {
322 1.0 / self.p
324 } else if adjacency[(prev, neighbor)] > 0.0 {
325 1.0
327 } else {
328 1.0 / self.q
330 };
331
332 let final_weight = edge_weight * bias;
333 weights.push(final_weight);
334 total_weight += final_weight;
335 }
336
337 if total_weight <= 0.0 {
338 return Ok(neighbors[rng.gen_range(0..neighbors.len())]);
340 }
341
342 let mut cumulative = 0.0;
344 let threshold = rng.gen::<f64>() * total_weight;
345
346 for (i, &weight) in weights.iter().enumerate() {
347 cumulative += weight;
348 if cumulative >= threshold {
349 return Ok(neighbors[i]);
350 }
351 }
352
353 Ok(neighbors[neighbors.len() - 1])
355 }
356
357 fn train_skipgram_on_walks(
358 &self,
359 walks: &[Vec<usize>],
360 ) -> SklResult<(Array2<f64>, HashMap<usize, usize>)> {
361 let mut word_count = HashMap::new();
363 for walk in walks {
364 for &word in walk {
365 *word_count.entry(word).or_insert(0) += 1;
366 }
367 }
368
369 let vocab: HashMap<usize, usize> = word_count
371 .iter()
372 .filter(|(_, &count)| count >= self.min_count)
373 .enumerate()
374 .map(|(idx, (&word, _))| (word, idx))
375 .collect();
376
377 let vocab_size = vocab.len();
378 if vocab_size == 0 {
379 return Err(SklearsError::InvalidInput(
380 "No words meet minimum count requirement".to_string(),
381 ));
382 }
383
384 let mut rng = if let Some(seed) = self.random_state {
385 StdRng::seed_from_u64(seed)
386 } else {
387 StdRng::seed_from_u64(thread_rng().random::<u64>())
388 };
389
390 let mut node_embeddings = Array2::zeros((vocab_size, self.n_components));
392 for i in 0..vocab_size {
393 for j in 0..self.n_components {
394 node_embeddings[(i, j)] = rng.sample::<f64, _>(scirs2_core::StandardNormal) * 0.1;
395 }
396 }
397
398 for _epoch in 0..self.epochs {
400 for walk in walks {
401 for (center_idx, ¢er_word) in walk.iter().enumerate() {
402 if let Some(¢er_vocab_idx) = vocab.get(¢er_word) {
403 let start = center_idx.saturating_sub(self.window_size);
405 let end = (center_idx + self.window_size + 1).min(walk.len());
406
407 for context_idx in start..end {
408 if context_idx != center_idx {
409 if let Some(&context_word) = walk.get(context_idx) {
410 if let Some(&context_vocab_idx) = vocab.get(&context_word) {
411 let dot_product: f64 = node_embeddings
413 .row(center_vocab_idx)
414 .iter()
415 .zip(node_embeddings.row(context_vocab_idx).iter())
416 .map(|(a, b)| a * b)
417 .sum();
418
419 let sigmoid = 1.0 / (1.0 + (-dot_product).exp());
420 let gradient = self.learning_rate * (1.0 - sigmoid);
421
422 for k in 0..self.n_components {
423 let center_val = node_embeddings[(center_vocab_idx, k)];
424 let context_val =
425 node_embeddings[(context_vocab_idx, k)];
426
427 node_embeddings[(center_vocab_idx, k)] +=
428 gradient * context_val;
429 node_embeddings[(context_vocab_idx, k)] +=
430 gradient * center_val;
431 }
432 }
433 }
434 }
435 }
436 }
437 }
438 }
439 }
440
441 Ok((node_embeddings, vocab))
442 }
443}
444
445impl Transform<ArrayView2<'_, Float>, Array2<Float>> for Node2Vec<Node2VecTrained> {
446 fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
447 let (n_samples, _) = x.dim();
448
449 if n_samples != self.state.vocab.len() {
450 return Err(SklearsError::InvalidInput(
451 "Input size must match training data size for Node2Vec".to_string(),
452 ));
453 }
454
455 Ok(self.state.node_embeddings.mapv(|v| v as Float))
457 }
458}
459
460impl Node2Vec<Node2VecTrained> {
461 pub fn node_embeddings(&self) -> &Array2<f64> {
463 &self.state.node_embeddings
464 }
465
466 pub fn vocab(&self) -> &HashMap<usize, usize> {
468 &self.state.vocab
469 }
470}