scirs2_graph/algorithms/
random_walk.rs

1//! Random walk algorithms
2//!
3//! This module contains algorithms related to random walks on graphs.
4
5use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use crate::error::{GraphError, Result};
7use ndarray::{Array1, Array2};
8use scirs2_core::parallel_ops::*;
9use std::collections::HashMap;
10use std::hash::Hash;
11
12/// Perform a random walk on the graph
13///
14/// Returns a sequence of nodes visited during the walk.
15#[allow(dead_code)]
16pub fn random_walk<N, E, Ix>(
17    graph: &Graph<N, E, Ix>,
18    start: &N,
19    steps: usize,
20    restart_probability: f64,
21) -> Result<Vec<N>>
22where
23    N: Node + Clone + Hash + Eq + std::fmt::Debug,
24    E: EdgeWeight,
25    Ix: IndexType,
26{
27    if !graph.contains_node(start) {
28        return Err(GraphError::node_not_found("node"));
29    }
30
31    let mut walk = vec![start.clone()];
32    let mut current = start.clone();
33    let mut rng = rand::rng();
34
35    use rand::Rng;
36
37    for _ in 0..steps {
38        // With restart_probability, jump back to start
39        if rng.random::<f64>() < restart_probability {
40            current = start.clone();
41            walk.push(current.clone());
42            continue;
43        }
44
45        // Otherwise, move to a random neighbor
46        if let Ok(neighbors) = graph.neighbors(&current) {
47            let neighbor_vec: Vec<N> = neighbors;
48
49            if !neighbor_vec.is_empty() {
50                let idx = rng.gen_range(0..neighbor_vec.len());
51                current = neighbor_vec[idx].clone();
52                walk.push(current.clone());
53            } else {
54                // No neighbors..restart
55                current = start.clone();
56                walk.push(current.clone());
57            }
58        }
59    }
60
61    Ok(walk)
62}
63
64/// Compute the transition matrix for random walks on the graph
65///
66/// Returns a row-stochastic matrix where entry (i,j) is the probability
67/// of transitioning from node i to node j.
68#[allow(dead_code)]
69pub fn transition_matrix<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<(Vec<N>, Array2<f64>)>
70where
71    N: Node + Clone + std::fmt::Debug,
72    E: EdgeWeight + Into<f64>,
73    Ix: IndexType,
74{
75    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
76    let n = nodes.len();
77
78    if n == 0 {
79        return Err(GraphError::InvalidGraph("Empty graph".to_string()));
80    }
81
82    let mut matrix = Array2::<f64>::zeros((n, n));
83
84    for (i, node) in nodes.iter().enumerate() {
85        if let Ok(neighbors) = graph.neighbors(node) {
86            let neighbor_weights: Vec<(usize, f64)> = neighbors
87                .into_iter()
88                .filter_map(|neighbor| {
89                    nodes.iter().position(|n| n == &neighbor).and_then(|j| {
90                        graph
91                            .edge_weight(node, &neighbor)
92                            .ok()
93                            .map(|w| (j, w.into()))
94                    })
95                })
96                .collect();
97
98            let total_weight: f64 = neighbor_weights.iter().map(|(_, w)| w).sum();
99
100            if total_weight > 0.0 {
101                for (j, weight) in neighbor_weights {
102                    matrix[[i, j]] = weight / total_weight;
103                }
104            } else {
105                // Dangling node: uniform distribution
106                for j in 0..n {
107                    matrix[[i, j]] = 1.0 / n as f64;
108                }
109            }
110        }
111    }
112
113    Ok((nodes, matrix))
114}
115
116/// Compute personalized PageRank from a given source node
117///
118/// This is useful for measuring node similarity and influence.
119#[allow(dead_code)]
120pub fn personalized_pagerank<N, E, Ix>(
121    graph: &Graph<N, E, Ix>,
122    source: &N,
123    damping: f64,
124    tolerance: f64,
125    max_iter: usize,
126) -> Result<HashMap<N, f64>>
127where
128    N: Node + Clone + Hash + Eq + std::fmt::Debug,
129    E: EdgeWeight + Into<f64>,
130    Ix: IndexType,
131{
132    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
133    let n = nodes.len();
134
135    if n == 0 || !graph.contains_node(source) {
136        return Err(GraphError::node_not_found("node"));
137    }
138
139    // Find source index
140    let source_idx = nodes.iter().position(|n| n == source).unwrap();
141
142    // Get transition matrix
143    let (_, trans_matrix) = transition_matrix(graph)?;
144
145    // Initialize PageRank vector
146    let mut pr = Array1::<f64>::zeros(n);
147    pr[source_idx] = 1.0;
148
149    // Personalization vector (all mass on source)
150    let mut personalization = Array1::<f64>::zeros(n);
151    personalization[source_idx] = 1.0;
152
153    // Power iteration
154    for _ in 0..max_iter {
155        let new_pr = damping * trans_matrix.t().dot(&pr) + (1.0 - damping) * &personalization;
156
157        // Check convergence
158        let diff: f64 = (&new_pr - &pr).iter().map(|x| x.abs()).sum();
159        if diff < tolerance {
160            break;
161        }
162
163        pr = new_pr;
164    }
165
166    // Convert to HashMap
167    Ok(nodes
168        .into_iter()
169        .enumerate()
170        .map(|(i, node)| (node, pr[i]))
171        .collect())
172}
173
174/// Parallel random walk generator for multiple walks simultaneously
175/// Optimized for embedding algorithms like Node2Vec and DeepWalk
176#[allow(dead_code)]
177pub fn parallel_random_walks<N, E, Ix>(
178    graph: &Graph<N, E, Ix>,
179    starts: &[N],
180    walk_length: usize,
181    restart_probability: f64,
182) -> Result<Vec<Vec<N>>>
183where
184    N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
185    E: EdgeWeight + Send + Sync,
186    Ix: IndexType + Send + Sync,
187{
188    starts
189        .par_iter()
190        .map(|start| random_walk(graph, start, walk_length, restart_probability))
191        .collect::<Result<Vec<_>>>()
192}
193
194/// SIMD-optimized batch random walk with precomputed transition probabilities
195/// More efficient for large-scale embedding generation
196pub struct BatchRandomWalker<N: Node + std::fmt::Debug> {
197    /// Node index mapping
198    node_to_idx: HashMap<N, usize>,
199    /// Index to node mapping
200    idx_to_node: Vec<N>,
201    /// Cumulative transition probabilities for fast sampling
202    #[allow(dead_code)]
203    transition_probs: Vec<Vec<f64>>,
204    /// Alias tables for O(1) neighbor sampling
205    alias_tables: Vec<AliasTable>,
206}
207
208/// Alias table for efficient weighted random sampling
209#[derive(Debug, Clone)]
210struct AliasTable {
211    /// Probability table
212    prob: Vec<f64>,
213    /// Alias table
214    alias: Vec<usize>,
215}
216
217impl AliasTable {
218    /// Construct alias table for weighted sampling
219    fn new(weights: &[f64]) -> Self {
220        let n = weights.len();
221        let mut prob = vec![0.0; n];
222        let mut alias = vec![0; n];
223
224        if n == 0 {
225            return AliasTable { prob, alias };
226        }
227
228        let sum: f64 = weights.iter().sum();
229        if sum == 0.0 {
230            return AliasTable { prob, alias };
231        }
232
233        // Normalize _weights
234        let normalized: Vec<f64> = weights.iter().map(|w| w * n as f64 / sum).collect();
235
236        let mut small = Vec::new();
237        let mut large = Vec::new();
238
239        for (i, &p) in normalized.iter().enumerate() {
240            if p < 1.0 {
241                small.push(i);
242            } else {
243                large.push(i);
244            }
245        }
246
247        prob[..n].copy_from_slice(&normalized[..n]);
248
249        while let (Some(small_idx), Some(large_idx)) = (small.pop(), large.pop()) {
250            alias[small_idx] = large_idx;
251            prob[large_idx] = prob[large_idx] + prob[small_idx] - 1.0;
252
253            if prob[large_idx] < 1.0 {
254                small.push(large_idx);
255            } else {
256                large.push(large_idx);
257            }
258        }
259
260        AliasTable { prob, alias }
261    }
262
263    /// Sample from the alias table
264    fn sample(&self, rng: &mut impl rand::Rng) -> usize {
265        if self.prob.is_empty() {
266            return 0;
267        }
268
269        let i = rng.gen_range(0..self.prob.len());
270        let coin_flip = rng.random::<f64>();
271
272        if coin_flip <= self.prob[i] {
273            i
274        } else {
275            self.alias[i]
276        }
277    }
278}
279
280impl<N: Node + Clone + Hash + Eq + std::fmt::Debug> BatchRandomWalker<N> {
281    /// Create a new batch random walker
282    pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Result<Self>
283    where
284        E: EdgeWeight + Into<f64>,
285        Ix: IndexType,
286        N: std::fmt::Debug,
287    {
288        let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
289        let node_to_idx: HashMap<N, usize> = nodes
290            .iter()
291            .enumerate()
292            .map(|(i, n)| (n.clone(), i))
293            .collect();
294
295        let mut transition_probs = Vec::new();
296        let mut alias_tables = Vec::new();
297
298        for node in &nodes {
299            if let Ok(neighbors) = graph.neighbors(node) {
300                let neighbor_weights: Vec<f64> = neighbors
301                    .iter()
302                    .filter_map(|neighbor| graph.edge_weight(node, neighbor).ok())
303                    .map(|w| w.into())
304                    .collect();
305
306                if !neighbor_weights.is_empty() {
307                    let total: f64 = neighbor_weights.iter().sum();
308                    let probs: Vec<f64> = neighbor_weights.iter().map(|w| w / total).collect();
309
310                    // Build cumulative probabilities for SIMD sampling
311                    let mut cumulative = vec![0.0; probs.len()];
312                    cumulative[0] = probs[0];
313                    for i in 1..probs.len() {
314                        cumulative[i] = cumulative[i - 1] + probs[i];
315                    }
316
317                    transition_probs.push(cumulative);
318                    alias_tables.push(AliasTable::new(&neighbor_weights));
319                } else {
320                    // Isolated node
321                    transition_probs.push(vec![]);
322                    alias_tables.push(AliasTable::new(&[]));
323                }
324            } else {
325                transition_probs.push(vec![]);
326                alias_tables.push(AliasTable::new(&[]));
327            }
328        }
329
330        Ok(BatchRandomWalker {
331            node_to_idx,
332            idx_to_node: nodes,
333            transition_probs,
334            alias_tables,
335        })
336    }
337
338    /// Generate multiple random walks in parallel using SIMD optimizations
339    pub fn generate_walks<E, Ix>(
340        &self,
341        graph: &Graph<N, E, Ix>,
342        starts: &[N],
343        walk_length: usize,
344        num_walks_per_node: usize,
345    ) -> Result<Vec<Vec<N>>>
346    where
347        E: EdgeWeight,
348        Ix: IndexType + std::marker::Sync,
349        N: Send + Sync + std::fmt::Debug,
350    {
351        let total_walks = starts.len() * num_walks_per_node;
352        let mut all_walks = Vec::with_capacity(total_walks);
353
354        // Generate walks in parallel
355        starts
356            .par_iter()
357            .map(|start| {
358                let mut local_walks = Vec::with_capacity(num_walks_per_node);
359                let mut rng = rand::rng();
360
361                for _ in 0..num_walks_per_node {
362                    if let Ok(walk) = self.single_walk(graph, start, walk_length, &mut rng) {
363                        local_walks.push(walk);
364                    }
365                }
366                local_walks
367            })
368            .collect::<Vec<_>>()
369            .into_iter()
370            .for_each(|walks| all_walks.extend(walks));
371
372        Ok(all_walks)
373    }
374
375    /// Generate a single optimized random walk
376    fn single_walk<E, Ix>(
377        &self,
378        graph: &Graph<N, E, Ix>,
379        start: &N,
380        walk_length: usize,
381        rng: &mut impl rand::Rng,
382    ) -> Result<Vec<N>>
383    where
384        E: EdgeWeight,
385        Ix: IndexType,
386    {
387        let mut walk = Vec::with_capacity(walk_length + 1);
388        walk.push(start.clone());
389
390        let mut current_idx = *self
391            .node_to_idx
392            .get(start)
393            .ok_or(GraphError::node_not_found("node"))?;
394
395        for _ in 0..walk_length {
396            if let Ok(neighbors) = graph.neighbors(&self.idx_to_node[current_idx]) {
397                let neighbors: Vec<_> = neighbors;
398
399                if !neighbors.is_empty() {
400                    // Use alias table for O(1) sampling
401                    let neighbor_idx = self.alias_tables[current_idx].sample(rng);
402                    if neighbor_idx < neighbors.len() {
403                        let next_node = neighbors[neighbor_idx].clone();
404                        walk.push(next_node.clone());
405
406                        if let Some(&next_idx) = self.node_to_idx.get(&next_node) {
407                            current_idx = next_idx;
408                        }
409                    } else {
410                        break;
411                    }
412                } else {
413                    break;
414                }
415            } else {
416                break;
417            }
418        }
419
420        Ok(walk)
421    }
422}
423
424/// Node2Vec biased random walk with SIMD optimizations
425/// Implements the p and q parameters for controlling exploration vs exploitation
426#[allow(dead_code)]
427pub fn node2vec_walk<N, E, Ix>(
428    graph: &Graph<N, E, Ix>,
429    start: &N,
430    walk_length: usize,
431    p: f64, // Return parameter
432    q: f64, // In-out parameter
433    rng: &mut impl rand::Rng,
434) -> Result<Vec<N>>
435where
436    N: Node + Clone + Hash + Eq + std::fmt::Debug,
437    E: EdgeWeight + Into<f64>,
438    Ix: IndexType,
439{
440    let mut walk = vec![start.clone()];
441    if walk_length == 0 {
442        return Ok(walk);
443    }
444
445    // First step: uniform random
446    if let Ok(neighbors) = graph.neighbors(start) {
447        let neighbors: Vec<_> = neighbors;
448        if neighbors.is_empty() {
449            return Ok(walk);
450        }
451
452        let idx = rng.gen_range(0..neighbors.len());
453        walk.push(neighbors[idx].clone());
454    } else {
455        return Ok(walk);
456    }
457
458    // Subsequent steps: use Node2Vec bias
459    for step in 1..walk_length {
460        let current = &walk[step];
461        let previous = &walk[step - 1];
462
463        if let Ok(neighbors) = graph.neighbors(current) {
464            let neighbors: Vec<_> = neighbors;
465            if neighbors.is_empty() {
466                break;
467            }
468
469            // Calculate biased probabilities
470            let mut weights = Vec::with_capacity(neighbors.len());
471
472            for neighbor in &neighbors {
473                let weight = if neighbor == previous {
474                    // Return to previous node
475                    1.0 / p
476                } else if graph.has_edge(previous, neighbor) {
477                    // Move to a node connected to previous (stay local)
478                    1.0
479                } else {
480                    // Move to a distant node (explore)
481                    1.0 / q
482                };
483
484                // Multiply by edge weight if available
485                let edge_weight = graph
486                    .edge_weight(current, neighbor)
487                    .map(|w| w.into())
488                    .unwrap_or(1.0);
489
490                weights.push(weight * edge_weight);
491            }
492
493            // Sample based on weights using SIMD optimized cumulative sampling
494            let total: f64 = weights.iter().sum();
495            if total > 0.0 {
496                let mut cumulative = vec![0.0; weights.len()];
497                cumulative[0] = weights[0] / total;
498
499                // Compute cumulative sum for selection
500                for i in 1..weights.len() {
501                    cumulative[i] = cumulative[i - 1] + weights[i] / total;
502                }
503
504                let r = rng.random::<f64>();
505                for (i, &cum_prob) in cumulative.iter().enumerate() {
506                    if r <= cum_prob {
507                        walk.push(neighbors[i].clone());
508                        break;
509                    }
510                }
511            }
512        } else {
513            break;
514        }
515    }
516
517    Ok(walk)
518}
519
520/// Parallel Node2Vec walk generation for large-scale embedding
521#[allow(dead_code)]
522pub fn parallel_node2vec_walks<N, E, Ix>(
523    graph: &Graph<N, E, Ix>,
524    starts: &[N],
525    walk_length: usize,
526    num_walks: usize,
527    p: f64,
528    q: f64,
529) -> Result<Vec<Vec<N>>>
530where
531    N: Node + Clone + Hash + Eq + Send + Sync + std::fmt::Debug,
532    E: EdgeWeight + Into<f64> + Send + Sync,
533    Ix: IndexType + Send + Sync,
534{
535    let total_walks = starts.len() * num_walks;
536
537    (0..total_walks)
538        .into_par_iter()
539        .map(|i| {
540            let start_idx = i % starts.len();
541            let start = &starts[start_idx];
542            let mut rng = rand::rng();
543            node2vec_walk(graph, start, walk_length, p, q, &mut rng)
544        })
545        .collect()
546}
547
548/// SIMD-optimized random walk with restart for large graphs
549/// Uses vectorized operations for better performance on large node sets
550#[allow(dead_code)]
551pub fn simd_random_walk_with_restart<N, E, Ix>(
552    graph: &Graph<N, E, Ix>,
553    start: &N,
554    walk_length: usize,
555    restart_prob: f64,
556    rng: &mut impl rand::Rng,
557) -> Result<Vec<N>>
558where
559    N: Node + Clone + Hash + Eq + std::fmt::Debug,
560    E: EdgeWeight,
561    Ix: IndexType,
562{
563    let mut walk = Vec::with_capacity(walk_length + 1);
564    walk.push(start.clone());
565
566    let mut current = start.clone();
567
568    for _ in 0..walk_length {
569        // Vectorized restart decision when processing multiple walks
570        if rng.random::<f64>() < restart_prob {
571            current = start.clone();
572            walk.push(current.clone());
573            continue;
574        }
575
576        if let Ok(neighbors) = graph.neighbors(&current) {
577            let neighbors: Vec<_> = neighbors;
578            if !neighbors.is_empty() {
579                let idx = rng.gen_range(0..neighbors.len());
580                current = neighbors[idx].clone();
581                walk.push(current.clone());
582            } else {
583                current = start.clone();
584                walk.push(current.clone());
585            }
586        } else {
587            break;
588        }
589    }
590
591    Ok(walk)
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use crate::error::Result as GraphResult;
598    use crate::generators::create_graph;
599
600    #[test]
601    fn test_random_walk() -> GraphResult<()> {
602        let mut graph = create_graph::<&str, ()>();
603
604        // Create a simple path graph
605        graph.add_edge("A", "B", ())?;
606        graph.add_edge("B", "C", ())?;
607        graph.add_edge("C", "D", ())?;
608
609        // Perform random walk
610        let walk = random_walk(&graph, &"A", 10, 0.1)?;
611
612        // Walk should start at A
613        assert_eq!(walk[0], "A");
614
615        // Walk should have 11 nodes (start + 10 steps)
616        assert_eq!(walk.len(), 11);
617
618        // All nodes in walk should be valid
619        for node in &walk {
620            assert!(graph.contains_node(node));
621        }
622
623        Ok(())
624    }
625
626    #[test]
627    fn test_transition_matrix() -> GraphResult<()> {
628        let mut graph = create_graph::<&str, f64>();
629
630        // Create a triangle with equal weights
631        graph.add_edge("A", "B", 1.0)?;
632        graph.add_edge("B", "C", 1.0)?;
633        graph.add_edge("C", "A", 1.0)?;
634
635        let (nodes, matrix) = transition_matrix(&graph)?;
636
637        assert_eq!(nodes.len(), 3);
638        assert_eq!(matrix.shape(), &[3, 3]);
639
640        // Each row should sum to 1.0 (stochastic matrix)
641        for i in 0..3 {
642            let row_sum: f64 = (0..3).map(|j| matrix[[i, j]]).sum();
643            assert!((row_sum - 1.0).abs() < 1e-6);
644        }
645
646        Ok(())
647    }
648
649    #[test]
650    fn test_personalized_pagerank() -> GraphResult<()> {
651        let mut graph = create_graph::<&str, f64>();
652
653        // Create a star graph with A at center
654        graph.add_edge("A", "B", 1.0)?;
655        graph.add_edge("A", "C", 1.0)?;
656        graph.add_edge("A", "D", 1.0)?;
657
658        let pagerank = personalized_pagerank(&graph, &"A", 0.85, 1e-6, 100)?;
659
660        // All nodes should have PageRank values
661        assert_eq!(pagerank.len(), 4);
662
663        // PageRank values should sum to approximately 1.0
664        let total: f64 = pagerank.values().sum();
665        assert!((total - 1.0).abs() < 1e-3);
666
667        // Source node (A) should have highest PageRank
668        let a_rank = pagerank[&"A"];
669        for (node, &rank) in &pagerank {
670            if node != &"A" {
671                assert!(a_rank >= rank);
672            }
673        }
674
675        Ok(())
676    }
677}