1use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use crate::error::{GraphError, Result};
7use ndarray::{Array1, Array2};
8use scirs2_core::parallel_ops::*;
9use std::collections::HashMap;
10use std::hash::Hash;
11
12#[allow(dead_code)]
16pub fn random_walk<N, E, Ix>(
17 graph: &Graph<N, E, Ix>,
18 start: &N,
19 steps: usize,
20 restart_probability: f64,
21) -> Result<Vec<N>>
22where
23 N: Node + Clone + Hash + Eq + std::fmt::Debug,
24 E: EdgeWeight,
25 Ix: IndexType,
26{
27 if !graph.contains_node(start) {
28 return Err(GraphError::node_not_found("node"));
29 }
30
31 let mut walk = vec![start.clone()];
32 let mut current = start.clone();
33 let mut rng = rand::rng();
34
35 use rand::Rng;
36
37 for _ in 0..steps {
38 if rng.random::<f64>() < restart_probability {
40 current = start.clone();
41 walk.push(current.clone());
42 continue;
43 }
44
45 if let Ok(neighbors) = graph.neighbors(¤t) {
47 let neighbor_vec: Vec<N> = neighbors;
48
49 if !neighbor_vec.is_empty() {
50 let idx = rng.gen_range(0..neighbor_vec.len());
51 current = neighbor_vec[idx].clone();
52 walk.push(current.clone());
53 } else {
54 current = start.clone();
56 walk.push(current.clone());
57 }
58 }
59 }
60
61 Ok(walk)
62}
63
64#[allow(dead_code)]
69pub fn transition_matrix<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<(Vec<N>, Array2<f64>)>
70where
71 N: Node + Clone + std::fmt::Debug,
72 E: EdgeWeight + Into<f64>,
73 Ix: IndexType,
74{
75 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
76 let n = nodes.len();
77
78 if n == 0 {
79 return Err(GraphError::InvalidGraph("Empty graph".to_string()));
80 }
81
82 let mut matrix = Array2::<f64>::zeros((n, n));
83
84 for (i, node) in nodes.iter().enumerate() {
85 if let Ok(neighbors) = graph.neighbors(node) {
86 let neighbor_weights: Vec<(usize, f64)> = neighbors
87 .into_iter()
88 .filter_map(|neighbor| {
89 nodes.iter().position(|n| n == &neighbor).and_then(|j| {
90 graph
91 .edge_weight(node, &neighbor)
92 .ok()
93 .map(|w| (j, w.into()))
94 })
95 })
96 .collect();
97
98 let total_weight: f64 = neighbor_weights.iter().map(|(_, w)| w).sum();
99
100 if total_weight > 0.0 {
101 for (j, weight) in neighbor_weights {
102 matrix[[i, j]] = weight / total_weight;
103 }
104 } else {
105 for j in 0..n {
107 matrix[[i, j]] = 1.0 / n as f64;
108 }
109 }
110 }
111 }
112
113 Ok((nodes, matrix))
114}
115
116#[allow(dead_code)]
120pub fn personalized_pagerank<N, E, Ix>(
121 graph: &Graph<N, E, Ix>,
122 source: &N,
123 damping: f64,
124 tolerance: f64,
125 max_iter: usize,
126) -> Result<HashMap<N, f64>>
127where
128 N: Node + Clone + Hash + Eq + std::fmt::Debug,
129 E: EdgeWeight + Into<f64>,
130 Ix: IndexType,
131{
132 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
133 let n = nodes.len();
134
135 if n == 0 || !graph.contains_node(source) {
136 return Err(GraphError::node_not_found("node"));
137 }
138
139 let source_idx = nodes.iter().position(|n| n == source).unwrap();
141
142 let (_, trans_matrix) = transition_matrix(graph)?;
144
145 let mut pr = Array1::<f64>::zeros(n);
147 pr[source_idx] = 1.0;
148
149 let mut personalization = Array1::<f64>::zeros(n);
151 personalization[source_idx] = 1.0;
152
153 for _ in 0..max_iter {
155 let new_pr = damping * trans_matrix.t().dot(&pr) + (1.0 - damping) * &personalization;
156
157 let diff: f64 = (&new_pr - &pr).iter().map(|x| x.abs()).sum();
159 if diff < tolerance {
160 break;
161 }
162
163 pr = new_pr;
164 }
165
166 Ok(nodes
168 .into_iter()
169 .enumerate()
170 .map(|(i, node)| (node, pr[i]))
171 .collect())
172}
173
174#[allow(dead_code)]
177pub fn parallel_random_walks<N, E, Ix>(
178 graph: &Graph<N, E, Ix>,
179 starts: &[N],
180 walk_length: usize,
181 restart_probability: f64,
182) -> Result<Vec<Vec<N>>>
183where
184 N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
185 E: EdgeWeight + Send + Sync,
186 Ix: IndexType + Send + Sync,
187{
188 starts
189 .par_iter()
190 .map(|start| random_walk(graph, start, walk_length, restart_probability))
191 .collect::<Result<Vec<_>>>()
192}
193
194pub struct BatchRandomWalker<N: Node + std::fmt::Debug> {
197 node_to_idx: HashMap<N, usize>,
199 idx_to_node: Vec<N>,
201 #[allow(dead_code)]
203 transition_probs: Vec<Vec<f64>>,
204 alias_tables: Vec<AliasTable>,
206}
207
208#[derive(Debug, Clone)]
210struct AliasTable {
211 prob: Vec<f64>,
213 alias: Vec<usize>,
215}
216
217impl AliasTable {
218 fn new(weights: &[f64]) -> Self {
220 let n = weights.len();
221 let mut prob = vec![0.0; n];
222 let mut alias = vec![0; n];
223
224 if n == 0 {
225 return AliasTable { prob, alias };
226 }
227
228 let sum: f64 = weights.iter().sum();
229 if sum == 0.0 {
230 return AliasTable { prob, alias };
231 }
232
233 let normalized: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
235
236 let mut small = Vec::new();
237 let mut large = Vec::new();
238
239 for (i, &p) in normalized.iter().enumerate() {
240 if p < 1.0 {
241 small.push(i);
242 } else {
243 large.push(i);
244 }
245 }
246
247 prob[..n].copy_from_slice(&normalized[..n]);
248
249 while let (Some(small_idx), Some(large_idx)) = (small.pop(), large.pop()) {
250 alias[small_idx] = large_idx;
251 prob[large_idx] = prob[large_idx] + prob[small_idx] - 1.0;
252
253 if prob[large_idx] < 1.0 {
254 small.push(large_idx);
255 } else {
256 large.push(large_idx);
257 }
258 }
259
260 AliasTable { prob, alias }
261 }
262
263 fn sample(&self, rng: &mut impl rand::Rng) -> usize {
265 if self.prob.is_empty() {
266 return 0;
267 }
268
269 let i = rng.gen_range(0..self.prob.len());
270 let coin_flip = rng.random::<f64>();
271
272 if coin_flip <= self.prob[i] {
273 i
274 } else {
275 self.alias[i]
276 }
277 }
278}
279
280impl<N: Node + Clone + Hash + Eq + std::fmt::Debug> BatchRandomWalker<N> {
281 pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Result<Self>
283 where
284 E: EdgeWeight + Into<f64>,
285 Ix: IndexType,
286 N: std::fmt::Debug,
287 {
288 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
289 let node_to_idx: HashMap<N, usize> = nodes
290 .iter()
291 .enumerate()
292 .map(|(i, n)| (n.clone(), i))
293 .collect();
294
295 let mut transition_probs = Vec::new();
296 let mut alias_tables = Vec::new();
297
298 for node in &nodes {
299 if let Ok(neighbors) = graph.neighbors(node) {
300 let neighbor_weights: Vec<f64> = neighbors
301 .iter()
302 .filter_map(|neighbor| graph.edge_weight(node, neighbor).ok())
303 .map(|w| w.into())
304 .collect();
305
306 if !neighbor_weights.is_empty() {
307 let total: f64 = neighbor_weights.iter().sum();
308 let probs: Vec<f64> = neighbor_weights.iter().map(|w| w / total).collect();
309
310 let mut cumulative = vec![0.0; probs.len()];
312 cumulative[0] = probs[0];
313 for i in 1..probs.len() {
314 cumulative[i] = cumulative[i - 1] + probs[i];
315 }
316
317 transition_probs.push(cumulative);
318 alias_tables.push(AliasTable::new(&neighbor_weights));
319 } else {
320 transition_probs.push(vec![]);
322 alias_tables.push(AliasTable::new(&[]));
323 }
324 } else {
325 transition_probs.push(vec![]);
326 alias_tables.push(AliasTable::new(&[]));
327 }
328 }
329
330 Ok(BatchRandomWalker {
331 node_to_idx,
332 idx_to_node: nodes,
333 transition_probs,
334 alias_tables,
335 })
336 }
337
338 pub fn generate_walks<E, Ix>(
340 &self,
341 graph: &Graph<N, E, Ix>,
342 starts: &[N],
343 walk_length: usize,
344 num_walks_per_node: usize,
345 ) -> Result<Vec<Vec<N>>>
346 where
347 E: EdgeWeight,
348 Ix: IndexType + std::marker::Sync,
349 N: Send + Sync + std::fmt::Debug,
350 {
351 let total_walks = starts.len() * num_walks_per_node;
352 let mut all_walks = Vec::with_capacity(total_walks);
353
354 starts
356 .par_iter()
357 .map(|start| {
358 let mut local_walks = Vec::with_capacity(num_walks_per_node);
359 let mut rng = rand::rng();
360
361 for _ in 0..num_walks_per_node {
362 if let Ok(walk) = self.single_walk(graph, start, walk_length, &mut rng) {
363 local_walks.push(walk);
364 }
365 }
366 local_walks
367 })
368 .collect::<Vec<_>>()
369 .into_iter()
370 .for_each(|walks| all_walks.extend(walks));
371
372 Ok(all_walks)
373 }
374
375 fn single_walk<E, Ix>(
377 &self,
378 graph: &Graph<N, E, Ix>,
379 start: &N,
380 walk_length: usize,
381 rng: &mut impl rand::Rng,
382 ) -> Result<Vec<N>>
383 where
384 E: EdgeWeight,
385 Ix: IndexType,
386 {
387 let mut walk = Vec::with_capacity(walk_length + 1);
388 walk.push(start.clone());
389
390 let mut current_idx = *self
391 .node_to_idx
392 .get(start)
393 .ok_or(GraphError::node_not_found("node"))?;
394
395 for _ in 0..walk_length {
396 if let Ok(neighbors) = graph.neighbors(&self.idx_to_node[current_idx]) {
397 let neighbors: Vec<_> = neighbors;
398
399 if !neighbors.is_empty() {
400 let neighbor_idx = self.alias_tables[current_idx].sample(rng);
402 if neighbor_idx < neighbors.len() {
403 let next_node = neighbors[neighbor_idx].clone();
404 walk.push(next_node.clone());
405
406 if let Some(&next_idx) = self.node_to_idx.get(&next_node) {
407 current_idx = next_idx;
408 }
409 } else {
410 break;
411 }
412 } else {
413 break;
414 }
415 } else {
416 break;
417 }
418 }
419
420 Ok(walk)
421 }
422}
423
424#[allow(dead_code)]
427pub fn node2vec_walk<N, E, Ix>(
428 graph: &Graph<N, E, Ix>,
429 start: &N,
430 walk_length: usize,
431 p: f64, q: f64, rng: &mut impl rand::Rng,
434) -> Result<Vec<N>>
435where
436 N: Node + Clone + Hash + Eq + std::fmt::Debug,
437 E: EdgeWeight + Into<f64>,
438 Ix: IndexType,
439{
440 let mut walk = vec![start.clone()];
441 if walk_length == 0 {
442 return Ok(walk);
443 }
444
445 if let Ok(neighbors) = graph.neighbors(start) {
447 let neighbors: Vec<_> = neighbors;
448 if neighbors.is_empty() {
449 return Ok(walk);
450 }
451
452 let idx = rng.gen_range(0..neighbors.len());
453 walk.push(neighbors[idx].clone());
454 } else {
455 return Ok(walk);
456 }
457
458 for step in 1..walk_length {
460 let current = &walk[step];
461 let previous = &walk[step - 1];
462
463 if let Ok(neighbors) = graph.neighbors(current) {
464 let neighbors: Vec<_> = neighbors;
465 if neighbors.is_empty() {
466 break;
467 }
468
469 let mut weights = Vec::with_capacity(neighbors.len());
471
472 for neighbor in &neighbors {
473 let weight = if neighbor == previous {
474 1.0 / p
476 } else if graph.has_edge(previous, neighbor) {
477 1.0
479 } else {
480 1.0 / q
482 };
483
484 let edge_weight = graph
486 .edge_weight(current, neighbor)
487 .map(|w| w.into())
488 .unwrap_or(1.0);
489
490 weights.push(weight * edge_weight);
491 }
492
493 let total: f64 = weights.iter().sum();
495 if total > 0.0 {
496 let mut cumulative = vec![0.0; weights.len()];
497 cumulative[0] = weights[0] / total;
498
499 for i in 1..weights.len() {
501 cumulative[i] = cumulative[i - 1] + weights[i] / total;
502 }
503
504 let r = rng.random::<f64>();
505 for (i, &cum_prob) in cumulative.iter().enumerate() {
506 if r <= cum_prob {
507 walk.push(neighbors[i].clone());
508 break;
509 }
510 }
511 }
512 } else {
513 break;
514 }
515 }
516
517 Ok(walk)
518}
519
520#[allow(dead_code)]
522pub fn parallel_node2vec_walks<N, E, Ix>(
523 graph: &Graph<N, E, Ix>,
524 starts: &[N],
525 walk_length: usize,
526 num_walks: usize,
527 p: f64,
528 q: f64,
529) -> Result<Vec<Vec<N>>>
530where
531 N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
532 E: EdgeWeight + Into<f64> + Send + Sync,
533 Ix: IndexType + Send + Sync,
534{
535 let total_walks = starts.len() * num_walks;
536
537 (0..total_walks)
538 .into_par_iter()
539 .map(|i| {
540 let start_idx = i % starts.len();
541 let start = &starts[start_idx];
542 let mut rng = rand::rng();
543 node2vec_walk(graph, start, walk_length, p, q, &mut rng)
544 })
545 .collect()
546}
547
548#[allow(dead_code)]
551pub fn simd_random_walk_with_restart<N, E, Ix>(
552 graph: &Graph<N, E, Ix>,
553 start: &N,
554 walk_length: usize,
555 restart_prob: f64,
556 rng: &mut impl rand::Rng,
557) -> Result<Vec<N>>
558where
559 N: Node + Clone + Hash + Eq + std::fmt::Debug,
560 E: EdgeWeight,
561 Ix: IndexType,
562{
563 let mut walk = Vec::with_capacity(walk_length + 1);
564 walk.push(start.clone());
565
566 let mut current = start.clone();
567
568 for _ in 0..walk_length {
569 if rng.random::<f64>() < restart_prob {
571 current = start.clone();
572 walk.push(current.clone());
573 continue;
574 }
575
576 if let Ok(neighbors) = graph.neighbors(¤t) {
577 let neighbors: Vec<_> = neighbors;
578 if !neighbors.is_empty() {
579 let idx = rng.gen_range(0..neighbors.len());
580 current = neighbors[idx].clone();
581 walk.push(current.clone());
582 } else {
583 current = start.clone();
584 walk.push(current.clone());
585 }
586 } else {
587 break;
588 }
589 }
590
591 Ok(walk)
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use crate::error::Result as GraphResult;
598 use crate::generators::create_graph;
599
600 #[test]
601 fn test_random_walk() -> GraphResult<()> {
602 let mut graph = create_graph::<&str, ()>();
603
604 graph.add_edge("A", "B", ())?;
606 graph.add_edge("B", "C", ())?;
607 graph.add_edge("C", "D", ())?;
608
609 let walk = random_walk(&graph, &"A", 10, 0.1)?;
611
612 assert_eq!(walk[0], "A");
614
615 assert_eq!(walk.len(), 11);
617
618 for node in &walk {
620 assert!(graph.contains_node(node));
621 }
622
623 Ok(())
624 }
625
626 #[test]
627 fn test_transition_matrix() -> GraphResult<()> {
628 let mut graph = create_graph::<&str, f64>();
629
630 graph.add_edge("A", "B", 1.0)?;
632 graph.add_edge("B", "C", 1.0)?;
633 graph.add_edge("C", "A", 1.0)?;
634
635 let (nodes, matrix) = transition_matrix(&graph)?;
636
637 assert_eq!(nodes.len(), 3);
638 assert_eq!(matrix.shape(), &[3, 3]);
639
640 for i in 0..3 {
642 let row_sum: f64 = (0..3).map(|j| matrix[[i, j]]).sum();
643 assert!((row_sum - 1.0).abs() < 1e-6);
644 }
645
646 Ok(())
647 }
648
649 #[test]
650 fn test_personalized_pagerank() -> GraphResult<()> {
651 let mut graph = create_graph::<&str, f64>();
652
653 graph.add_edge("A", "B", 1.0)?;
655 graph.add_edge("A", "C", 1.0)?;
656 graph.add_edge("A", "D", 1.0)?;
657
658 let pagerank = personalized_pagerank(&graph, &"A", 0.85, 1e-6, 100)?;
659
660 assert_eq!(pagerank.len(), 4);
662
663 let total: f64 = pagerank.values().sum();
665 assert!((total - 1.0).abs() < 1e-3);
666
667 let a_rank = pagerank[&"A"];
669 for (node, &rank) in &pagerank {
670 if node != &"A" {
671 assert!(a_rank >= rank);
672 }
673 }
674
675 Ok(())
676 }
677}