Skip to main content

scirs2_graph/embeddings/
random_walk.rs

1//! Random walk generation for graph embeddings
2//!
3//! Provides random walk generators for both undirected and directed graphs,
4//! including simple uniform walks and Node2Vec biased walks.
5
6use super::types::RandomWalk;
7use crate::base::{DiGraph, EdgeWeight, Graph, Node};
8use crate::error::{GraphError, Result};
9use scirs2_core::random::rand_prelude::IndexedRandom;
10use scirs2_core::random::Rng;
11
12/// Random walk generator for graphs
13pub struct RandomWalkGenerator<N: Node> {
14    /// Random number generator
15    rng: scirs2_core::random::rngs::ThreadRng,
16    /// Phantom marker for node type
17    _phantom: std::marker::PhantomData<N>,
18}
19
20impl<N: Node> Default for RandomWalkGenerator<N> {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl<N: Node> RandomWalkGenerator<N> {
27    /// Create a new random walk generator
28    pub fn new() -> Self {
29        RandomWalkGenerator {
30            rng: scirs2_core::random::rng(),
31            _phantom: std::marker::PhantomData,
32        }
33    }
34
35    /// Generate a simple random walk from a starting node (undirected graph)
36    pub fn simple_random_walk<E, Ix>(
37        &mut self,
38        graph: &Graph<N, E, Ix>,
39        start: &N,
40        length: usize,
41    ) -> Result<RandomWalk<N>>
42    where
43        N: Clone + std::fmt::Debug,
44        E: EdgeWeight,
45        Ix: petgraph::graph::IndexType,
46    {
47        if !graph.contains_node(start) {
48            return Err(GraphError::node_not_found("node"));
49        }
50
51        let mut walk = vec![start.clone()];
52        let mut current = start.clone();
53
54        for _ in 1..length {
55            let neighbors = graph.neighbors(&current)?;
56            if neighbors.is_empty() {
57                break; // No outgoing edges, stop walk
58            }
59
60            current = neighbors
61                .choose(&mut self.rng)
62                .ok_or(GraphError::AlgorithmError(
63                    "Failed to choose neighbor".to_string(),
64                ))?
65                .clone();
66            walk.push(current.clone());
67        }
68
69        Ok(RandomWalk { nodes: walk })
70    }
71
72    /// Generate a simple random walk on a directed graph (follows outgoing edges)
73    pub fn simple_random_walk_digraph<E, Ix>(
74        &mut self,
75        graph: &DiGraph<N, E, Ix>,
76        start: &N,
77        length: usize,
78    ) -> Result<RandomWalk<N>>
79    where
80        N: Clone + std::fmt::Debug,
81        E: EdgeWeight,
82        Ix: petgraph::graph::IndexType,
83    {
84        if !graph.contains_node(start) {
85            return Err(GraphError::node_not_found("node"));
86        }
87
88        let mut walk = vec![start.clone()];
89        let mut current = start.clone();
90
91        for _ in 1..length {
92            let successors = graph.successors(&current)?;
93            if successors.is_empty() {
94                break; // No outgoing edges, stop walk
95            }
96
97            current = successors
98                .choose(&mut self.rng)
99                .ok_or(GraphError::AlgorithmError(
100                    "Failed to choose successor".to_string(),
101                ))?
102                .clone();
103            walk.push(current.clone());
104        }
105
106        Ok(RandomWalk { nodes: walk })
107    }
108
109    /// Generate a Node2Vec biased random walk (undirected graph)
110    ///
111    /// Uses biased second-order random walks controlled by parameters p and q:
112    /// - p: Return parameter. Higher p makes it less likely to return to the previous node.
113    /// - q: In-out parameter. Higher q biases towards nodes close to the previous node (BFS-like).
114    ///   Lower q biases towards unexplored nodes (DFS-like).
115    pub fn node2vec_walk<E, Ix>(
116        &mut self,
117        graph: &Graph<N, E, Ix>,
118        start: &N,
119        length: usize,
120        p: f64,
121        q: f64,
122    ) -> Result<RandomWalk<N>>
123    where
124        N: Clone + std::fmt::Debug,
125        E: EdgeWeight + Into<f64>,
126        Ix: petgraph::graph::IndexType,
127    {
128        if !graph.contains_node(start) {
129            return Err(GraphError::node_not_found("node"));
130        }
131
132        if p <= 0.0 || q <= 0.0 {
133            return Err(GraphError::InvalidParameter {
134                param: "p/q".to_string(),
135                value: format!("p={p}, q={q}"),
136                expected: "p > 0 and q > 0".to_string(),
137                context: "Node2Vec walk parameters".to_string(),
138            });
139        }
140
141        let mut walk = vec![start.clone()];
142        if length <= 1 {
143            return Ok(RandomWalk { nodes: walk });
144        }
145
146        // First step is unbiased
147        let first_neighbors = graph.neighbors(start)?;
148        if first_neighbors.is_empty() {
149            return Ok(RandomWalk { nodes: walk });
150        }
151
152        let mut current = first_neighbors
153            .choose(&mut self.rng)
154            .ok_or(GraphError::AlgorithmError(
155                "Failed to choose first neighbor".to_string(),
156            ))?
157            .clone();
158        walk.push(current.clone());
159
160        // Subsequent steps use biased sampling
161        for _ in 2..length {
162            let current_neighbors = graph.neighbors(&current)?;
163            if current_neighbors.is_empty() {
164                break;
165            }
166
167            let prev = &walk[walk.len() - 2];
168            let mut weights = Vec::new();
169
170            for neighbor in &current_neighbors {
171                let weight = if neighbor == prev {
172                    // Return to previous node
173                    1.0 / p
174                } else if graph.has_edge(prev, neighbor) {
175                    // Neighbor is also connected to previous node (BFS-like)
176                    1.0
177                } else {
178                    // New exploration (DFS-like)
179                    1.0 / q
180                };
181                weights.push(weight);
182            }
183
184            // Weighted random selection
185            let total_weight: f64 = weights.iter().sum();
186            if total_weight <= 0.0 {
187                break;
188            }
189
190            let mut random_value = self.rng.random::<f64>() * total_weight;
191            let mut selected_index = 0;
192
193            for (i, &weight) in weights.iter().enumerate() {
194                random_value -= weight;
195                if random_value <= 0.0 {
196                    selected_index = i;
197                    break;
198                }
199            }
200
201            let next_node = current_neighbors[selected_index].clone();
202            walk.push(next_node.clone());
203            // Update current for next iteration (FIXED: was previously not updating)
204            current = next_node;
205        }
206
207        Ok(RandomWalk { nodes: walk })
208    }
209
210    /// Generate a Node2Vec biased random walk on a directed graph
211    ///
212    /// Follows outgoing edges with the same p,q bias scheme as the undirected version.
213    pub fn node2vec_walk_digraph<E, Ix>(
214        &mut self,
215        graph: &DiGraph<N, E, Ix>,
216        start: &N,
217        length: usize,
218        p: f64,
219        q: f64,
220    ) -> Result<RandomWalk<N>>
221    where
222        N: Clone + std::fmt::Debug,
223        E: EdgeWeight + Into<f64>,
224        Ix: petgraph::graph::IndexType,
225    {
226        if !graph.contains_node(start) {
227            return Err(GraphError::node_not_found("node"));
228        }
229
230        if p <= 0.0 || q <= 0.0 {
231            return Err(GraphError::InvalidParameter {
232                param: "p/q".to_string(),
233                value: format!("p={p}, q={q}"),
234                expected: "p > 0 and q > 0".to_string(),
235                context: "Node2Vec walk parameters".to_string(),
236            });
237        }
238
239        let mut walk = vec![start.clone()];
240        if length <= 1 {
241            return Ok(RandomWalk { nodes: walk });
242        }
243
244        // First step is unbiased
245        let first_successors = graph.successors(start)?;
246        if first_successors.is_empty() {
247            return Ok(RandomWalk { nodes: walk });
248        }
249
250        let mut current = first_successors
251            .choose(&mut self.rng)
252            .ok_or(GraphError::AlgorithmError(
253                "Failed to choose first successor".to_string(),
254            ))?
255            .clone();
256        walk.push(current.clone());
257
258        // Subsequent steps use biased sampling
259        for _ in 2..length {
260            let current_successors = graph.successors(&current)?;
261            if current_successors.is_empty() {
262                break;
263            }
264
265            let prev = &walk[walk.len() - 2];
266            let mut weights = Vec::new();
267
268            for neighbor in &current_successors {
269                let weight = if neighbor == prev {
270                    1.0 / p
271                } else if graph.has_edge(prev, neighbor) {
272                    1.0
273                } else {
274                    1.0 / q
275                };
276                weights.push(weight);
277            }
278
279            let total_weight: f64 = weights.iter().sum();
280            if total_weight <= 0.0 {
281                break;
282            }
283
284            let mut random_value = self.rng.random::<f64>() * total_weight;
285            let mut selected_index = 0;
286
287            for (i, &weight) in weights.iter().enumerate() {
288                random_value -= weight;
289                if random_value <= 0.0 {
290                    selected_index = i;
291                    break;
292                }
293            }
294
295            let next_node = current_successors[selected_index].clone();
296            walk.push(next_node.clone());
297            current = next_node;
298        }
299
300        Ok(RandomWalk { nodes: walk })
301    }
302
303    /// Generate multiple random walks from a starting node
304    pub fn generate_walks<E, Ix>(
305        &mut self,
306        graph: &Graph<N, E, Ix>,
307        start: &N,
308        num_walks: usize,
309        walk_length: usize,
310    ) -> Result<Vec<RandomWalk<N>>>
311    where
312        N: Clone + std::fmt::Debug,
313        E: EdgeWeight,
314        Ix: petgraph::graph::IndexType,
315    {
316        let mut walks = Vec::new();
317        for _ in 0..num_walks {
318            let walk = self.simple_random_walk(graph, start, walk_length)?;
319            walks.push(walk);
320        }
321        Ok(walks)
322    }
323
324    /// Generate multiple random walks from a starting node on a directed graph
325    pub fn generate_walks_digraph<E, Ix>(
326        &mut self,
327        graph: &DiGraph<N, E, Ix>,
328        start: &N,
329        num_walks: usize,
330        walk_length: usize,
331    ) -> Result<Vec<RandomWalk<N>>>
332    where
333        N: Clone + std::fmt::Debug,
334        E: EdgeWeight,
335        Ix: petgraph::graph::IndexType,
336    {
337        let mut walks = Vec::new();
338        for _ in 0..num_walks {
339            let walk = self.simple_random_walk_digraph(graph, start, walk_length)?;
340            walks.push(walk);
341        }
342        Ok(walks)
343    }
344}