scirs2_graph/embeddings/
random_walk.rs1use super::types::RandomWalk;
4use crate::base::{EdgeWeight, Graph, Node};
5use crate::error::{GraphError, Result};
6use rand::prelude::IndexedRandom;
7use rand::seq::SliceRandom;
8use rand::Rng;
9
10pub struct RandomWalkGenerator<N: Node> {
12 rng: rand::rngs::ThreadRng,
14 _phantom: std::marker::PhantomData<N>,
16}
17
18impl<N: Node> Default for RandomWalkGenerator<N> {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl<N: Node> RandomWalkGenerator<N> {
25 pub fn new() -> Self {
27 RandomWalkGenerator {
28 rng: rand::rng(),
29 _phantom: std::marker::PhantomData,
30 }
31 }
32
33 pub fn simple_random_walk<E, Ix>(
35 &mut self,
36 graph: &Graph<N, E, Ix>,
37 start: &N,
38 length: usize,
39 ) -> Result<RandomWalk<N>>
40 where
41 N: Clone + std::fmt::Debug,
42 E: EdgeWeight,
43 Ix: petgraph::graph::IndexType,
44 {
45 if !graph.contains_node(start) {
46 return Err(GraphError::node_not_found("node"));
47 }
48
49 let mut walk = vec![start.clone()];
50 let mut current = start.clone();
51
52 for _ in 1..length {
53 let neighbors = graph.neighbors(¤t)?;
54 if neighbors.is_empty() {
55 break; }
57
58 current = neighbors
59 .choose(&mut self.rng)
60 .ok_or(GraphError::AlgorithmError(
61 "Failed to choose neighbor".to_string(),
62 ))?
63 .clone();
64 walk.push(current.clone());
65 }
66
67 Ok(RandomWalk { nodes: walk })
68 }
69
70 pub fn node2vec_walk<E, Ix>(
72 &mut self,
73 graph: &Graph<N, E, Ix>,
74 start: &N,
75 length: usize,
76 p: f64,
77 q: f64,
78 ) -> Result<RandomWalk<N>>
79 where
80 N: Clone + std::fmt::Debug,
81 E: EdgeWeight + Into<f64>,
82 Ix: petgraph::graph::IndexType,
83 {
84 if !graph.contains_node(start) {
85 return Err(GraphError::node_not_found("node"));
86 }
87
88 let mut walk = vec![start.clone()];
89 if length == 1 {
90 return Ok(RandomWalk { nodes: walk });
91 }
92
93 let first_neighbors = graph.neighbors(start)?;
95 if first_neighbors.is_empty() {
96 return Ok(RandomWalk { nodes: walk });
97 }
98
99 let current = first_neighbors
100 .choose(&mut self.rng)
101 .ok_or(GraphError::AlgorithmError(
102 "Failed to choose first neighbor".to_string(),
103 ))?
104 .clone();
105 walk.push(current.clone());
106
107 for _ in 2..length {
109 let current_neighbors = graph.neighbors(¤t)?;
110 if current_neighbors.is_empty() {
111 break;
112 }
113
114 let prev = &walk[walk.len() - 2];
115 let mut weights = Vec::new();
116
117 for neighbor in ¤t_neighbors {
118 let weight = if neighbor == prev {
119 1.0 / p
121 } else if graph.has_edge(prev, neighbor) {
122 1.0
124 } else {
125 1.0 / q
127 };
128 weights.push(weight);
129 }
130
131 let total_weight: f64 = weights.iter().sum();
133 let mut random_value = self.rng.random::<f64>() * total_weight;
134 let mut selected_index = 0;
135
136 for (i, &weight) in weights.iter().enumerate() {
137 random_value -= weight;
138 if random_value <= 0.0 {
139 selected_index = i;
140 break;
141 }
142 }
143
144 let next_node = current_neighbors[selected_index].clone();
145 walk.push(next_node.clone());
146 }
149
150 Ok(RandomWalk { nodes: walk })
151 }
152
153 pub fn generate_walks<E, Ix>(
155 &mut self,
156 graph: &Graph<N, E, Ix>,
157 start: &N,
158 num_walks: usize,
159 walk_length: usize,
160 ) -> Result<Vec<RandomWalk<N>>>
161 where
162 N: Clone + std::fmt::Debug,
163 E: EdgeWeight,
164 Ix: petgraph::graph::IndexType,
165 {
166 let mut walks = Vec::new();
167 for _ in 0..num_walks {
168 let walk = self.simple_random_walk(graph, start, walk_length)?;
169 walks.push(walk);
170 }
171 Ok(walks)
172 }
173}