rustworkx_core/traversal/
random_walk.rs1use petgraph::{Direction::Outgoing, visit::IntoNeighborsDirected};
16use std::hash::Hash;
17
18use hashbrown::HashMap;
19
20use rand::prelude::*;
21use rand::rngs::SysRng;
22use rand_pcg::Pcg64;
23
24pub fn generate_random_path<G>(
56 graph: G,
57 source: G::NodeId,
58 length: usize,
59 seed: Option<u64>,
60) -> Vec<G::NodeId>
61where
62 G: IntoNeighborsDirected,
63 G::NodeId: Eq + Hash,
64{
65 let mut rng: Pcg64 = match seed {
66 Some(seed) => Pcg64::seed_from_u64(seed),
67 None => Pcg64::try_from_rng(&mut SysRng).unwrap(),
68 };
69
70 let mut degrees: HashMap<G::NodeId, usize> = HashMap::new();
71 let mut get_degree_lazy = |u: G::NodeId| {
72 *degrees
73 .entry(u)
74 .or_insert_with(|| graph.neighbors_directed(u, Outgoing).count())
75 };
76
77 let mut path = Vec::with_capacity(length + 1);
78 let mut current_node = source;
79 path.push(source);
80 for _ in 0..length {
81 let degree = get_degree_lazy(current_node);
82 if degree == 0 {
83 return path;
84 }
85 let idx = rng.random_range(..degree);
86 let neighbor = graph
87 .neighbors_directed(current_node, Outgoing)
88 .nth(idx)
89 .unwrap();
90 path.push(neighbor);
91 current_node = neighbor;
92 }
93 path
94}
95
96#[cfg(test)]
97mod tests {
98 use crate::traversal::generate_random_path;
99 use hashbrown::HashMap;
100 use petgraph::graph::{DiGraph, UnGraph};
101
102 #[test]
103 fn test_degree_zero_shorter_path() {
104 let mut graph: DiGraph<(), ()> = DiGraph::with_capacity(2, 1);
105 let a = graph.add_node(());
106 let b = graph.add_node(());
107 graph.add_edge(a, b, ());
108
109 assert_eq!(generate_random_path(&graph, a, 10, None), vec![a, b]);
111 }
112
113 #[test]
114 fn test_alternating_path_digraph() {
115 let mut graph: DiGraph<(), ()> = DiGraph::with_capacity(2, 1);
116 let a = graph.add_node(());
117 let b = graph.add_node(());
118 graph.add_edge(a, b, ());
119 graph.add_edge(b, a, ());
120
121 assert!(generate_random_path(&graph, a, 3, None) == vec![a, b, a, b]);
122 }
123
124 #[test]
125 fn test_alternating_path_graph() {
126 let mut graph: UnGraph<(), ()> = UnGraph::with_capacity(2, 1);
127 let a = graph.add_node(());
128 let b = graph.add_node(());
129 graph.add_edge(a, b, ());
130
131 assert!(generate_random_path(&graph, a, 3, None) == vec![a, b, a, b]);
132 }
133
134 #[test]
135 fn test_path_visit_frequency() {
136 let mut graph: UnGraph<(), ()> = UnGraph::with_capacity(7, 8);
140
141 let a = graph.add_node(());
142 let b = graph.add_node(());
143 let c = graph.add_node(());
144 let d = graph.add_node(());
145 let e = graph.add_node(());
146 let f = graph.add_node(());
147 let g = graph.add_node(());
148
149 graph.extend_with_edges([(a, b), (b, c), (c, d), (e, f), (c, e), (c, f), (f, g)]);
150
151 let path_length = 5_000;
152 let mut frequencies = generate_random_path(&graph, a, path_length, Some(5))
153 .iter()
154 .copied()
155 .fold(HashMap::new(), |mut map, val| {
156 map.entry(val).and_modify(|frq| *frq += 1_f64).or_insert(1.);
157 map
158 });
159 for (_, k) in frequencies.iter_mut() {
160 *k /= path_length as f64 + 1.;
161 }
162
163 let tol = 1e-2;
165 assert!((frequencies.get(&a).unwrap() - 1. / 14.).abs() < tol);
166 assert!((frequencies.get(&b).unwrap() - 2. / 14.).abs() < tol);
167 assert!((frequencies.get(&c).unwrap() - 4. / 14.).abs() < tol);
168 assert!((frequencies.get(&d).unwrap() - 1. / 14.).abs() < tol);
169 assert!((frequencies.get(&e).unwrap() - 2. / 14.).abs() < tol);
170 assert!((frequencies.get(&f).unwrap() - 3. / 14.).abs() < tol);
171 assert!((frequencies.get(&g).unwrap() - 1. / 14.).abs() < tol);
172 }
173}