Skip to main content

rustworkx_core/traversal/
random_walk.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13//! Module for graph random walk algorithms.
14
15use petgraph::{Direction::Outgoing, visit::IntoNeighborsDirected};
16use std::hash::Hash;
17
18use hashbrown::HashMap;
19
20use rand::prelude::*;
21use rand::rngs::SysRng;
22use rand_pcg::Pcg64;
23
24/// Return a random path (or random walk) on the graph.
25///
26/// The next node to visit is selected uniformly at random from the outgoing
27/// neighbors. If a node has no outgoing neighbor, the path will stop early.
28/// The graph may be directed or not.
29///
30/// # Arguments:
31///
32/// * `graph` - Graph on which the random walk is done.
33/// * `source` - Starting node of the path.
34/// * `length` - Maximum length of the path.
35/// * `seed` - seed of the random number generator that chooses the next node.
36///
37/// # Returns
38///
39/// A vector of the visited nodes including the initial node `source`.
40///
41/// # Example
42///
43/// ```rust
44/// use petgraph::graph::DiGraph;
45/// use rustworkx_core::traversal::generate_random_path;
46///
47/// let mut graph: DiGraph<(), ()> = DiGraph::with_capacity(3, 3);
48/// let a = graph.add_node(());
49/// let b = graph.add_node(());
50/// let c = graph.add_node(());
51/// graph.extend_with_edges([(a, b), (b, c), (c, a)]);
52/// let path = generate_random_path(&graph, a, 3, Some(5));
53/// assert_eq!(path, vec![a, b, c, a]);
54/// ```
55pub fn generate_random_path<G>(
56    graph: G,
57    source: G::NodeId,
58    length: usize,
59    seed: Option<u64>,
60) -> Vec<G::NodeId>
61where
62    G: IntoNeighborsDirected,
63    G::NodeId: Eq + Hash,
64{
65    let mut rng: Pcg64 = match seed {
66        Some(seed) => Pcg64::seed_from_u64(seed),
67        None => Pcg64::try_from_rng(&mut SysRng).unwrap(),
68    };
69
70    let mut degrees: HashMap<G::NodeId, usize> = HashMap::new();
71    let mut get_degree_lazy = |u: G::NodeId| {
72        *degrees
73            .entry(u)
74            .or_insert_with(|| graph.neighbors_directed(u, Outgoing).count())
75    };
76
77    let mut path = Vec::with_capacity(length + 1);
78    let mut current_node = source;
79    path.push(source);
80    for _ in 0..length {
81        let degree = get_degree_lazy(current_node);
82        if degree == 0 {
83            return path;
84        }
85        let idx = rng.random_range(..degree);
86        let neighbor = graph
87            .neighbors_directed(current_node, Outgoing)
88            .nth(idx)
89            .unwrap();
90        path.push(neighbor);
91        current_node = neighbor;
92    }
93    path
94}
95
96#[cfg(test)]
97mod tests {
98    use crate::traversal::generate_random_path;
99    use hashbrown::HashMap;
100    use petgraph::graph::{DiGraph, UnGraph};
101
102    #[test]
103    fn test_degree_zero_shorter_path() {
104        let mut graph: DiGraph<(), ()> = DiGraph::with_capacity(2, 1);
105        let a = graph.add_node(());
106        let b = graph.add_node(());
107        graph.add_edge(a, b, ());
108
109        // Node b has no neighbor and the random walk terminates early.
110        assert_eq!(generate_random_path(&graph, a, 10, None), vec![a, b]);
111    }
112
113    #[test]
114    fn test_alternating_path_digraph() {
115        let mut graph: DiGraph<(), ()> = DiGraph::with_capacity(2, 1);
116        let a = graph.add_node(());
117        let b = graph.add_node(());
118        graph.add_edge(a, b, ());
119        graph.add_edge(b, a, ());
120
121        assert!(generate_random_path(&graph, a, 3, None) == vec![a, b, a, b]);
122    }
123
124    #[test]
125    fn test_alternating_path_graph() {
126        let mut graph: UnGraph<(), ()> = UnGraph::with_capacity(2, 1);
127        let a = graph.add_node(());
128        let b = graph.add_node(());
129        graph.add_edge(a, b, ());
130
131        assert!(generate_random_path(&graph, a, 3, None) == vec![a, b, a, b]);
132    }
133
134    #[test]
135    fn test_path_visit_frequency() {
136        // a -- b -- c -- d
137        //         / |
138        //      e -- f -- g
139        let mut graph: UnGraph<(), ()> = UnGraph::with_capacity(7, 8);
140
141        let a = graph.add_node(());
142        let b = graph.add_node(());
143        let c = graph.add_node(());
144        let d = graph.add_node(());
145        let e = graph.add_node(());
146        let f = graph.add_node(());
147        let g = graph.add_node(());
148
149        graph.extend_with_edges([(a, b), (b, c), (c, d), (e, f), (c, e), (c, f), (f, g)]);
150
151        let path_length = 5_000;
152        let mut frequencies = generate_random_path(&graph, a, path_length, Some(5))
153            .iter()
154            .copied()
155            .fold(HashMap::new(), |mut map, val| {
156                map.entry(val).and_modify(|frq| *frq += 1_f64).or_insert(1.);
157                map
158            });
159        for (_, k) in frequencies.iter_mut() {
160            *k /= path_length as f64 + 1.;
161        }
162
163        // Expected frequency is degree/2 number of edges.
164        let tol = 1e-2;
165        assert!((frequencies.get(&a).unwrap() - 1. / 14.).abs() < tol);
166        assert!((frequencies.get(&b).unwrap() - 2. / 14.).abs() < tol);
167        assert!((frequencies.get(&c).unwrap() - 4. / 14.).abs() < tol);
168        assert!((frequencies.get(&d).unwrap() - 1. / 14.).abs() < tol);
169        assert!((frequencies.get(&e).unwrap() - 2. / 14.).abs() < tol);
170        assert!((frequencies.get(&f).unwrap() - 3. / 14.).abs() < tol);
171        assert!((frequencies.get(&g).unwrap() - 1. / 14.).abs() < tol);
172    }
173}