1use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use crate::error::{GraphError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::parallel_ops::*;
9use scirs2_core::random::RngExt;
10use std::collections::HashMap;
11use std::hash::Hash;
12
13#[allow(dead_code)]
17pub fn random_walk<N, E, Ix>(
18 graph: &Graph<N, E, Ix>,
19 start: &N,
20 steps: usize,
21 restart_probability: f64,
22) -> Result<Vec<N>>
23where
24 N: Node + Clone + Hash + Eq + std::fmt::Debug,
25 E: EdgeWeight,
26 Ix: IndexType,
27{
28 if !graph.contains_node(start) {
29 return Err(GraphError::node_not_found("node"));
30 }
31
32 let mut walk = vec![start.clone()];
33 let mut current = start.clone();
34 let mut rng = scirs2_core::random::rng();
35
36 use scirs2_core::random::{Rng, RngExt};
37
38 for _ in 0..steps {
39 if rng.random::<f64>() < restart_probability {
41 current = start.clone();
42 walk.push(current.clone());
43 continue;
44 }
45
46 if let Ok(neighbors) = graph.neighbors(¤t) {
48 let neighbor_vec: Vec<N> = neighbors;
49
50 if !neighbor_vec.is_empty() {
51 let idx = rng.random_range(0..neighbor_vec.len());
52 current = neighbor_vec[idx].clone();
53 walk.push(current.clone());
54 } else {
55 current = start.clone();
57 walk.push(current.clone());
58 }
59 }
60 }
61
62 Ok(walk)
63}
64
65#[allow(dead_code)]
70pub fn transition_matrix<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<(Vec<N>, Array2<f64>)>
71where
72 N: Node + Clone + std::fmt::Debug,
73 E: EdgeWeight + Into<f64>,
74 Ix: IndexType,
75{
76 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
77 let n = nodes.len();
78
79 if n == 0 {
80 return Err(GraphError::InvalidGraph("Empty graph".to_string()));
81 }
82
83 let mut matrix = Array2::<f64>::zeros((n, n));
84
85 for (i, node) in nodes.iter().enumerate() {
86 if let Ok(neighbors) = graph.neighbors(node) {
87 let neighbor_weights: Vec<(usize, f64)> = neighbors
88 .into_iter()
89 .filter_map(|neighbor| {
90 nodes.iter().position(|n| n == &neighbor).and_then(|j| {
91 graph
92 .edge_weight(node, &neighbor)
93 .ok()
94 .map(|w| (j, w.into()))
95 })
96 })
97 .collect();
98
99 let total_weight: f64 = neighbor_weights.iter().map(|(_, w)| w).sum();
100
101 if total_weight > 0.0 {
102 for (j, weight) in neighbor_weights {
103 matrix[[i, j]] = weight / total_weight;
104 }
105 } else {
106 for j in 0..n {
108 matrix[[i, j]] = 1.0 / n as f64;
109 }
110 }
111 }
112 }
113
114 Ok((nodes, matrix))
115}
116
117#[allow(dead_code)]
121pub fn personalized_pagerank<N, E, Ix>(
122 graph: &Graph<N, E, Ix>,
123 source: &N,
124 damping: f64,
125 tolerance: f64,
126 max_iter: usize,
127) -> Result<HashMap<N, f64>>
128where
129 N: Node + Clone + Hash + Eq + std::fmt::Debug,
130 E: EdgeWeight + Into<f64>,
131 Ix: IndexType,
132{
133 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
134 let n = nodes.len();
135
136 if n == 0 || !graph.contains_node(source) {
137 return Err(GraphError::node_not_found("node"));
138 }
139
140 let source_idx = nodes
142 .iter()
143 .position(|n| n == source)
144 .expect("Operation failed");
145
146 let (_, trans_matrix) = transition_matrix(graph)?;
148
149 let mut pr = Array1::<f64>::zeros(n);
151 pr[source_idx] = 1.0;
152
153 let mut personalization = Array1::<f64>::zeros(n);
155 personalization[source_idx] = 1.0;
156
157 for _ in 0..max_iter {
159 let new_pr = damping * trans_matrix.t().dot(&pr) + (1.0 - damping) * &personalization;
160
161 let diff: f64 = (&new_pr - &pr).iter().map(|x| x.abs()).sum();
163 if diff < tolerance {
164 break;
165 }
166
167 pr = new_pr;
168 }
169
170 Ok(nodes
172 .into_iter()
173 .enumerate()
174 .map(|(i, node)| (node, pr[i]))
175 .collect())
176}
177
178#[allow(dead_code)]
181pub fn parallel_random_walks<N, E, Ix>(
182 graph: &Graph<N, E, Ix>,
183 starts: &[N],
184 walk_length: usize,
185 restart_probability: f64,
186) -> Result<Vec<Vec<N>>>
187where
188 N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
189 E: EdgeWeight + Send + Sync,
190 Ix: IndexType + Send + Sync,
191{
192 starts
193 .par_iter()
194 .map(|start| random_walk(graph, start, walk_length, restart_probability))
195 .collect::<Result<Vec<_>>>()
196}
197
198pub struct BatchRandomWalker<N: Node + std::fmt::Debug> {
201 node_to_idx: HashMap<N, usize>,
203 idx_to_node: Vec<N>,
205 #[allow(dead_code)]
207 transition_probs: Vec<Vec<f64>>,
208 alias_tables: Vec<AliasTable>,
210}
211
212#[derive(Debug, Clone)]
214struct AliasTable {
215 prob: Vec<f64>,
217 alias: Vec<usize>,
219}
220
221impl AliasTable {
222 fn new(weights: &[f64]) -> Self {
224 let n = weights.len();
225 let mut prob = vec![0.0; n];
226 let mut alias = vec![0; n];
227
228 if n == 0 {
229 return AliasTable { prob, alias };
230 }
231
232 let sum: f64 = weights.iter().sum();
233 if sum == 0.0 {
234 return AliasTable { prob, alias };
235 }
236
237 let normalized: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
239
240 let mut small = Vec::new();
241 let mut large = Vec::new();
242
243 for (i, &p) in normalized.iter().enumerate() {
244 if p < 1.0 {
245 small.push(i);
246 } else {
247 large.push(i);
248 }
249 }
250
251 prob[..n].copy_from_slice(&normalized[..n]);
252
253 while let (Some(small_idx), Some(large_idx)) = (small.pop(), large.pop()) {
254 alias[small_idx] = large_idx;
255 prob[large_idx] = prob[large_idx] + prob[small_idx] - 1.0;
256
257 if prob[large_idx] < 1.0 {
258 small.push(large_idx);
259 } else {
260 large.push(large_idx);
261 }
262 }
263
264 AliasTable { prob, alias }
265 }
266
267 fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> usize {
269 if self.prob.is_empty() {
270 return 0;
271 }
272
273 let i = rng.random_range(0..self.prob.len());
274 let coin_flip = rng.random::<f64>();
275
276 if coin_flip <= self.prob[i] {
277 i
278 } else {
279 self.alias[i]
280 }
281 }
282}
283
284impl<N: Node + Clone + Hash + Eq + std::fmt::Debug> BatchRandomWalker<N> {
285 pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Result<Self>
287 where
288 E: EdgeWeight + Into<f64>,
289 Ix: IndexType,
290 N: std::fmt::Debug,
291 {
292 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
293 let node_to_idx: HashMap<N, usize> = nodes
294 .iter()
295 .enumerate()
296 .map(|(i, n)| (n.clone(), i))
297 .collect();
298
299 let mut transition_probs = Vec::new();
300 let mut alias_tables = Vec::new();
301
302 for node in &nodes {
303 if let Ok(neighbors) = graph.neighbors(node) {
304 let neighbor_weights: Vec<f64> = neighbors
305 .iter()
306 .filter_map(|neighbor| graph.edge_weight(node, neighbor).ok())
307 .map(|w| w.into())
308 .collect();
309
310 if !neighbor_weights.is_empty() {
311 let total: f64 = neighbor_weights.iter().sum();
312 let probs: Vec<f64> = neighbor_weights.iter().map(|w| w / total).collect();
313
314 let mut cumulative = vec![0.0; probs.len()];
316 cumulative[0] = probs[0];
317 for i in 1..probs.len() {
318 cumulative[i] = cumulative[i - 1] + probs[i];
319 }
320
321 transition_probs.push(cumulative);
322 alias_tables.push(AliasTable::new(&neighbor_weights));
323 } else {
324 transition_probs.push(vec![]);
326 alias_tables.push(AliasTable::new(&[]));
327 }
328 } else {
329 transition_probs.push(vec![]);
330 alias_tables.push(AliasTable::new(&[]));
331 }
332 }
333
334 Ok(BatchRandomWalker {
335 node_to_idx,
336 idx_to_node: nodes,
337 transition_probs,
338 alias_tables,
339 })
340 }
341
342 pub fn generate_walks<E, Ix>(
344 &self,
345 graph: &Graph<N, E, Ix>,
346 starts: &[N],
347 walk_length: usize,
348 num_walks_per_node: usize,
349 ) -> Result<Vec<Vec<N>>>
350 where
351 E: EdgeWeight,
352 Ix: IndexType + std::marker::Sync,
353 N: Send + Sync + std::fmt::Debug,
354 {
355 let total_walks = starts.len() * num_walks_per_node;
356 let mut all_walks = Vec::with_capacity(total_walks);
357
358 starts
360 .par_iter()
361 .map(|start| {
362 let mut local_walks = Vec::with_capacity(num_walks_per_node);
363 let mut rng = scirs2_core::random::rng();
364
365 for _ in 0..num_walks_per_node {
366 if let Ok(walk) = self.single_walk(graph, start, walk_length, &mut rng) {
367 local_walks.push(walk);
368 }
369 }
370 local_walks
371 })
372 .collect::<Vec<_>>()
373 .into_iter()
374 .for_each(|walks| all_walks.extend(walks));
375
376 Ok(all_walks)
377 }
378
379 fn single_walk<E, Ix>(
381 &self,
382 graph: &Graph<N, E, Ix>,
383 start: &N,
384 walk_length: usize,
385 rng: &mut impl scirs2_core::random::Rng,
386 ) -> Result<Vec<N>>
387 where
388 E: EdgeWeight,
389 Ix: IndexType,
390 {
391 let mut walk = Vec::with_capacity(walk_length + 1);
392 walk.push(start.clone());
393
394 let mut current_idx = *self
395 .node_to_idx
396 .get(start)
397 .ok_or(GraphError::node_not_found("node"))?;
398
399 for _ in 0..walk_length {
400 if let Ok(neighbors) = graph.neighbors(&self.idx_to_node[current_idx]) {
401 let neighbors: Vec<_> = neighbors;
402
403 if !neighbors.is_empty() {
404 let neighbor_idx = self.alias_tables[current_idx].sample(rng);
406 if neighbor_idx < neighbors.len() {
407 let next_node = neighbors[neighbor_idx].clone();
408 walk.push(next_node.clone());
409
410 if let Some(&next_idx) = self.node_to_idx.get(&next_node) {
411 current_idx = next_idx;
412 }
413 } else {
414 break;
415 }
416 } else {
417 break;
418 }
419 } else {
420 break;
421 }
422 }
423
424 Ok(walk)
425 }
426}
427
428#[allow(dead_code)]
431pub fn node2vec_walk<N, E, Ix>(
432 graph: &Graph<N, E, Ix>,
433 start: &N,
434 walk_length: usize,
435 p: f64, q: f64, rng: &mut impl scirs2_core::random::Rng,
438) -> Result<Vec<N>>
439where
440 N: Node + Clone + Hash + Eq + std::fmt::Debug,
441 E: EdgeWeight + Into<f64>,
442 Ix: IndexType,
443{
444 let mut walk = vec![start.clone()];
445 if walk_length == 0 {
446 return Ok(walk);
447 }
448
449 if let Ok(neighbors) = graph.neighbors(start) {
451 let neighbors: Vec<_> = neighbors;
452 if neighbors.is_empty() {
453 return Ok(walk);
454 }
455
456 let idx = rng.random_range(0..neighbors.len());
457 walk.push(neighbors[idx].clone());
458 } else {
459 return Ok(walk);
460 }
461
462 for step in 1..walk_length {
464 let current = &walk[step];
465 let previous = &walk[step - 1];
466
467 if let Ok(neighbors) = graph.neighbors(current) {
468 let neighbors: Vec<_> = neighbors;
469 if neighbors.is_empty() {
470 break;
471 }
472
473 let mut weights = Vec::with_capacity(neighbors.len());
475
476 for neighbor in &neighbors {
477 let weight = if neighbor == previous {
478 1.0 / p
480 } else if graph.has_edge(previous, neighbor) {
481 1.0
483 } else {
484 1.0 / q
486 };
487
488 let edge_weight = graph
490 .edge_weight(current, neighbor)
491 .map(|w| w.into())
492 .unwrap_or(1.0);
493
494 weights.push(weight * edge_weight);
495 }
496
497 let total: f64 = weights.iter().sum();
499 if total > 0.0 {
500 let mut cumulative = vec![0.0; weights.len()];
501 cumulative[0] = weights[0] / total;
502
503 for i in 1..weights.len() {
505 cumulative[i] = cumulative[i - 1] + weights[i] / total;
506 }
507
508 let r = rng.random::<f64>();
509 for (i, &cum_prob) in cumulative.iter().enumerate() {
510 if r <= cum_prob {
511 walk.push(neighbors[i].clone());
512 break;
513 }
514 }
515 }
516 } else {
517 break;
518 }
519 }
520
521 Ok(walk)
522}
523
524#[allow(dead_code)]
526pub fn parallel_node2vec_walks<N, E, Ix>(
527 graph: &Graph<N, E, Ix>,
528 starts: &[N],
529 walk_length: usize,
530 num_walks: usize,
531 p: f64,
532 q: f64,
533) -> Result<Vec<Vec<N>>>
534where
535 N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
536 E: EdgeWeight + Into<f64> + Send + Sync,
537 Ix: IndexType + Send + Sync,
538{
539 let total_walks = starts.len() * num_walks;
540
541 (0..total_walks)
542 .into_par_iter()
543 .map(|i| {
544 let start_idx = i % starts.len();
545 let start = &starts[start_idx];
546 let mut rng = scirs2_core::random::rng();
547 node2vec_walk(graph, start, walk_length, p, q, &mut rng)
548 })
549 .collect()
550}
551
552#[allow(dead_code)]
555pub fn simd_random_walk_with_restart<N, E, Ix>(
556 graph: &Graph<N, E, Ix>,
557 start: &N,
558 walk_length: usize,
559 restart_prob: f64,
560 rng: &mut impl scirs2_core::random::Rng,
561) -> Result<Vec<N>>
562where
563 N: Node + Clone + Hash + Eq + std::fmt::Debug,
564 E: EdgeWeight,
565 Ix: IndexType,
566{
567 let mut walk = Vec::with_capacity(walk_length + 1);
568 walk.push(start.clone());
569
570 let mut current = start.clone();
571
572 for _ in 0..walk_length {
573 if rng.random::<f64>() < restart_prob {
575 current = start.clone();
576 walk.push(current.clone());
577 continue;
578 }
579
580 if let Ok(neighbors) = graph.neighbors(¤t) {
581 let neighbors: Vec<_> = neighbors;
582 if !neighbors.is_empty() {
583 let idx = rng.random_range(0..neighbors.len());
584 current = neighbors[idx].clone();
585 walk.push(current.clone());
586 } else {
587 current = start.clone();
588 walk.push(current.clone());
589 }
590 } else {
591 break;
592 }
593 }
594
595 Ok(walk)
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::error::Result as GraphResult;
602 use crate::generators::create_graph;
603
604 #[test]
605 fn test_random_walk() -> GraphResult<()> {
606 let mut graph = create_graph::<&str, ()>();
607
608 graph.add_edge("A", "B", ())?;
610 graph.add_edge("B", "C", ())?;
611 graph.add_edge("C", "D", ())?;
612
613 let walk = random_walk(&graph, &"A", 10, 0.1)?;
615
616 assert_eq!(walk[0], "A");
618
619 assert_eq!(walk.len(), 11);
621
622 for node in &walk {
624 assert!(graph.contains_node(node));
625 }
626
627 Ok(())
628 }
629
630 #[test]
631 fn test_transition_matrix() -> GraphResult<()> {
632 let mut graph = create_graph::<&str, f64>();
633
634 graph.add_edge("A", "B", 1.0)?;
636 graph.add_edge("B", "C", 1.0)?;
637 graph.add_edge("C", "A", 1.0)?;
638
639 let (nodes, matrix) = transition_matrix(&graph)?;
640
641 assert_eq!(nodes.len(), 3);
642 assert_eq!(matrix.shape(), &[3, 3]);
643
644 for i in 0..3 {
646 let row_sum: f64 = (0..3).map(|j| matrix[[i, j]]).sum();
647 assert!((row_sum - 1.0).abs() < 1e-6);
648 }
649
650 Ok(())
651 }
652
653 #[test]
654 fn test_personalized_pagerank() -> GraphResult<()> {
655 let mut graph = create_graph::<&str, f64>();
656
657 graph.add_edge("A", "B", 1.0)?;
659 graph.add_edge("A", "C", 1.0)?;
660 graph.add_edge("A", "D", 1.0)?;
661
662 let pagerank = personalized_pagerank(&graph, &"A", 0.85, 1e-6, 100)?;
663
664 assert_eq!(pagerank.len(), 4);
666
667 let total: f64 = pagerank.values().sum();
669 assert!((total - 1.0).abs() < 1e-3);
670
671 let a_rank = pagerank[&"A"];
673 for (node, &rank) in &pagerank {
674 if node != &"A" {
675 assert!(a_rank >= rank);
676 }
677 }
678
679 Ok(())
680 }
681}