scirs2_graph/algorithms/
random_walk.rs

1//! Random walk algorithms
2//!
3//! This module contains algorithms related to random walks on graphs.
4
5use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use crate::error::{GraphError, Result};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::parallel_ops::*;
9use std::collections::HashMap;
10use std::hash::Hash;
11
12/// Perform a random walk on the graph
13///
14/// Returns a sequence of nodes visited during the walk.
15#[allow(dead_code)]
16pub fn random_walk<N, E, Ix>(
17    graph: &Graph<N, E, Ix>,
18    start: &N,
19    steps: usize,
20    restart_probability: f64,
21) -> Result<Vec<N>>
22where
23    N: Node + Clone + Hash + Eq + std::fmt::Debug,
24    E: EdgeWeight,
25    Ix: IndexType,
26{
27    if !graph.contains_node(start) {
28        return Err(GraphError::node_not_found("node"));
29    }
30
31    let mut walk = vec![start.clone()];
32    let mut current = start.clone();
33    let mut rng = scirs2_core::random::rng();
34
35    use scirs2_core::random::Rng;
36
37    for _ in 0..steps {
38        // With restart_probability, jump back to start
39        if rng.random::<f64>() < restart_probability {
40            current = start.clone();
41            walk.push(current.clone());
42            continue;
43        }
44
45        // Otherwise, move to a random neighbor
46        if let Ok(neighbors) = graph.neighbors(&current) {
47            let neighbor_vec: Vec<N> = neighbors;
48
49            if !neighbor_vec.is_empty() {
50                let idx = rng.gen_range(0..neighbor_vec.len());
51                current = neighbor_vec[idx].clone();
52                walk.push(current.clone());
53            } else {
54                // No neighbors..restart
55                current = start.clone();
56                walk.push(current.clone());
57            }
58        }
59    }
60
61    Ok(walk)
62}
63
64/// Compute the transition matrix for random walks on the graph
65///
66/// Returns a row-stochastic matrix where entry (i,j) is the probability
67/// of transitioning from node i to node j.
68#[allow(dead_code)]
69pub fn transition_matrix<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<(Vec<N>, Array2<f64>)>
70where
71    N: Node + Clone + std::fmt::Debug,
72    E: EdgeWeight + Into<f64>,
73    Ix: IndexType,
74{
75    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
76    let n = nodes.len();
77
78    if n == 0 {
79        return Err(GraphError::InvalidGraph("Empty graph".to_string()));
80    }
81
82    let mut matrix = Array2::<f64>::zeros((n, n));
83
84    for (i, node) in nodes.iter().enumerate() {
85        if let Ok(neighbors) = graph.neighbors(node) {
86            let neighbor_weights: Vec<(usize, f64)> = neighbors
87                .into_iter()
88                .filter_map(|neighbor| {
89                    nodes.iter().position(|n| n == &neighbor).and_then(|j| {
90                        graph
91                            .edge_weight(node, &neighbor)
92                            .ok()
93                            .map(|w| (j, w.into()))
94                    })
95                })
96                .collect();
97
98            let total_weight: f64 = neighbor_weights.iter().map(|(_, w)| w).sum();
99
100            if total_weight > 0.0 {
101                for (j, weight) in neighbor_weights {
102                    matrix[[i, j]] = weight / total_weight;
103                }
104            } else {
105                // Dangling node: uniform distribution
106                for j in 0..n {
107                    matrix[[i, j]] = 1.0 / n as f64;
108                }
109            }
110        }
111    }
112
113    Ok((nodes, matrix))
114}
115
116/// Compute personalized PageRank from a given source node
117///
118/// This is useful for measuring node similarity and influence.
119#[allow(dead_code)]
120pub fn personalized_pagerank<N, E, Ix>(
121    graph: &Graph<N, E, Ix>,
122    source: &N,
123    damping: f64,
124    tolerance: f64,
125    max_iter: usize,
126) -> Result<HashMap<N, f64>>
127where
128    N: Node + Clone + Hash + Eq + std::fmt::Debug,
129    E: EdgeWeight + Into<f64>,
130    Ix: IndexType,
131{
132    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
133    let n = nodes.len();
134
135    if n == 0 || !graph.contains_node(source) {
136        return Err(GraphError::node_not_found("node"));
137    }
138
139    // Find source index
140    let source_idx = nodes
141        .iter()
142        .position(|n| n == source)
143        .expect("Operation failed");
144
145    // Get transition matrix
146    let (_, trans_matrix) = transition_matrix(graph)?;
147
148    // Initialize PageRank vector
149    let mut pr = Array1::<f64>::zeros(n);
150    pr[source_idx] = 1.0;
151
152    // Personalization vector (all mass on source)
153    let mut personalization = Array1::<f64>::zeros(n);
154    personalization[source_idx] = 1.0;
155
156    // Power iteration
157    for _ in 0..max_iter {
158        let new_pr = damping * trans_matrix.t().dot(&pr) + (1.0 - damping) * &personalization;
159
160        // Check convergence
161        let diff: f64 = (&new_pr - &pr).iter().map(|x| x.abs()).sum();
162        if diff < tolerance {
163            break;
164        }
165
166        pr = new_pr;
167    }
168
169    // Convert to HashMap
170    Ok(nodes
171        .into_iter()
172        .enumerate()
173        .map(|(i, node)| (node, pr[i]))
174        .collect())
175}
176
177/// Parallel random walk generator for multiple walks simultaneously
178/// Optimized for embedding algorithms like Node2Vec and DeepWalk
179#[allow(dead_code)]
180pub fn parallel_random_walks<N, E, Ix>(
181    graph: &Graph<N, E, Ix>,
182    starts: &[N],
183    walk_length: usize,
184    restart_probability: f64,
185) -> Result<Vec<Vec<N>>>
186where
187    N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
188    E: EdgeWeight + Send + Sync,
189    Ix: IndexType + Send + Sync,
190{
191    starts
192        .par_iter()
193        .map(|start| random_walk(graph, start, walk_length, restart_probability))
194        .collect::<Result<Vec<_>>>()
195}
196
197/// SIMD-optimized batch random walk with precomputed transition probabilities
198/// More efficient for large-scale embedding generation
199pub struct BatchRandomWalker<N: Node + std::fmt::Debug> {
200    /// Node index mapping
201    node_to_idx: HashMap<N, usize>,
202    /// Index to node mapping
203    idx_to_node: Vec<N>,
204    /// Cumulative transition probabilities for fast sampling
205    #[allow(dead_code)]
206    transition_probs: Vec<Vec<f64>>,
207    /// Alias tables for O(1) neighbor sampling
208    alias_tables: Vec<AliasTable>,
209}
210
211/// Alias table for efficient weighted random sampling
212#[derive(Debug, Clone)]
213struct AliasTable {
214    /// Probability table
215    prob: Vec<f64>,
216    /// Alias table
217    alias: Vec<usize>,
218}
219
220impl AliasTable {
221    /// Construct alias table for weighted sampling
222    fn new(weights: &[f64]) -> Self {
223        let n = weights.len();
224        let mut prob = vec![0.0; n];
225        let mut alias = vec![0; n];
226
227        if n == 0 {
228            return AliasTable { prob, alias };
229        }
230
231        let sum: f64 = weights.iter().sum();
232        if sum == 0.0 {
233            return AliasTable { prob, alias };
234        }
235
236        // Normalize _weights
237        let normalized: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
238
239        let mut small = Vec::new();
240        let mut large = Vec::new();
241
242        for (i, &p) in normalized.iter().enumerate() {
243            if p < 1.0 {
244                small.push(i);
245            } else {
246                large.push(i);
247            }
248        }
249
250        prob[..n].copy_from_slice(&normalized[..n]);
251
252        while let (Some(small_idx), Some(large_idx)) = (small.pop(), large.pop()) {
253            alias[small_idx] = large_idx;
254            prob[large_idx] = prob[large_idx] + prob[small_idx] - 1.0;
255
256            if prob[large_idx] < 1.0 {
257                small.push(large_idx);
258            } else {
259                large.push(large_idx);
260            }
261        }
262
263        AliasTable { prob, alias }
264    }
265
266    /// Sample from the alias table
267    fn sample(&self, rng: &mut impl scirs2_core::random::Rng) -> usize {
268        if self.prob.is_empty() {
269            return 0;
270        }
271
272        let i = rng.gen_range(0..self.prob.len());
273        let coin_flip = rng.random::<f64>();
274
275        if coin_flip <= self.prob[i] {
276            i
277        } else {
278            self.alias[i]
279        }
280    }
281}
282
283impl<N: Node + Clone + Hash + Eq + std::fmt::Debug> BatchRandomWalker<N> {
284    /// Create a new batch random walker
285    pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Result<Self>
286    where
287        E: EdgeWeight + Into<f64>,
288        Ix: IndexType,
289        N: std::fmt::Debug,
290    {
291        let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
292        let node_to_idx: HashMap<N, usize> = nodes
293            .iter()
294            .enumerate()
295            .map(|(i, n)| (n.clone(), i))
296            .collect();
297
298        let mut transition_probs = Vec::new();
299        let mut alias_tables = Vec::new();
300
301        for node in &nodes {
302            if let Ok(neighbors) = graph.neighbors(node) {
303                let neighbor_weights: Vec<f64> = neighbors
304                    .iter()
305                    .filter_map(|neighbor| graph.edge_weight(node, neighbor).ok())
306                    .map(|w| w.into())
307                    .collect();
308
309                if !neighbor_weights.is_empty() {
310                    let total: f64 = neighbor_weights.iter().sum();
311                    let probs: Vec<f64> = neighbor_weights.iter().map(|w| w / total).collect();
312
313                    // Build cumulative probabilities for SIMD sampling
314                    let mut cumulative = vec![0.0; probs.len()];
315                    cumulative[0] = probs[0];
316                    for i in 1..probs.len() {
317                        cumulative[i] = cumulative[i - 1] + probs[i];
318                    }
319
320                    transition_probs.push(cumulative);
321                    alias_tables.push(AliasTable::new(&neighbor_weights));
322                } else {
323                    // Isolated node
324                    transition_probs.push(vec![]);
325                    alias_tables.push(AliasTable::new(&[]));
326                }
327            } else {
328                transition_probs.push(vec![]);
329                alias_tables.push(AliasTable::new(&[]));
330            }
331        }
332
333        Ok(BatchRandomWalker {
334            node_to_idx,
335            idx_to_node: nodes,
336            transition_probs,
337            alias_tables,
338        })
339    }
340
341    /// Generate multiple random walks in parallel using SIMD optimizations
342    pub fn generate_walks<E, Ix>(
343        &self,
344        graph: &Graph<N, E, Ix>,
345        starts: &[N],
346        walk_length: usize,
347        num_walks_per_node: usize,
348    ) -> Result<Vec<Vec<N>>>
349    where
350        E: EdgeWeight,
351        Ix: IndexType + std::marker::Sync,
352        N: Send + Sync + std::fmt::Debug,
353    {
354        let total_walks = starts.len() * num_walks_per_node;
355        let mut all_walks = Vec::with_capacity(total_walks);
356
357        // Generate walks in parallel
358        starts
359            .par_iter()
360            .map(|start| {
361                let mut local_walks = Vec::with_capacity(num_walks_per_node);
362                let mut rng = scirs2_core::random::rng();
363
364                for _ in 0..num_walks_per_node {
365                    if let Ok(walk) = self.single_walk(graph, start, walk_length, &mut rng) {
366                        local_walks.push(walk);
367                    }
368                }
369                local_walks
370            })
371            .collect::<Vec<_>>()
372            .into_iter()
373            .for_each(|walks| all_walks.extend(walks));
374
375        Ok(all_walks)
376    }
377
378    /// Generate a single optimized random walk
379    fn single_walk<E, Ix>(
380        &self,
381        graph: &Graph<N, E, Ix>,
382        start: &N,
383        walk_length: usize,
384        rng: &mut impl scirs2_core::random::Rng,
385    ) -> Result<Vec<N>>
386    where
387        E: EdgeWeight,
388        Ix: IndexType,
389    {
390        let mut walk = Vec::with_capacity(walk_length + 1);
391        walk.push(start.clone());
392
393        let mut current_idx = *self
394            .node_to_idx
395            .get(start)
396            .ok_or(GraphError::node_not_found("node"))?;
397
398        for _ in 0..walk_length {
399            if let Ok(neighbors) = graph.neighbors(&self.idx_to_node[current_idx]) {
400                let neighbors: Vec<_> = neighbors;
401
402                if !neighbors.is_empty() {
403                    // Use alias table for O(1) sampling
404                    let neighbor_idx = self.alias_tables[current_idx].sample(rng);
405                    if neighbor_idx < neighbors.len() {
406                        let next_node = neighbors[neighbor_idx].clone();
407                        walk.push(next_node.clone());
408
409                        if let Some(&next_idx) = self.node_to_idx.get(&next_node) {
410                            current_idx = next_idx;
411                        }
412                    } else {
413                        break;
414                    }
415                } else {
416                    break;
417                }
418            } else {
419                break;
420            }
421        }
422
423        Ok(walk)
424    }
425}
426
427/// Node2Vec biased random walk with SIMD optimizations
428/// Implements the p and q parameters for controlling exploration vs exploitation
429#[allow(dead_code)]
430pub fn node2vec_walk<N, E, Ix>(
431    graph: &Graph<N, E, Ix>,
432    start: &N,
433    walk_length: usize,
434    p: f64, // Return parameter
435    q: f64, // In-out parameter
436    rng: &mut impl scirs2_core::random::Rng,
437) -> Result<Vec<N>>
438where
439    N: Node + Clone + Hash + Eq + std::fmt::Debug,
440    E: EdgeWeight + Into<f64>,
441    Ix: IndexType,
442{
443    let mut walk = vec![start.clone()];
444    if walk_length == 0 {
445        return Ok(walk);
446    }
447
448    // First step: uniform random
449    if let Ok(neighbors) = graph.neighbors(start) {
450        let neighbors: Vec<_> = neighbors;
451        if neighbors.is_empty() {
452            return Ok(walk);
453        }
454
455        let idx = rng.gen_range(0..neighbors.len());
456        walk.push(neighbors[idx].clone());
457    } else {
458        return Ok(walk);
459    }
460
461    // Subsequent steps: use Node2Vec bias
462    for step in 1..walk_length {
463        let current = &walk[step];
464        let previous = &walk[step - 1];
465
466        if let Ok(neighbors) = graph.neighbors(current) {
467            let neighbors: Vec<_> = neighbors;
468            if neighbors.is_empty() {
469                break;
470            }
471
472            // Calculate biased probabilities
473            let mut weights = Vec::with_capacity(neighbors.len());
474
475            for neighbor in &neighbors {
476                let weight = if neighbor == previous {
477                    // Return to previous node
478                    1.0 / p
479                } else if graph.has_edge(previous, neighbor) {
480                    // Move to a node connected to previous (stay local)
481                    1.0
482                } else {
483                    // Move to a distant node (explore)
484                    1.0 / q
485                };
486
487                // Multiply by edge weight if available
488                let edge_weight = graph
489                    .edge_weight(current, neighbor)
490                    .map(|w| w.into())
491                    .unwrap_or(1.0);
492
493                weights.push(weight * edge_weight);
494            }
495
496            // Sample based on weights using SIMD optimized cumulative sampling
497            let total: f64 = weights.iter().sum();
498            if total > 0.0 {
499                let mut cumulative = vec![0.0; weights.len()];
500                cumulative[0] = weights[0] / total;
501
502                // Compute cumulative sum for selection
503                for i in 1..weights.len() {
504                    cumulative[i] = cumulative[i - 1] + weights[i] / total;
505                }
506
507                let r = rng.random::<f64>();
508                for (i, &cum_prob) in cumulative.iter().enumerate() {
509                    if r <= cum_prob {
510                        walk.push(neighbors[i].clone());
511                        break;
512                    }
513                }
514            }
515        } else {
516            break;
517        }
518    }
519
520    Ok(walk)
521}
522
523/// Parallel Node2Vec walk generation for large-scale embedding
524#[allow(dead_code)]
525pub fn parallel_node2vec_walks<N, E, Ix>(
526    graph: &Graph<N, E, Ix>,
527    starts: &[N],
528    walk_length: usize,
529    num_walks: usize,
530    p: f64,
531    q: f64,
532) -> Result<Vec<Vec<N>>>
533where
534    N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
535    E: EdgeWeight + Into<f64> + Send + Sync,
536    Ix: IndexType + Send + Sync,
537{
538    let total_walks = starts.len() * num_walks;
539
540    (0..total_walks)
541        .into_par_iter()
542        .map(|i| {
543            let start_idx = i % starts.len();
544            let start = &starts[start_idx];
545            let mut rng = scirs2_core::random::rng();
546            node2vec_walk(graph, start, walk_length, p, q, &mut rng)
547        })
548        .collect()
549}
550
551/// SIMD-optimized random walk with restart for large graphs
552/// Uses vectorized operations for better performance on large node sets
553#[allow(dead_code)]
554pub fn simd_random_walk_with_restart<N, E, Ix>(
555    graph: &Graph<N, E, Ix>,
556    start: &N,
557    walk_length: usize,
558    restart_prob: f64,
559    rng: &mut impl scirs2_core::random::Rng,
560) -> Result<Vec<N>>
561where
562    N: Node + Clone + Hash + Eq + std::fmt::Debug,
563    E: EdgeWeight,
564    Ix: IndexType,
565{
566    let mut walk = Vec::with_capacity(walk_length + 1);
567    walk.push(start.clone());
568
569    let mut current = start.clone();
570
571    for _ in 0..walk_length {
572        // Vectorized restart decision when processing multiple walks
573        if rng.random::<f64>() < restart_prob {
574            current = start.clone();
575            walk.push(current.clone());
576            continue;
577        }
578
579        if let Ok(neighbors) = graph.neighbors(&current) {
580            let neighbors: Vec<_> = neighbors;
581            if !neighbors.is_empty() {
582                let idx = rng.gen_range(0..neighbors.len());
583                current = neighbors[idx].clone();
584                walk.push(current.clone());
585            } else {
586                current = start.clone();
587                walk.push(current.clone());
588            }
589        } else {
590            break;
591        }
592    }
593
594    Ok(walk)
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600    use crate::error::Result as GraphResult;
601    use crate::generators::create_graph;
602
603    #[test]
604    fn test_random_walk() -> GraphResult<()> {
605        let mut graph = create_graph::<&str, ()>();
606
607        // Create a simple path graph
608        graph.add_edge("A", "B", ())?;
609        graph.add_edge("B", "C", ())?;
610        graph.add_edge("C", "D", ())?;
611
612        // Perform random walk
613        let walk = random_walk(&graph, &"A", 10, 0.1)?;
614
615        // Walk should start at A
616        assert_eq!(walk[0], "A");
617
618        // Walk should have 11 nodes (start + 10 steps)
619        assert_eq!(walk.len(), 11);
620
621        // All nodes in walk should be valid
622        for node in &walk {
623            assert!(graph.contains_node(node));
624        }
625
626        Ok(())
627    }
628
629    #[test]
630    fn test_transition_matrix() -> GraphResult<()> {
631        let mut graph = create_graph::<&str, f64>();
632
633        // Create a triangle with equal weights
634        graph.add_edge("A", "B", 1.0)?;
635        graph.add_edge("B", "C", 1.0)?;
636        graph.add_edge("C", "A", 1.0)?;
637
638        let (nodes, matrix) = transition_matrix(&graph)?;
639
640        assert_eq!(nodes.len(), 3);
641        assert_eq!(matrix.shape(), &[3, 3]);
642
643        // Each row should sum to 1.0 (stochastic matrix)
644        for i in 0..3 {
645            let row_sum: f64 = (0..3).map(|j| matrix[[i, j]]).sum();
646            assert!((row_sum - 1.0).abs() < 1e-6);
647        }
648
649        Ok(())
650    }
651
652    #[test]
653    fn test_personalized_pagerank() -> GraphResult<()> {
654        let mut graph = create_graph::<&str, f64>();
655
656        // Create a star graph with A at center
657        graph.add_edge("A", "B", 1.0)?;
658        graph.add_edge("A", "C", 1.0)?;
659        graph.add_edge("A", "D", 1.0)?;
660
661        let pagerank = personalized_pagerank(&graph, &"A", 0.85, 1e-6, 100)?;
662
663        // All nodes should have PageRank values
664        assert_eq!(pagerank.len(), 4);
665
666        // PageRank values should sum to approximately 1.0
667        let total: f64 = pagerank.values().sum();
668        assert!((total - 1.0).abs() < 1e-3);
669
670        // Source node (A) should have highest PageRank
671        let a_rank = pagerank[&"A"];
672        for (node, &rank) in &pagerank {
673            if node != &"A" {
674                assert!(a_rank >= rank);
675            }
676        }
677
678        Ok(())
679    }
680}