scirs2_graph/embeddings/
random_walk.rs1use super::types::RandomWalk;
4use crate::base::{EdgeWeight, Graph, Node};
5use crate::error::{GraphError, Result};
6use scirs2_core::random::rand_prelude::IndexedRandom;
7use scirs2_core::random::Rng;
8
9pub struct RandomWalkGenerator<N: Node> {
11 rng: scirs2_core::random::rngs::ThreadRng,
13 _phantom: std::marker::PhantomData<N>,
15}
16
17impl<N: Node> Default for RandomWalkGenerator<N> {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl<N: Node> RandomWalkGenerator<N> {
24 pub fn new() -> Self {
26 RandomWalkGenerator {
27 rng: scirs2_core::random::rng(),
28 _phantom: std::marker::PhantomData,
29 }
30 }
31
32 pub fn simple_random_walk<E, Ix>(
34 &mut self,
35 graph: &Graph<N, E, Ix>,
36 start: &N,
37 length: usize,
38 ) -> Result<RandomWalk<N>>
39 where
40 N: Clone + std::fmt::Debug,
41 E: EdgeWeight,
42 Ix: petgraph::graph::IndexType,
43 {
44 if !graph.contains_node(start) {
45 return Err(GraphError::node_not_found("node"));
46 }
47
48 let mut walk = vec![start.clone()];
49 let mut current = start.clone();
50
51 for _ in 1..length {
52 let neighbors = graph.neighbors(¤t)?;
53 if neighbors.is_empty() {
54 break; }
56
57 current = neighbors
58 .choose(&mut self.rng)
59 .ok_or(GraphError::AlgorithmError(
60 "Failed to choose neighbor".to_string(),
61 ))?
62 .clone();
63 walk.push(current.clone());
64 }
65
66 Ok(RandomWalk { nodes: walk })
67 }
68
69 pub fn node2vec_walk<E, Ix>(
71 &mut self,
72 graph: &Graph<N, E, Ix>,
73 start: &N,
74 length: usize,
75 p: f64,
76 q: f64,
77 ) -> Result<RandomWalk<N>>
78 where
79 N: Clone + std::fmt::Debug,
80 E: EdgeWeight + Into<f64>,
81 Ix: petgraph::graph::IndexType,
82 {
83 if !graph.contains_node(start) {
84 return Err(GraphError::node_not_found("node"));
85 }
86
87 let mut walk = vec![start.clone()];
88 if length == 1 {
89 return Ok(RandomWalk { nodes: walk });
90 }
91
92 let first_neighbors = graph.neighbors(start)?;
94 if first_neighbors.is_empty() {
95 return Ok(RandomWalk { nodes: walk });
96 }
97
98 let current = first_neighbors
99 .choose(&mut self.rng)
100 .ok_or(GraphError::AlgorithmError(
101 "Failed to choose first neighbor".to_string(),
102 ))?
103 .clone();
104 walk.push(current.clone());
105
106 for _ in 2..length {
108 let current_neighbors = graph.neighbors(¤t)?;
109 if current_neighbors.is_empty() {
110 break;
111 }
112
113 let prev = &walk[walk.len() - 2];
114 let mut weights = Vec::new();
115
116 for neighbor in ¤t_neighbors {
117 let weight = if neighbor == prev {
118 1.0 / p
120 } else if graph.has_edge(prev, neighbor) {
121 1.0
123 } else {
124 1.0 / q
126 };
127 weights.push(weight);
128 }
129
130 let total_weight: f64 = weights.iter().sum();
132 let mut random_value = self.rng.random::<f64>() * total_weight;
133 let mut selected_index = 0;
134
135 for (i, &weight) in weights.iter().enumerate() {
136 random_value -= weight;
137 if random_value <= 0.0 {
138 selected_index = i;
139 break;
140 }
141 }
142
143 let next_node = current_neighbors[selected_index].clone();
144 walk.push(next_node.clone());
145 }
148
149 Ok(RandomWalk { nodes: walk })
150 }
151
152 pub fn generate_walks<E, Ix>(
154 &mut self,
155 graph: &Graph<N, E, Ix>,
156 start: &N,
157 num_walks: usize,
158 walk_length: usize,
159 ) -> Result<Vec<RandomWalk<N>>>
160 where
161 N: Clone + std::fmt::Debug,
162 E: EdgeWeight,
163 Ix: petgraph::graph::IndexType,
164 {
165 let mut walks = Vec::new();
166 for _ in 0..num_walks {
167 let walk = self.simple_random_walk(graph, start, walk_length)?;
168 walks.push(walk);
169 }
170 Ok(walks)
171 }
172}