scirs2_graph/embeddings/
random_walk.rs

1//! Random walk generation for graph embeddings
2
3use super::types::RandomWalk;
4use crate::base::{EdgeWeight, Graph, Node};
5use crate::error::{GraphError, Result};
6use scirs2_core::random::rand_prelude::IndexedRandom;
7use scirs2_core::random::Rng;
8
9/// Random walk generator for graphs
10pub struct RandomWalkGenerator<N: Node> {
11    /// Random number generator
12    rng: scirs2_core::random::rngs::ThreadRng,
13    /// Phantom marker for node type
14    _phantom: std::marker::PhantomData<N>,
15}
16
17impl<N: Node> Default for RandomWalkGenerator<N> {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl<N: Node> RandomWalkGenerator<N> {
24    /// Create a new random walk generator
25    pub fn new() -> Self {
26        RandomWalkGenerator {
27            rng: scirs2_core::random::rng(),
28            _phantom: std::marker::PhantomData,
29        }
30    }
31
32    /// Generate a simple random walk from a starting node
33    pub fn simple_random_walk<E, Ix>(
34        &mut self,
35        graph: &Graph<N, E, Ix>,
36        start: &N,
37        length: usize,
38    ) -> Result<RandomWalk<N>>
39    where
40        N: Clone + std::fmt::Debug,
41        E: EdgeWeight,
42        Ix: petgraph::graph::IndexType,
43    {
44        if !graph.contains_node(start) {
45            return Err(GraphError::node_not_found("node"));
46        }
47
48        let mut walk = vec![start.clone()];
49        let mut current = start.clone();
50
51        for _ in 1..length {
52            let neighbors = graph.neighbors(&current)?;
53            if neighbors.is_empty() {
54                break; // No outgoing edges, stop walk
55            }
56
57            current = neighbors
58                .choose(&mut self.rng)
59                .ok_or(GraphError::AlgorithmError(
60                    "Failed to choose neighbor".to_string(),
61                ))?
62                .clone();
63            walk.push(current.clone());
64        }
65
66        Ok(RandomWalk { nodes: walk })
67    }
68
69    /// Generate a Node2Vec biased random walk
70    pub fn node2vec_walk<E, Ix>(
71        &mut self,
72        graph: &Graph<N, E, Ix>,
73        start: &N,
74        length: usize,
75        p: f64,
76        q: f64,
77    ) -> Result<RandomWalk<N>>
78    where
79        N: Clone + std::fmt::Debug,
80        E: EdgeWeight + Into<f64>,
81        Ix: petgraph::graph::IndexType,
82    {
83        if !graph.contains_node(start) {
84            return Err(GraphError::node_not_found("node"));
85        }
86
87        let mut walk = vec![start.clone()];
88        if length == 1 {
89            return Ok(RandomWalk { nodes: walk });
90        }
91
92        // First step is unbiased
93        let first_neighbors = graph.neighbors(start)?;
94        if first_neighbors.is_empty() {
95            return Ok(RandomWalk { nodes: walk });
96        }
97
98        let current = first_neighbors
99            .choose(&mut self.rng)
100            .ok_or(GraphError::AlgorithmError(
101                "Failed to choose first neighbor".to_string(),
102            ))?
103            .clone();
104        walk.push(current.clone());
105
106        // Subsequent steps use biased sampling
107        for _ in 2..length {
108            let current_neighbors = graph.neighbors(&current)?;
109            if current_neighbors.is_empty() {
110                break;
111            }
112
113            let prev = &walk[walk.len() - 2];
114            let mut weights = Vec::new();
115
116            for neighbor in &current_neighbors {
117                let weight = if neighbor == prev {
118                    // Return to previous node
119                    1.0 / p
120                } else if graph.has_edge(prev, neighbor) {
121                    // Neighbor is also connected to previous node
122                    1.0
123                } else {
124                    // New exploration
125                    1.0 / q
126                };
127                weights.push(weight);
128            }
129
130            // Weighted random selection
131            let total_weight: f64 = weights.iter().sum();
132            let mut random_value = self.rng.random::<f64>() * total_weight;
133            let mut selected_index = 0;
134
135            for (i, &weight) in weights.iter().enumerate() {
136                random_value -= weight;
137                if random_value <= 0.0 {
138                    selected_index = i;
139                    break;
140                }
141            }
142
143            let next_node = current_neighbors[selected_index].clone();
144            walk.push(next_node.clone());
145            // Update current for next iteration - this line was originally incorrect
146            // let _current = next_node;
147        }
148
149        Ok(RandomWalk { nodes: walk })
150    }
151
152    /// Generate multiple random walks from a starting node
153    pub fn generate_walks<E, Ix>(
154        &mut self,
155        graph: &Graph<N, E, Ix>,
156        start: &N,
157        num_walks: usize,
158        walk_length: usize,
159    ) -> Result<Vec<RandomWalk<N>>>
160    where
161        N: Clone + std::fmt::Debug,
162        E: EdgeWeight,
163        Ix: petgraph::graph::IndexType,
164    {
165        let mut walks = Vec::new();
166        for _ in 0..num_walks {
167            let walk = self.simple_random_walk(graph, start, walk_length)?;
168            walks.push(walk);
169        }
170        Ok(walks)
171    }
172}