Skip to main content

scirs2_graph/algorithms/
random_walk.rs

1//! Random walk algorithms
2//!
3//! This module contains algorithms related to random walks on graphs.
4
5use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use crate::error::{GraphError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::parallel_ops::*;
9use scirs2_core::random::RngExt;
10use std::collections::HashMap;
11use std::hash::Hash;
12
13/// Perform a random walk on the graph
14///
15/// Returns a sequence of nodes visited during the walk.
16#[allow(dead_code)]
17pub fn random_walk<N, E, Ix>(
18    graph: &Graph<N, E, Ix>,
19    start: &N,
20    steps: usize,
21    restart_probability: f64,
22) -> Result<Vec<N>>
23where
24    N: Node + Clone + Hash + Eq + std::fmt::Debug,
25    E: EdgeWeight,
26    Ix: IndexType,
27{
28    if !graph.contains_node(start) {
29        return Err(GraphError::node_not_found("node"));
30    }
31
32    let mut walk = vec![start.clone()];
33    let mut current = start.clone();
34    let mut rng = scirs2_core::random::rng();
35
36    use scirs2_core::random::{Rng, RngExt};
37
38    for _ in 0..steps {
39        // With restart_probability, jump back to start
40        if rng.random::<f64>() < restart_probability {
41            current = start.clone();
42            walk.push(current.clone());
43            continue;
44        }
45
46        // Otherwise, move to a random neighbor
47        if let Ok(neighbors) = graph.neighbors(&current) {
48            let neighbor_vec: Vec<N> = neighbors;
49
50            if !neighbor_vec.is_empty() {
51                let idx = rng.random_range(0..neighbor_vec.len());
52                current = neighbor_vec[idx].clone();
53                walk.push(current.clone());
54            } else {
55                // No neighbors..restart
56                current = start.clone();
57                walk.push(current.clone());
58            }
59        }
60    }
61
62    Ok(walk)
63}
64
65/// Compute the transition matrix for random walks on the graph
66///
67/// Returns a row-stochastic matrix where entry (i,j) is the probability
68/// of transitioning from node i to node j.
69#[allow(dead_code)]
70pub fn transition_matrix<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<(Vec<N>, Array2<f64>)>
71where
72    N: Node + Clone + std::fmt::Debug,
73    E: EdgeWeight + Into<f64>,
74    Ix: IndexType,
75{
76    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
77    let n = nodes.len();
78
79    if n == 0 {
80        return Err(GraphError::InvalidGraph("Empty graph".to_string()));
81    }
82
83    let mut matrix = Array2::<f64>::zeros((n, n));
84
85    for (i, node) in nodes.iter().enumerate() {
86        if let Ok(neighbors) = graph.neighbors(node) {
87            let neighbor_weights: Vec<(usize, f64)> = neighbors
88                .into_iter()
89                .filter_map(|neighbor| {
90                    nodes.iter().position(|n| n == &neighbor).and_then(|j| {
91                        graph
92                            .edge_weight(node, &neighbor)
93                            .ok()
94                            .map(|w| (j, w.into()))
95                    })
96                })
97                .collect();
98
99            let total_weight: f64 = neighbor_weights.iter().map(|(_, w)| w).sum();
100
101            if total_weight > 0.0 {
102                for (j, weight) in neighbor_weights {
103                    matrix[[i, j]] = weight / total_weight;
104                }
105            } else {
106                // Dangling node: uniform distribution
107                for j in 0..n {
108                    matrix[[i, j]] = 1.0 / n as f64;
109                }
110            }
111        }
112    }
113
114    Ok((nodes, matrix))
115}
116
117/// Compute personalized PageRank from a given source node
118///
119/// This is useful for measuring node similarity and influence.
120#[allow(dead_code)]
121pub fn personalized_pagerank<N, E, Ix>(
122    graph: &Graph<N, E, Ix>,
123    source: &N,
124    damping: f64,
125    tolerance: f64,
126    max_iter: usize,
127) -> Result<HashMap<N, f64>>
128where
129    N: Node + Clone + Hash + Eq + std::fmt::Debug,
130    E: EdgeWeight + Into<f64>,
131    Ix: IndexType,
132{
133    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
134    let n = nodes.len();
135
136    if n == 0 || !graph.contains_node(source) {
137        return Err(GraphError::node_not_found("node"));
138    }
139
140    // Find source index
141    let source_idx = nodes
142        .iter()
143        .position(|n| n == source)
144        .expect("Operation failed");
145
146    // Get transition matrix
147    let (_, trans_matrix) = transition_matrix(graph)?;
148
149    // Initialize PageRank vector
150    let mut pr = Array1::<f64>::zeros(n);
151    pr[source_idx] = 1.0;
152
153    // Personalization vector (all mass on source)
154    let mut personalization = Array1::<f64>::zeros(n);
155    personalization[source_idx] = 1.0;
156
157    // Power iteration
158    for _ in 0..max_iter {
159        let new_pr = damping * trans_matrix.t().dot(&pr) + (1.0 - damping) * &personalization;
160
161        // Check convergence
162        let diff: f64 = (&new_pr - &pr).iter().map(|x| x.abs()).sum();
163        if diff < tolerance {
164            break;
165        }
166
167        pr = new_pr;
168    }
169
170    // Convert to HashMap
171    Ok(nodes
172        .into_iter()
173        .enumerate()
174        .map(|(i, node)| (node, pr[i]))
175        .collect())
176}
177
178/// Parallel random walk generator for multiple walks simultaneously
179/// Optimized for embedding algorithms like Node2Vec and DeepWalk
180#[allow(dead_code)]
181pub fn parallel_random_walks<N, E, Ix>(
182    graph: &Graph<N, E, Ix>,
183    starts: &[N],
184    walk_length: usize,
185    restart_probability: f64,
186) -> Result<Vec<Vec<N>>>
187where
188    N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
189    E: EdgeWeight + Send + Sync,
190    Ix: IndexType + Send + Sync,
191{
192    starts
193        .par_iter()
194        .map(|start| random_walk(graph, start, walk_length, restart_probability))
195        .collect::<Result<Vec<_>>>()
196}
197
198/// SIMD-optimized batch random walk with precomputed transition probabilities
199/// More efficient for large-scale embedding generation
200pub struct BatchRandomWalker<N: Node + std::fmt::Debug> {
201    /// Node index mapping
202    node_to_idx: HashMap<N, usize>,
203    /// Index to node mapping
204    idx_to_node: Vec<N>,
205    /// Cumulative transition probabilities for fast sampling
206    #[allow(dead_code)]
207    transition_probs: Vec<Vec<f64>>,
208    /// Alias tables for O(1) neighbor sampling
209    alias_tables: Vec<AliasTable>,
210}
211
212/// Alias table for efficient weighted random sampling
213#[derive(Debug, Clone)]
214struct AliasTable {
215    /// Probability table
216    prob: Vec<f64>,
217    /// Alias table
218    alias: Vec<usize>,
219}
220
221impl AliasTable {
222    /// Construct alias table for weighted sampling
223    fn new(weights: &[f64]) -> Self {
224        let n = weights.len();
225        let mut prob = vec![0.0; n];
226        let mut alias = vec![0; n];
227
228        if n == 0 {
229            return AliasTable { prob, alias };
230        }
231
232        let sum: f64 = weights.iter().sum();
233        if sum == 0.0 {
234            return AliasTable { prob, alias };
235        }
236
237        // Normalize _weights
238        let normalized: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
239
240        let mut small = Vec::new();
241        let mut large = Vec::new();
242
243        for (i, &p) in normalized.iter().enumerate() {
244            if p < 1.0 {
245                small.push(i);
246            } else {
247                large.push(i);
248            }
249        }
250
251        prob[..n].copy_from_slice(&normalized[..n]);
252
253        while let (Some(small_idx), Some(large_idx)) = (small.pop(), large.pop()) {
254            alias[small_idx] = large_idx;
255            prob[large_idx] = prob[large_idx] + prob[small_idx] - 1.0;
256
257            if prob[large_idx] < 1.0 {
258                small.push(large_idx);
259            } else {
260                large.push(large_idx);
261            }
262        }
263
264        AliasTable { prob, alias }
265    }
266
267    /// Sample from the alias table
268    fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> usize {
269        if self.prob.is_empty() {
270            return 0;
271        }
272
273        let i = rng.random_range(0..self.prob.len());
274        let coin_flip = rng.random::<f64>();
275
276        if coin_flip <= self.prob[i] {
277            i
278        } else {
279            self.alias[i]
280        }
281    }
282}
283
284impl<N: Node + Clone + Hash + Eq + std::fmt::Debug> BatchRandomWalker<N> {
285    /// Create a new batch random walker
286    pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Result<Self>
287    where
288        E: EdgeWeight + Into<f64>,
289        Ix: IndexType,
290        N: std::fmt::Debug,
291    {
292        let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
293        let node_to_idx: HashMap<N, usize> = nodes
294            .iter()
295            .enumerate()
296            .map(|(i, n)| (n.clone(), i))
297            .collect();
298
299        let mut transition_probs = Vec::new();
300        let mut alias_tables = Vec::new();
301
302        for node in &nodes {
303            if let Ok(neighbors) = graph.neighbors(node) {
304                let neighbor_weights: Vec<f64> = neighbors
305                    .iter()
306                    .filter_map(|neighbor| graph.edge_weight(node, neighbor).ok())
307                    .map(|w| w.into())
308                    .collect();
309
310                if !neighbor_weights.is_empty() {
311                    let total: f64 = neighbor_weights.iter().sum();
312                    let probs: Vec<f64> = neighbor_weights.iter().map(|w| w / total).collect();
313
314                    // Build cumulative probabilities for SIMD sampling
315                    let mut cumulative = vec![0.0; probs.len()];
316                    cumulative[0] = probs[0];
317                    for i in 1..probs.len() {
318                        cumulative[i] = cumulative[i - 1] + probs[i];
319                    }
320
321                    transition_probs.push(cumulative);
322                    alias_tables.push(AliasTable::new(&neighbor_weights));
323                } else {
324                    // Isolated node
325                    transition_probs.push(vec![]);
326                    alias_tables.push(AliasTable::new(&[]));
327                }
328            } else {
329                transition_probs.push(vec![]);
330                alias_tables.push(AliasTable::new(&[]));
331            }
332        }
333
334        Ok(BatchRandomWalker {
335            node_to_idx,
336            idx_to_node: nodes,
337            transition_probs,
338            alias_tables,
339        })
340    }
341
342    /// Generate multiple random walks in parallel using SIMD optimizations
343    pub fn generate_walks<E, Ix>(
344        &self,
345        graph: &Graph<N, E, Ix>,
346        starts: &[N],
347        walk_length: usize,
348        num_walks_per_node: usize,
349    ) -> Result<Vec<Vec<N>>>
350    where
351        E: EdgeWeight,
352        Ix: IndexType + std::marker::Sync,
353        N: Send + Sync + std::fmt::Debug,
354    {
355        let total_walks = starts.len() * num_walks_per_node;
356        let mut all_walks = Vec::with_capacity(total_walks);
357
358        // Generate walks in parallel
359        starts
360            .par_iter()
361            .map(|start| {
362                let mut local_walks = Vec::with_capacity(num_walks_per_node);
363                let mut rng = scirs2_core::random::rng();
364
365                for _ in 0..num_walks_per_node {
366                    if let Ok(walk) = self.single_walk(graph, start, walk_length, &mut rng) {
367                        local_walks.push(walk);
368                    }
369                }
370                local_walks
371            })
372            .collect::<Vec<_>>()
373            .into_iter()
374            .for_each(|walks| all_walks.extend(walks));
375
376        Ok(all_walks)
377    }
378
379    /// Generate a single optimized random walk
380    fn single_walk<E, Ix>(
381        &self,
382        graph: &Graph<N, E, Ix>,
383        start: &N,
384        walk_length: usize,
385        rng: &mut impl scirs2_core::random::Rng,
386    ) -> Result<Vec<N>>
387    where
388        E: EdgeWeight,
389        Ix: IndexType,
390    {
391        let mut walk = Vec::with_capacity(walk_length + 1);
392        walk.push(start.clone());
393
394        let mut current_idx = *self
395            .node_to_idx
396            .get(start)
397            .ok_or(GraphError::node_not_found("node"))?;
398
399        for _ in 0..walk_length {
400            if let Ok(neighbors) = graph.neighbors(&self.idx_to_node[current_idx]) {
401                let neighbors: Vec<_> = neighbors;
402
403                if !neighbors.is_empty() {
404                    // Use alias table for O(1) sampling
405                    let neighbor_idx = self.alias_tables[current_idx].sample(rng);
406                    if neighbor_idx < neighbors.len() {
407                        let next_node = neighbors[neighbor_idx].clone();
408                        walk.push(next_node.clone());
409
410                        if let Some(&next_idx) = self.node_to_idx.get(&next_node) {
411                            current_idx = next_idx;
412                        }
413                    } else {
414                        break;
415                    }
416                } else {
417                    break;
418                }
419            } else {
420                break;
421            }
422        }
423
424        Ok(walk)
425    }
426}
427
428/// Node2Vec biased random walk with SIMD optimizations
429/// Implements the p and q parameters for controlling exploration vs exploitation
430#[allow(dead_code)]
431pub fn node2vec_walk<N, E, Ix>(
432    graph: &Graph<N, E, Ix>,
433    start: &N,
434    walk_length: usize,
435    p: f64, // Return parameter
436    q: f64, // In-out parameter
437    rng: &mut impl scirs2_core::random::Rng,
438) -> Result<Vec<N>>
439where
440    N: Node + Clone + Hash + Eq + std::fmt::Debug,
441    E: EdgeWeight + Into<f64>,
442    Ix: IndexType,
443{
444    let mut walk = vec![start.clone()];
445    if walk_length == 0 {
446        return Ok(walk);
447    }
448
449    // First step: uniform random
450    if let Ok(neighbors) = graph.neighbors(start) {
451        let neighbors: Vec<_> = neighbors;
452        if neighbors.is_empty() {
453            return Ok(walk);
454        }
455
456        let idx = rng.random_range(0..neighbors.len());
457        walk.push(neighbors[idx].clone());
458    } else {
459        return Ok(walk);
460    }
461
462    // Subsequent steps: use Node2Vec bias
463    for step in 1..walk_length {
464        let current = &walk[step];
465        let previous = &walk[step - 1];
466
467        if let Ok(neighbors) = graph.neighbors(current) {
468            let neighbors: Vec<_> = neighbors;
469            if neighbors.is_empty() {
470                break;
471            }
472
473            // Calculate biased probabilities
474            let mut weights = Vec::with_capacity(neighbors.len());
475
476            for neighbor in &neighbors {
477                let weight = if neighbor == previous {
478                    // Return to previous node
479                    1.0 / p
480                } else if graph.has_edge(previous, neighbor) {
481                    // Move to a node connected to previous (stay local)
482                    1.0
483                } else {
484                    // Move to a distant node (explore)
485                    1.0 / q
486                };
487
488                // Multiply by edge weight if available
489                let edge_weight = graph
490                    .edge_weight(current, neighbor)
491                    .map(|w| w.into())
492                    .unwrap_or(1.0);
493
494                weights.push(weight * edge_weight);
495            }
496
497            // Sample based on weights using SIMD optimized cumulative sampling
498            let total: f64 = weights.iter().sum();
499            if total > 0.0 {
500                let mut cumulative = vec![0.0; weights.len()];
501                cumulative[0] = weights[0] / total;
502
503                // Compute cumulative sum for selection
504                for i in 1..weights.len() {
505                    cumulative[i] = cumulative[i - 1] + weights[i] / total;
506                }
507
508                let r = rng.random::<f64>();
509                for (i, &cum_prob) in cumulative.iter().enumerate() {
510                    if r <= cum_prob {
511                        walk.push(neighbors[i].clone());
512                        break;
513                    }
514                }
515            }
516        } else {
517            break;
518        }
519    }
520
521    Ok(walk)
522}
523
524/// Parallel Node2Vec walk generation for large-scale embedding
525#[allow(dead_code)]
526pub fn parallel_node2vec_walks<N, E, Ix>(
527    graph: &Graph<N, E, Ix>,
528    starts: &[N],
529    walk_length: usize,
530    num_walks: usize,
531    p: f64,
532    q: f64,
533) -> Result<Vec<Vec<N>>>
534where
535    N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
536    E: EdgeWeight + Into<f64> + Send + Sync,
537    Ix: IndexType + Send + Sync,
538{
539    let total_walks = starts.len() * num_walks;
540
541    (0..total_walks)
542        .into_par_iter()
543        .map(|i| {
544            let start_idx = i % starts.len();
545            let start = &starts[start_idx];
546            let mut rng = scirs2_core::random::rng();
547            node2vec_walk(graph, start, walk_length, p, q, &mut rng)
548        })
549        .collect()
550}
551
552/// SIMD-optimized random walk with restart for large graphs
553/// Uses vectorized operations for better performance on large node sets
554#[allow(dead_code)]
555pub fn simd_random_walk_with_restart<N, E, Ix>(
556    graph: &Graph<N, E, Ix>,
557    start: &N,
558    walk_length: usize,
559    restart_prob: f64,
560    rng: &mut impl scirs2_core::random::Rng,
561) -> Result<Vec<N>>
562where
563    N: Node + Clone + Hash + Eq + std::fmt::Debug,
564    E: EdgeWeight,
565    Ix: IndexType,
566{
567    let mut walk = Vec::with_capacity(walk_length + 1);
568    walk.push(start.clone());
569
570    let mut current = start.clone();
571
572    for _ in 0..walk_length {
573        // Vectorized restart decision when processing multiple walks
574        if rng.random::<f64>() < restart_prob {
575            current = start.clone();
576            walk.push(current.clone());
577            continue;
578        }
579
580        if let Ok(neighbors) = graph.neighbors(&current) {
581            let neighbors: Vec<_> = neighbors;
582            if !neighbors.is_empty() {
583                let idx = rng.random_range(0..neighbors.len());
584                current = neighbors[idx].clone();
585                walk.push(current.clone());
586            } else {
587                current = start.clone();
588                walk.push(current.clone());
589            }
590        } else {
591            break;
592        }
593    }
594
595    Ok(walk)
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use crate::error::Result as GraphResult;
602    use crate::generators::create_graph;
603
604    #[test]
605    fn test_random_walk() -> GraphResult<()> {
606        let mut graph = create_graph::<&str, ()>();
607
608        // Create a simple path graph
609        graph.add_edge("A", "B", ())?;
610        graph.add_edge("B", "C", ())?;
611        graph.add_edge("C", "D", ())?;
612
613        // Perform random walk
614        let walk = random_walk(&graph, &"A", 10, 0.1)?;
615
616        // Walk should start at A
617        assert_eq!(walk[0], "A");
618
619        // Walk should have 11 nodes (start + 10 steps)
620        assert_eq!(walk.len(), 11);
621
622        // All nodes in walk should be valid
623        for node in &walk {
624            assert!(graph.contains_node(node));
625        }
626
627        Ok(())
628    }
629
630    #[test]
631    fn test_transition_matrix() -> GraphResult<()> {
632        let mut graph = create_graph::<&str, f64>();
633
634        // Create a triangle with equal weights
635        graph.add_edge("A", "B", 1.0)?;
636        graph.add_edge("B", "C", 1.0)?;
637        graph.add_edge("C", "A", 1.0)?;
638
639        let (nodes, matrix) = transition_matrix(&graph)?;
640
641        assert_eq!(nodes.len(), 3);
642        assert_eq!(matrix.shape(), &[3, 3]);
643
644        // Each row should sum to 1.0 (stochastic matrix)
645        for i in 0..3 {
646            let row_sum: f64 = (0..3).map(|j| matrix[[i, j]]).sum();
647            assert!((row_sum - 1.0).abs() < 1e-6);
648        }
649
650        Ok(())
651    }
652
653    #[test]
654    fn test_personalized_pagerank() -> GraphResult<()> {
655        let mut graph = create_graph::<&str, f64>();
656
657        // Create a star graph with A at center
658        graph.add_edge("A", "B", 1.0)?;
659        graph.add_edge("A", "C", 1.0)?;
660        graph.add_edge("A", "D", 1.0)?;
661
662        let pagerank = personalized_pagerank(&graph, &"A", 0.85, 1e-6, 100)?;
663
664        // All nodes should have PageRank values
665        assert_eq!(pagerank.len(), 4);
666
667        // PageRank values should sum to approximately 1.0
668        let total: f64 = pagerank.values().sum();
669        assert!((total - 1.0).abs() < 1e-3);
670
671        // Source node (A) should have highest PageRank
672        let a_rank = pagerank[&"A"];
673        for (node, &rank) in &pagerank {
674            if node != &"A" {
675                assert!(a_rank >= rank);
676            }
677        }
678
679        Ok(())
680    }
681}