1use super::core::EmbeddingModel;
11use super::negative_sampling::NegativeSampler;
12use super::random_walk::RandomWalkGenerator;
13use super::types::{Node2VecConfig, RandomWalk};
14use crate::base::{DiGraph, EdgeWeight, Graph, Node};
15use crate::error::Result;
16use scirs2_core::random::seq::SliceRandom;
17use scirs2_core::random::RngExt;
18
19pub struct Node2Vec<N: Node> {
24 config: Node2VecConfig,
25 model: EmbeddingModel<N>,
26 walk_generator: RandomWalkGenerator<N>,
27}
28
29impl<N: Node> Node2Vec<N> {
30 pub fn new(config: Node2VecConfig) -> Self {
32 Node2Vec {
33 model: EmbeddingModel::new(config.dimensions),
34 config,
35 walk_generator: RandomWalkGenerator::new(),
36 }
37 }
38
39 pub fn generate_walks<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<Vec<RandomWalk<N>>>
41 where
42 N: Clone + std::fmt::Debug,
43 E: EdgeWeight + Into<f64>,
44 Ix: petgraph::graph::IndexType,
45 {
46 let mut all_walks = Vec::new();
47
48 for node in graph.nodes() {
49 for _ in 0..self.config.num_walks {
50 let walk = self.walk_generator.node2vec_walk(
51 graph,
52 node,
53 self.config.walk_length,
54 self.config.p,
55 self.config.q,
56 )?;
57 all_walks.push(walk);
58 }
59 }
60
61 Ok(all_walks)
62 }
63
64 pub fn generate_walks_digraph<E, Ix>(
66 &mut self,
67 graph: &DiGraph<N, E, Ix>,
68 ) -> Result<Vec<RandomWalk<N>>>
69 where
70 N: Clone + std::fmt::Debug,
71 E: EdgeWeight + Into<f64>,
72 Ix: petgraph::graph::IndexType,
73 {
74 let mut all_walks = Vec::new();
75
76 for node in graph.nodes() {
77 for _ in 0..self.config.num_walks {
78 let walk = self.walk_generator.node2vec_walk_digraph(
79 graph,
80 node,
81 self.config.walk_length,
82 self.config.p,
83 self.config.q,
84 )?;
85 all_walks.push(walk);
86 }
87 }
88
89 Ok(all_walks)
90 }
91
92 pub fn train<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
94 where
95 N: Clone + std::fmt::Debug,
96 E: EdgeWeight + Into<f64>,
97 Ix: petgraph::graph::IndexType,
98 {
99 let mut rng = scirs2_core::random::rng();
101 self.model.initialize_random(graph, &mut rng);
102
103 let negative_sampler = NegativeSampler::new(graph);
105
106 for epoch in 0..self.config.epochs {
108 let walks = self.generate_walks(graph)?;
110
111 let context_pairs =
113 EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
114
115 let mut shuffled_pairs = context_pairs;
117 shuffled_pairs.shuffle(&mut rng);
118
119 let current_lr = self.config.learning_rate
122 * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
123
124 self.model.train_skip_gram(
125 &shuffled_pairs,
126 &negative_sampler,
127 current_lr,
128 self.config.negative_samples,
129 &mut rng,
130 )?;
131 }
132
133 Ok(())
134 }
135
136 pub fn train_digraph<E, Ix>(&mut self, graph: &DiGraph<N, E, Ix>) -> Result<()>
138 where
139 N: Clone + std::fmt::Debug,
140 E: EdgeWeight + Into<f64>,
141 Ix: petgraph::graph::IndexType,
142 {
143 let mut rng = scirs2_core::random::rng();
145 self.model.initialize_random_digraph(graph, &mut rng);
146
147 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
150 let node_degrees: Vec<f64> = nodes.iter().map(|n| graph.degree(n) as f64).collect();
151
152 let total_degree: f64 = node_degrees.iter().sum();
154 let frequencies: Vec<f64> = node_degrees
155 .iter()
156 .map(|d| (d / total_degree.max(1.0)).powf(0.75))
157 .collect();
158 let total_freq: f64 = frequencies.iter().sum();
159 let normalized: Vec<f64> = frequencies
160 .iter()
161 .map(|f| f / total_freq.max(1e-10))
162 .collect();
163
164 let mut cumulative = vec![0.0; normalized.len()];
165 if !cumulative.is_empty() {
166 cumulative[0] = normalized[0];
167 for i in 1..normalized.len() {
168 cumulative[i] = cumulative[i - 1] + normalized[i];
169 }
170 }
171
172 for epoch in 0..self.config.epochs {
174 let walks = self.generate_walks_digraph(graph)?;
175 let context_pairs =
176 EmbeddingModel::generate_context_pairs(&walks, self.config.window_size);
177
178 let mut shuffled_pairs = context_pairs;
179 shuffled_pairs.shuffle(&mut rng);
180
181 let current_lr = self.config.learning_rate
182 * (1.0 - epoch as f64 / self.config.epochs as f64).max(0.0001);
183
184 for pair in &shuffled_pairs {
187 self.train_pair_digraph(
188 pair,
189 &nodes,
190 &cumulative,
191 current_lr,
192 self.config.negative_samples,
193 &mut rng,
194 );
195 }
196 }
197
198 Ok(())
199 }
200
201 fn train_pair_digraph(
203 &mut self,
204 pair: &super::types::ContextPair<N>,
205 nodes: &[N],
206 cumulative: &[f64],
207 learning_rate: f64,
208 num_negative: usize,
209 rng: &mut impl scirs2_core::random::Rng,
210 ) where
211 N: Clone,
212 {
213 let dim = self.config.dimensions;
214
215 let target_emb = match self.model.embeddings.get(&pair.target) {
217 Some(e) => e.clone(),
218 None => return,
219 };
220
221 let context_emb = match self.model.context_embeddings.get(&pair.context) {
223 Some(e) => e.clone(),
224 None => return,
225 };
226
227 let dot: f64 = target_emb
229 .vector
230 .iter()
231 .zip(context_emb.vector.iter())
232 .map(|(a, b)| a * b)
233 .sum();
234 let sig = 1.0 / (1.0 + (-dot).exp());
235 let g = learning_rate * (1.0 - sig);
236
237 let mut target_grad = vec![0.0; dim];
238 for d in 0..dim {
239 target_grad[d] += g * context_emb.vector[d];
240 }
241
242 if let Some(ctx) = self.model.context_embeddings.get_mut(&pair.context) {
244 for d in 0..dim {
245 ctx.vector[d] += g * target_emb.vector[d];
246 }
247 }
248
249 for _ in 0..num_negative {
251 let r = rng.random::<f64>();
252 let neg_idx = cumulative
253 .iter()
254 .position(|&c| r <= c)
255 .unwrap_or(cumulative.len().saturating_sub(1));
256
257 if neg_idx >= nodes.len() {
258 continue;
259 }
260
261 let neg_node = &nodes[neg_idx];
262 if neg_node == &pair.target || neg_node == &pair.context {
263 continue;
264 }
265
266 if let Some(neg_emb) = self.model.context_embeddings.get(neg_node) {
267 let neg_dot: f64 = target_emb
268 .vector
269 .iter()
270 .zip(neg_emb.vector.iter())
271 .map(|(a, b)| a * b)
272 .sum();
273 let neg_sig = 1.0 / (1.0 + (-neg_dot).exp());
274 let neg_g = learning_rate * (-neg_sig);
275
276 for d in 0..dim {
277 target_grad[d] += neg_g * neg_emb.vector[d];
278 }
279
280 if let Some(neg_ctx) = self.model.context_embeddings.get_mut(neg_node) {
282 for d in 0..dim {
283 neg_ctx.vector[d] += neg_g * target_emb.vector[d];
284 }
285 }
286 }
287 }
288
289 if let Some(target) = self.model.embeddings.get_mut(&pair.target) {
291 for d in 0..dim {
292 target.vector[d] += target_grad[d];
293 }
294 }
295 }
296
297 pub fn model(&self) -> &EmbeddingModel<N> {
299 &self.model
300 }
301
302 pub fn model_mut(&mut self) -> &mut EmbeddingModel<N> {
304 &mut self.model
305 }
306
307 pub fn config(&self) -> &Node2VecConfig {
309 &self.config
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 fn make_triangle() -> Graph<i32, f64> {
318 let mut g = Graph::new();
319 for i in 0..3 {
320 g.add_node(i);
321 }
322 let _ = g.add_edge(0, 1, 1.0);
323 let _ = g.add_edge(1, 2, 1.0);
324 let _ = g.add_edge(0, 2, 1.0);
325 g
326 }
327
328 fn make_star_graph() -> Graph<i32, f64> {
329 let mut g = Graph::new();
330 for i in 0..5 {
331 g.add_node(i);
332 }
333 for i in 1..5 {
335 let _ = g.add_edge(0, i, 1.0);
336 }
337 g
338 }
339
340 fn make_directed_chain() -> DiGraph<i32, f64> {
341 let mut g = DiGraph::new();
342 for i in 0..5 {
343 g.add_node(i);
344 }
345 let _ = g.add_edge(0, 1, 1.0);
346 let _ = g.add_edge(1, 2, 1.0);
347 let _ = g.add_edge(2, 3, 1.0);
348 let _ = g.add_edge(3, 4, 1.0);
349 g
350 }
351
352 #[test]
353 fn test_node2vec_train_basic() {
354 let g = make_triangle();
355 let config = Node2VecConfig {
356 dimensions: 8,
357 walk_length: 5,
358 num_walks: 3,
359 window_size: 2,
360 p: 1.0,
361 q: 1.0,
362 epochs: 2,
363 learning_rate: 0.025,
364 negative_samples: 2,
365 };
366
367 let mut n2v = Node2Vec::new(config);
368 let result = n2v.train(&g);
369 assert!(result.is_ok(), "Node2Vec training should succeed");
370
371 for node in [0, 1, 2] {
373 assert!(
374 n2v.model().get_embedding(&node).is_some(),
375 "Node {node} should have an embedding"
376 );
377 }
378 }
379
380 #[test]
381 fn test_node2vec_walk_generation() {
382 let g = make_triangle();
383 let config = Node2VecConfig {
384 dimensions: 8,
385 walk_length: 10,
386 num_walks: 2,
387 p: 1.0,
388 q: 1.0,
389 ..Default::default()
390 };
391
392 let mut n2v = Node2Vec::new(config);
393 let walks = n2v.generate_walks(&g);
394 assert!(walks.is_ok());
395
396 let walks = walks.expect("walks should be valid");
397 assert_eq!(walks.len(), 6);
399
400 for walk in &walks {
402 assert!(walk.nodes.len() <= 10);
403 assert!(!walk.nodes.is_empty());
404 }
405 }
406
407 #[test]
408 fn test_node2vec_biased_walks() {
409 let g = make_star_graph();
412 let config = Node2VecConfig {
413 dimensions: 8,
414 walk_length: 20,
415 num_walks: 5,
416 p: 0.5,
417 q: 2.0,
418 ..Default::default()
419 };
420
421 let mut n2v = Node2Vec::new(config);
422 let walks = n2v.generate_walks(&g);
423 assert!(walks.is_ok());
424
425 let walks = walks.expect("walks should be valid");
426 assert!(!walks.is_empty());
427
428 for walk in &walks {
430 for node in &walk.nodes {
431 assert!(
432 (0..5).contains(node),
433 "Walk should only contain valid nodes, got {node}"
434 );
435 }
436 }
437 }
438
439 #[test]
440 fn test_node2vec_embedding_similarity() {
441 let g = make_triangle();
442 let config = Node2VecConfig {
443 dimensions: 16,
444 walk_length: 10,
445 num_walks: 10,
446 window_size: 3,
447 p: 1.0,
448 q: 1.0,
449 epochs: 5,
450 learning_rate: 0.05,
451 negative_samples: 3,
452 };
453
454 let mut n2v = Node2Vec::new(config);
455 let _ = n2v.train(&g);
456
457 let model = n2v.model();
460 let sim_01 = model.most_similar(&0, 2);
461 assert!(sim_01.is_ok());
462
463 let sim_01 = sim_01.expect("similarity should be valid");
464 assert_eq!(sim_01.len(), 2, "Should find 2 most similar nodes");
465
466 for (node, score) in &sim_01 {
467 assert!(
468 score.is_finite(),
469 "Similarity for node {node} should be finite"
470 );
471 }
472 }
473
474 #[test]
475 fn test_node2vec_digraph_train() {
476 let g = make_directed_chain();
477 let config = Node2VecConfig {
478 dimensions: 8,
479 walk_length: 4,
480 num_walks: 3,
481 window_size: 2,
482 p: 1.0,
483 q: 1.0,
484 epochs: 2,
485 learning_rate: 0.025,
486 negative_samples: 2,
487 };
488
489 let mut n2v = Node2Vec::new(config);
490 let result = n2v.train_digraph(&g);
491 assert!(result.is_ok(), "DiGraph Node2Vec training should succeed");
492
493 for node in 0..5 {
495 assert!(
496 n2v.model().get_embedding(&node).is_some(),
497 "Node {node} should have an embedding in directed graph"
498 );
499 }
500 }
501
502 #[test]
503 fn test_node2vec_config() {
504 let config = Node2VecConfig::default();
505 assert_eq!(config.dimensions, 128);
506 assert_eq!(config.walk_length, 80);
507 assert_eq!(config.p, 1.0);
508 assert_eq!(config.q, 1.0);
509
510 let n2v: Node2Vec<i32> = Node2Vec::new(config);
511 assert_eq!(n2v.config().dimensions, 128);
512 }
513}