scirs2_graph/embeddings/
random_walk.rs

1//! Random walk generation for graph embeddings
2
3use super::types::RandomWalk;
4use crate::base::{EdgeWeight, Graph, Node};
5use crate::error::{GraphError, Result};
6use rand::prelude::IndexedRandom;
7use rand::seq::SliceRandom;
8use rand::Rng;
9
10/// Random walk generator for graphs
11pub struct RandomWalkGenerator<N: Node> {
12    /// Random number generator
13    rng: rand::rngs::ThreadRng,
14    /// Phantom marker for node type
15    _phantom: std::marker::PhantomData<N>,
16}
17
18impl<N: Node> Default for RandomWalkGenerator<N> {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl<N: Node> RandomWalkGenerator<N> {
25    /// Create a new random walk generator
26    pub fn new() -> Self {
27        RandomWalkGenerator {
28            rng: rand::rng(),
29            _phantom: std::marker::PhantomData,
30        }
31    }
32
33    /// Generate a simple random walk from a starting node
34    pub fn simple_random_walk<E, Ix>(
35        &mut self,
36        graph: &Graph<N, E, Ix>,
37        start: &N,
38        length: usize,
39    ) -> Result<RandomWalk<N>>
40    where
41        N: Clone + std::fmt::Debug,
42        E: EdgeWeight,
43        Ix: petgraph::graph::IndexType,
44    {
45        if !graph.contains_node(start) {
46            return Err(GraphError::node_not_found("node"));
47        }
48
49        let mut walk = vec![start.clone()];
50        let mut current = start.clone();
51
52        for _ in 1..length {
53            let neighbors = graph.neighbors(&current)?;
54            if neighbors.is_empty() {
55                break; // No outgoing edges, stop walk
56            }
57
58            current = neighbors
59                .choose(&mut self.rng)
60                .ok_or(GraphError::AlgorithmError(
61                    "Failed to choose neighbor".to_string(),
62                ))?
63                .clone();
64            walk.push(current.clone());
65        }
66
67        Ok(RandomWalk { nodes: walk })
68    }
69
70    /// Generate a Node2Vec biased random walk
71    pub fn node2vec_walk<E, Ix>(
72        &mut self,
73        graph: &Graph<N, E, Ix>,
74        start: &N,
75        length: usize,
76        p: f64,
77        q: f64,
78    ) -> Result<RandomWalk<N>>
79    where
80        N: Clone + std::fmt::Debug,
81        E: EdgeWeight + Into<f64>,
82        Ix: petgraph::graph::IndexType,
83    {
84        if !graph.contains_node(start) {
85            return Err(GraphError::node_not_found("node"));
86        }
87
88        let mut walk = vec![start.clone()];
89        if length == 1 {
90            return Ok(RandomWalk { nodes: walk });
91        }
92
93        // First step is unbiased
94        let first_neighbors = graph.neighbors(start)?;
95        if first_neighbors.is_empty() {
96            return Ok(RandomWalk { nodes: walk });
97        }
98
99        let current = first_neighbors
100            .choose(&mut self.rng)
101            .ok_or(GraphError::AlgorithmError(
102                "Failed to choose first neighbor".to_string(),
103            ))?
104            .clone();
105        walk.push(current.clone());
106
107        // Subsequent steps use biased sampling
108        for _ in 2..length {
109            let current_neighbors = graph.neighbors(&current)?;
110            if current_neighbors.is_empty() {
111                break;
112            }
113
114            let prev = &walk[walk.len() - 2];
115            let mut weights = Vec::new();
116
117            for neighbor in &current_neighbors {
118                let weight = if neighbor == prev {
119                    // Return to previous node
120                    1.0 / p
121                } else if graph.has_edge(prev, neighbor) {
122                    // Neighbor is also connected to previous node
123                    1.0
124                } else {
125                    // New exploration
126                    1.0 / q
127                };
128                weights.push(weight);
129            }
130
131            // Weighted random selection
132            let total_weight: f64 = weights.iter().sum();
133            let mut random_value = self.rng.random::<f64>() * total_weight;
134            let mut selected_index = 0;
135
136            for (i, &weight) in weights.iter().enumerate() {
137                random_value -= weight;
138                if random_value <= 0.0 {
139                    selected_index = i;
140                    break;
141                }
142            }
143
144            let next_node = current_neighbors[selected_index].clone();
145            walk.push(next_node.clone());
146            // Update current for next iteration - this line was originally incorrect
147            // let _current = next_node;
148        }
149
150        Ok(RandomWalk { nodes: walk })
151    }
152
153    /// Generate multiple random walks from a starting node
154    pub fn generate_walks<E, Ix>(
155        &mut self,
156        graph: &Graph<N, E, Ix>,
157        start: &N,
158        num_walks: usize,
159        walk_length: usize,
160    ) -> Result<Vec<RandomWalk<N>>>
161    where
162        N: Clone + std::fmt::Debug,
163        E: EdgeWeight,
164        Ix: petgraph::graph::IndexType,
165    {
166        let mut walks = Vec::new();
167        for _ in 0..num_walks {
168            let walk = self.simple_random_walk(graph, start, walk_length)?;
169            walks.push(walk);
170        }
171        Ok(walks)
172    }
173}