1use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use crate::error::{GraphError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::parallel_ops::*;
9use std::collections::HashMap;
10use std::hash::Hash;
11
12#[allow(dead_code)]
16pub fn random_walk<N, E, Ix>(
17 graph: &Graph<N, E, Ix>,
18 start: &N,
19 steps: usize,
20 restart_probability: f64,
21) -> Result<Vec<N>>
22where
23 N: Node + Clone + Hash + Eq + std::fmt::Debug,
24 E: EdgeWeight,
25 Ix: IndexType,
26{
27 if !graph.contains_node(start) {
28 return Err(GraphError::node_not_found("node"));
29 }
30
31 let mut walk = vec![start.clone()];
32 let mut current = start.clone();
33 let mut rng = scirs2_core::random::rng();
34
35 use scirs2_core::random::Rng;
36
37 for _ in 0..steps {
38 if rng.random::<f64>() < restart_probability {
40 current = start.clone();
41 walk.push(current.clone());
42 continue;
43 }
44
45 if let Ok(neighbors) = graph.neighbors(¤t) {
47 let neighbor_vec: Vec<N> = neighbors;
48
49 if !neighbor_vec.is_empty() {
50 let idx = rng.gen_range(0..neighbor_vec.len());
51 current = neighbor_vec[idx].clone();
52 walk.push(current.clone());
53 } else {
54 current = start.clone();
56 walk.push(current.clone());
57 }
58 }
59 }
60
61 Ok(walk)
62}
63
64#[allow(dead_code)]
69pub fn transition_matrix<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<(Vec<N>, Array2<f64>)>
70where
71 N: Node + Clone + std::fmt::Debug,
72 E: EdgeWeight + Into<f64>,
73 Ix: IndexType,
74{
75 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
76 let n = nodes.len();
77
78 if n == 0 {
79 return Err(GraphError::InvalidGraph("Empty graph".to_string()));
80 }
81
82 let mut matrix = Array2::<f64>::zeros((n, n));
83
84 for (i, node) in nodes.iter().enumerate() {
85 if let Ok(neighbors) = graph.neighbors(node) {
86 let neighbor_weights: Vec<(usize, f64)> = neighbors
87 .into_iter()
88 .filter_map(|neighbor| {
89 nodes.iter().position(|n| n == &neighbor).and_then(|j| {
90 graph
91 .edge_weight(node, &neighbor)
92 .ok()
93 .map(|w| (j, w.into()))
94 })
95 })
96 .collect();
97
98 let total_weight: f64 = neighbor_weights.iter().map(|(_, w)| w).sum();
99
100 if total_weight > 0.0 {
101 for (j, weight) in neighbor_weights {
102 matrix[[i, j]] = weight / total_weight;
103 }
104 } else {
105 for j in 0..n {
107 matrix[[i, j]] = 1.0 / n as f64;
108 }
109 }
110 }
111 }
112
113 Ok((nodes, matrix))
114}
115
116#[allow(dead_code)]
120pub fn personalized_pagerank<N, E, Ix>(
121 graph: &Graph<N, E, Ix>,
122 source: &N,
123 damping: f64,
124 tolerance: f64,
125 max_iter: usize,
126) -> Result<HashMap<N, f64>>
127where
128 N: Node + Clone + Hash + Eq + std::fmt::Debug,
129 E: EdgeWeight + Into<f64>,
130 Ix: IndexType,
131{
132 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
133 let n = nodes.len();
134
135 if n == 0 || !graph.contains_node(source) {
136 return Err(GraphError::node_not_found("node"));
137 }
138
139 let source_idx = nodes
141 .iter()
142 .position(|n| n == source)
143 .expect("Operation failed");
144
145 let (_, trans_matrix) = transition_matrix(graph)?;
147
148 let mut pr = Array1::<f64>::zeros(n);
150 pr[source_idx] = 1.0;
151
152 let mut personalization = Array1::<f64>::zeros(n);
154 personalization[source_idx] = 1.0;
155
156 for _ in 0..max_iter {
158 let new_pr = damping * trans_matrix.t().dot(&pr) + (1.0 - damping) * &personalization;
159
160 let diff: f64 = (&new_pr - &pr).iter().map(|x| x.abs()).sum();
162 if diff < tolerance {
163 break;
164 }
165
166 pr = new_pr;
167 }
168
169 Ok(nodes
171 .into_iter()
172 .enumerate()
173 .map(|(i, node)| (node, pr[i]))
174 .collect())
175}
176
177#[allow(dead_code)]
180pub fn parallel_random_walks<N, E, Ix>(
181 graph: &Graph<N, E, Ix>,
182 starts: &[N],
183 walk_length: usize,
184 restart_probability: f64,
185) -> Result<Vec<Vec<N>>>
186where
187 N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
188 E: EdgeWeight + Send + Sync,
189 Ix: IndexType + Send + Sync,
190{
191 starts
192 .par_iter()
193 .map(|start| random_walk(graph, start, walk_length, restart_probability))
194 .collect::<Result<Vec<_>>>()
195}
196
197pub struct BatchRandomWalker<N: Node + std::fmt::Debug> {
200 node_to_idx: HashMap<N, usize>,
202 idx_to_node: Vec<N>,
204 #[allow(dead_code)]
206 transition_probs: Vec<Vec<f64>>,
207 alias_tables: Vec<AliasTable>,
209}
210
211#[derive(Debug, Clone)]
213struct AliasTable {
214 prob: Vec<f64>,
216 alias: Vec<usize>,
218}
219
220impl AliasTable {
221 fn new(weights: &[f64]) -> Self {
223 let n = weights.len();
224 let mut prob = vec![0.0; n];
225 let mut alias = vec![0; n];
226
227 if n == 0 {
228 return AliasTable { prob, alias };
229 }
230
231 let sum: f64 = weights.iter().sum();
232 if sum == 0.0 {
233 return AliasTable { prob, alias };
234 }
235
236 let normalized: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
238
239 let mut small = Vec::new();
240 let mut large = Vec::new();
241
242 for (i, &p) in normalized.iter().enumerate() {
243 if p < 1.0 {
244 small.push(i);
245 } else {
246 large.push(i);
247 }
248 }
249
250 prob[..n].copy_from_slice(&normalized[..n]);
251
252 while let (Some(small_idx), Some(large_idx)) = (small.pop(), large.pop()) {
253 alias[small_idx] = large_idx;
254 prob[large_idx] = prob[large_idx] + prob[small_idx] - 1.0;
255
256 if prob[large_idx] < 1.0 {
257 small.push(large_idx);
258 } else {
259 large.push(large_idx);
260 }
261 }
262
263 AliasTable { prob, alias }
264 }
265
266 fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> usize {
268 if self.prob.is_empty() {
269 return 0;
270 }
271
272 let i = rng.gen_range(0..self.prob.len());
273 let coin_flip = rng.random::<f64>();
274
275 if coin_flip <= self.prob[i] {
276 i
277 } else {
278 self.alias[i]
279 }
280 }
281}
282
283impl<N: Node + Clone + Hash + Eq + std::fmt::Debug> BatchRandomWalker<N> {
284 pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Result<Self>
286 where
287 E: EdgeWeight + Into<f64>,
288 Ix: IndexType,
289 N: std::fmt::Debug,
290 {
291 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
292 let node_to_idx: HashMap<N, usize> = nodes
293 .iter()
294 .enumerate()
295 .map(|(i, n)| (n.clone(), i))
296 .collect();
297
298 let mut transition_probs = Vec::new();
299 let mut alias_tables = Vec::new();
300
301 for node in &nodes {
302 if let Ok(neighbors) = graph.neighbors(node) {
303 let neighbor_weights: Vec<f64> = neighbors
304 .iter()
305 .filter_map(|neighbor| graph.edge_weight(node, neighbor).ok())
306 .map(|w| w.into())
307 .collect();
308
309 if !neighbor_weights.is_empty() {
310 let total: f64 = neighbor_weights.iter().sum();
311 let probs: Vec<f64> = neighbor_weights.iter().map(|w| w / total).collect();
312
313 let mut cumulative = vec![0.0; probs.len()];
315 cumulative[0] = probs[0];
316 for i in 1..probs.len() {
317 cumulative[i] = cumulative[i - 1] + probs[i];
318 }
319
320 transition_probs.push(cumulative);
321 alias_tables.push(AliasTable::new(&neighbor_weights));
322 } else {
323 transition_probs.push(vec![]);
325 alias_tables.push(AliasTable::new(&[]));
326 }
327 } else {
328 transition_probs.push(vec![]);
329 alias_tables.push(AliasTable::new(&[]));
330 }
331 }
332
333 Ok(BatchRandomWalker {
334 node_to_idx,
335 idx_to_node: nodes,
336 transition_probs,
337 alias_tables,
338 })
339 }
340
341 pub fn generate_walks<E, Ix>(
343 &self,
344 graph: &Graph<N, E, Ix>,
345 starts: &[N],
346 walk_length: usize,
347 num_walks_per_node: usize,
348 ) -> Result<Vec<Vec<N>>>
349 where
350 E: EdgeWeight,
351 Ix: IndexType + std::marker::Sync,
352 N: Send + Sync + std::fmt::Debug,
353 {
354 let total_walks = starts.len() * num_walks_per_node;
355 let mut all_walks = Vec::with_capacity(total_walks);
356
357 starts
359 .par_iter()
360 .map(|start| {
361 let mut local_walks = Vec::with_capacity(num_walks_per_node);
362 let mut rng = scirs2_core::random::rng();
363
364 for _ in 0..num_walks_per_node {
365 if let Ok(walk) = self.single_walk(graph, start, walk_length, &mut rng) {
366 local_walks.push(walk);
367 }
368 }
369 local_walks
370 })
371 .collect::<Vec<_>>()
372 .into_iter()
373 .for_each(|walks| all_walks.extend(walks));
374
375 Ok(all_walks)
376 }
377
378 fn single_walk<E, Ix>(
380 &self,
381 graph: &Graph<N, E, Ix>,
382 start: &N,
383 walk_length: usize,
384 rng: &mut impl scirs2_core::random::Rng,
385 ) -> Result<Vec<N>>
386 where
387 E: EdgeWeight,
388 Ix: IndexType,
389 {
390 let mut walk = Vec::with_capacity(walk_length + 1);
391 walk.push(start.clone());
392
393 let mut current_idx = *self
394 .node_to_idx
395 .get(start)
396 .ok_or(GraphError::node_not_found("node"))?;
397
398 for _ in 0..walk_length {
399 if let Ok(neighbors) = graph.neighbors(&self.idx_to_node[current_idx]) {
400 let neighbors: Vec<_> = neighbors;
401
402 if !neighbors.is_empty() {
403 let neighbor_idx = self.alias_tables[current_idx].sample(rng);
405 if neighbor_idx < neighbors.len() {
406 let next_node = neighbors[neighbor_idx].clone();
407 walk.push(next_node.clone());
408
409 if let Some(&next_idx) = self.node_to_idx.get(&next_node) {
410 current_idx = next_idx;
411 }
412 } else {
413 break;
414 }
415 } else {
416 break;
417 }
418 } else {
419 break;
420 }
421 }
422
423 Ok(walk)
424 }
425}
426
427#[allow(dead_code)]
430pub fn node2vec_walk<N, E, Ix>(
431 graph: &Graph<N, E, Ix>,
432 start: &N,
433 walk_length: usize,
434 p: f64, q: f64, rng: &mut impl scirs2_core::random::Rng,
437) -> Result<Vec<N>>
438where
439 N: Node + Clone + Hash + Eq + std::fmt::Debug,
440 E: EdgeWeight + Into<f64>,
441 Ix: IndexType,
442{
443 let mut walk = vec![start.clone()];
444 if walk_length == 0 {
445 return Ok(walk);
446 }
447
448 if let Ok(neighbors) = graph.neighbors(start) {
450 let neighbors: Vec<_> = neighbors;
451 if neighbors.is_empty() {
452 return Ok(walk);
453 }
454
455 let idx = rng.gen_range(0..neighbors.len());
456 walk.push(neighbors[idx].clone());
457 } else {
458 return Ok(walk);
459 }
460
461 for step in 1..walk_length {
463 let current = &walk[step];
464 let previous = &walk[step - 1];
465
466 if let Ok(neighbors) = graph.neighbors(current) {
467 let neighbors: Vec<_> = neighbors;
468 if neighbors.is_empty() {
469 break;
470 }
471
472 let mut weights = Vec::with_capacity(neighbors.len());
474
475 for neighbor in &neighbors {
476 let weight = if neighbor == previous {
477 1.0 / p
479 } else if graph.has_edge(previous, neighbor) {
480 1.0
482 } else {
483 1.0 / q
485 };
486
487 let edge_weight = graph
489 .edge_weight(current, neighbor)
490 .map(|w| w.into())
491 .unwrap_or(1.0);
492
493 weights.push(weight * edge_weight);
494 }
495
496 let total: f64 = weights.iter().sum();
498 if total > 0.0 {
499 let mut cumulative = vec![0.0; weights.len()];
500 cumulative[0] = weights[0] / total;
501
502 for i in 1..weights.len() {
504 cumulative[i] = cumulative[i - 1] + weights[i] / total;
505 }
506
507 let r = rng.random::<f64>();
508 for (i, &cum_prob) in cumulative.iter().enumerate() {
509 if r <= cum_prob {
510 walk.push(neighbors[i].clone());
511 break;
512 }
513 }
514 }
515 } else {
516 break;
517 }
518 }
519
520 Ok(walk)
521}
522
523#[allow(dead_code)]
525pub fn parallel_node2vec_walks<N, E, Ix>(
526 graph: &Graph<N, E, Ix>,
527 starts: &[N],
528 walk_length: usize,
529 num_walks: usize,
530 p: f64,
531 q: f64,
532) -> Result<Vec<Vec<N>>>
533where
534 N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
535 E: EdgeWeight + Into<f64> + Send + Sync,
536 Ix: IndexType + Send + Sync,
537{
538 let total_walks = starts.len() * num_walks;
539
540 (0..total_walks)
541 .into_par_iter()
542 .map(|i| {
543 let start_idx = i % starts.len();
544 let start = &starts[start_idx];
545 let mut rng = scirs2_core::random::rng();
546 node2vec_walk(graph, start, walk_length, p, q, &mut rng)
547 })
548 .collect()
549}
550
551#[allow(dead_code)]
554pub fn simd_random_walk_with_restart<N, E, Ix>(
555 graph: &Graph<N, E, Ix>,
556 start: &N,
557 walk_length: usize,
558 restart_prob: f64,
559 rng: &mut impl scirs2_core::random::Rng,
560) -> Result<Vec<N>>
561where
562 N: Node + Clone + Hash + Eq + std::fmt::Debug,
563 E: EdgeWeight,
564 Ix: IndexType,
565{
566 let mut walk = Vec::with_capacity(walk_length + 1);
567 walk.push(start.clone());
568
569 let mut current = start.clone();
570
571 for _ in 0..walk_length {
572 if rng.random::<f64>() < restart_prob {
574 current = start.clone();
575 walk.push(current.clone());
576 continue;
577 }
578
579 if let Ok(neighbors) = graph.neighbors(¤t) {
580 let neighbors: Vec<_> = neighbors;
581 if !neighbors.is_empty() {
582 let idx = rng.gen_range(0..neighbors.len());
583 current = neighbors[idx].clone();
584 walk.push(current.clone());
585 } else {
586 current = start.clone();
587 walk.push(current.clone());
588 }
589 } else {
590 break;
591 }
592 }
593
594 Ok(walk)
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use crate::error::Result as GraphResult;
601 use crate::generators::create_graph;
602
603 #[test]
604 fn test_random_walk() -> GraphResult<()> {
605 let mut graph = create_graph::<&str, ()>();
606
607 graph.add_edge("A", "B", ())?;
609 graph.add_edge("B", "C", ())?;
610 graph.add_edge("C", "D", ())?;
611
612 let walk = random_walk(&graph, &"A", 10, 0.1)?;
614
615 assert_eq!(walk[0], "A");
617
618 assert_eq!(walk.len(), 11);
620
621 for node in &walk {
623 assert!(graph.contains_node(node));
624 }
625
626 Ok(())
627 }
628
629 #[test]
630 fn test_transition_matrix() -> GraphResult<()> {
631 let mut graph = create_graph::<&str, f64>();
632
633 graph.add_edge("A", "B", 1.0)?;
635 graph.add_edge("B", "C", 1.0)?;
636 graph.add_edge("C", "A", 1.0)?;
637
638 let (nodes, matrix) = transition_matrix(&graph)?;
639
640 assert_eq!(nodes.len(), 3);
641 assert_eq!(matrix.shape(), &[3, 3]);
642
643 for i in 0..3 {
645 let row_sum: f64 = (0..3).map(|j| matrix[[i, j]]).sum();
646 assert!((row_sum - 1.0).abs() < 1e-6);
647 }
648
649 Ok(())
650 }
651
652 #[test]
653 fn test_personalized_pagerank() -> GraphResult<()> {
654 let mut graph = create_graph::<&str, f64>();
655
656 graph.add_edge("A", "B", 1.0)?;
658 graph.add_edge("A", "C", 1.0)?;
659 graph.add_edge("A", "D", 1.0)?;
660
661 let pagerank = personalized_pagerank(&graph, &"A", 0.85, 1e-6, 100)?;
662
663 assert_eq!(pagerank.len(), 4);
665
666 let total: f64 = pagerank.values().sum();
668 assert!((total - 1.0).abs() < 1e-3);
669
670 let a_rank = pagerank[&"A"];
672 for (node, &rank) in &pagerank {
673 if node != &"A" {
674 assert!(a_rank >= rank);
675 }
676 }
677
678 Ok(())
679 }
680}