toolbox_rs/
dfs.rs

1use bitvec::vec::BitVec;
2use log::{debug, info};
3use std::time::Instant;
4
5use crate::graph::{EdgeID, Graph, INVALID_NODE_ID, NodeID};
6
7pub struct DFS {
8    sources: Vec<NodeID>,
9    target_set: BitVec,
10    parents: Vec<NodeID>,
11    target: NodeID,
12    stack: Vec<usize>,
13    empty_target_set: bool,
14}
15
16impl DFS {
17    // TODO: Also pass Graph instance
18    pub fn new(source_list: &[NodeID], target_list: &[NodeID], number_of_nodes: usize) -> Self {
19        let mut temp = Self {
20            sources: source_list.to_vec(),
21            target_set: BitVec::with_capacity(number_of_nodes),
22            parents: Vec::new(),
23            target: INVALID_NODE_ID,
24            stack: Vec::new(),
25            empty_target_set: target_list.is_empty(),
26        };
27
28        // initialize bit vector storing which nodes are targets
29        temp.target_set.resize(number_of_nodes, false);
30        for i in target_list {
31            temp.target_set.set(*i, true);
32        }
33
34        temp.populate_sources(number_of_nodes);
35        temp
36    }
37
38    fn populate_sources(&mut self, number_of_nodes: usize) {
39        self.parents.resize(number_of_nodes, INVALID_NODE_ID);
40        for s in &self.sources {
41            self.parents[*s] = *s;
42        }
43    }
44
45    pub fn run<T, G: Graph<T>>(&mut self, graph: &G) -> bool {
46        self.run_with_filter(graph, |_graph, _edge| false)
47    }
48
49    /// explore the graph in a DFS
50    /// returns true if a path between s and t was found or no target was given
51    pub fn run_with_filter<T, F, G: Graph<T>>(&mut self, graph: &G, filter: F) -> bool
52    where
53        F: Fn(&G, EdgeID) -> bool,
54    {
55        let start = Instant::now();
56        // reset queue w/o allocating
57        self.stack.clear();
58        self.stack.extend(self.sources.iter().copied());
59
60        // reset parents vector
61        self.parents.fill(INVALID_NODE_ID);
62        for s in &self.sources {
63            self.parents[*s] = *s;
64        }
65
66        while let Some(node) = self.stack.pop() {
67            let node_is_source = self.parents[node] == node;
68            // sources have themselves as parents
69            for edge in graph.edge_range(node) {
70                if filter(graph, edge) {
71                    continue;
72                }
73                let target = graph.target(edge);
74                if self.parents[target] != INVALID_NODE_ID
75                    || (node_is_source && self.parents[target] == target)
76                {
77                    // we already have seen this node and can ignore it, or
78                    // edge is fully contained within source set and we can ignore it, too
79                    continue;
80                }
81                self.parents[target] = node;
82                unsafe {
83                    // unsafe is used for performance, here, as the graph is consistent by construction
84                    if *self.target_set.get_unchecked(target) {
85                        self.target = target;
86                        debug!("setting target {}", self.target);
87                        // check if we have found our target if it exists
88                        let duration = start.elapsed();
89                        info!("D/DFS took: {duration:?} (done)");
90                        return true;
91                    }
92                }
93                self.stack.push(target);
94            }
95        }
96
97        let duration = start.elapsed();
98        info!("DFS took: {duration:?} (done)");
99
100        // return true only if target set was empty
101        self.empty_target_set
102    }
103
104    // path unpacking, by searching for the first target that was found
105    // and by then unwinding the path of nodes to it.
106    pub fn fetch_node_path(&self) -> Vec<NodeID> {
107        self.fetch_node_path_from_node(self.target)
108    }
109
110    // path unpacking, by searching for the first target that was found
111    // and by then unwinding the path of nodes to it.
112    // TODO: needs test to check what happens when t is unknown, or unvisited. Can this be removed?
113    pub fn fetch_node_path_from_node(&self, t: NodeID) -> Vec<NodeID> {
114        let mut id = t;
115        let mut path = Vec::new();
116        while id != self.parents[id] {
117            path.push(id);
118            id = self.parents[id];
119        }
120        path.push(id);
121        path.reverse();
122        path
123    }
124
125    // TODO: the reverse might be unnecessary to some applications
126    pub fn fetch_edge_path<T>(&self, graph: &impl Graph<T>) -> Vec<EdgeID> {
127        // path unpacking
128        let mut id = self.target;
129        let mut path = Vec::new();
130        while id != self.parents[id] {
131            let edge_id = graph.find_edge(self.parents[id], id).unwrap();
132            path.push(edge_id);
133            id = self.parents[id];
134        }
135
136        path.reverse();
137        path
138    }
139
140    //TODO: Add test covering this iterator
141    pub fn path_iter(&self) -> PathIter {
142        PathIter::new(self)
143    }
144}
145
146pub struct PathIter<'a> {
147    dfs: &'a DFS,
148    id: usize,
149}
150
151impl PathIter<'_> {
152    pub fn new(dfs: &DFS) -> PathIter {
153        debug!("init: {}", dfs.target);
154        PathIter {
155            dfs,
156            id: dfs.target,
157        }
158    }
159}
160
161impl Iterator for PathIter<'_> {
162    type Item = NodeID;
163    fn next(&mut self) -> Option<NodeID> {
164        if self.id == INVALID_NODE_ID {
165            // INVALID_NODE_ID is the indicator that unpacking is done or not possible
166            return None;
167        }
168
169        // path unpacking step
170        let result = self.id;
171        self.id = self.dfs.parents[self.id];
172        if result == self.dfs.parents[result] {
173            self.id = INVALID_NODE_ID;
174        }
175        Some(result)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use crate::edge::InputEdge;
182    use crate::graph::Graph;
183    use crate::{dfs::DFS, static_graph::StaticGraph};
184
185    #[test]
186    fn s_t_query_fetch_node_string() {
187        type Graph = StaticGraph<i32>;
188        let edges = vec![
189            InputEdge::new(0, 1, 3),
190            InputEdge::new(1, 2, 3),
191            InputEdge::new(4, 2, 1),
192            InputEdge::new(2, 3, 6),
193            InputEdge::new(0, 4, 2),
194            InputEdge::new(4, 5, 2),
195            InputEdge::new(5, 3, 7),
196            InputEdge::new(1, 5, 2),
197        ];
198        let graph = Graph::new(edges);
199        let mut dfs = DFS::new(&[0], &[5], graph.number_of_nodes());
200        assert!(dfs.run(&graph));
201
202        let path = dfs.fetch_node_path();
203        assert_eq!(path, vec![0, 4, 5]);
204
205        let path: Vec<usize> = dfs.path_iter().collect();
206        assert_eq!(path, vec![5, 4, 0]);
207    }
208
209    #[test]
210    fn s_t_query_edge_list() {
211        type Graph = StaticGraph<i32>;
212        let edges = vec![
213            InputEdge::new(0, 1, 3),
214            InputEdge::new(1, 2, 3),
215            InputEdge::new(4, 2, 1),
216            InputEdge::new(2, 3, 6),
217            InputEdge::new(0, 4, 2),
218            InputEdge::new(4, 5, 2),
219            InputEdge::new(5, 3, 7),
220            InputEdge::new(1, 5, 2),
221        ];
222        let graph = Graph::new(edges);
223        let mut dfs = DFS::new(&[0], &[5], graph.number_of_nodes());
224        assert!(dfs.run(&graph));
225        let path = dfs.fetch_edge_path(&graph);
226        assert_eq!(path, vec![1, 6]);
227    }
228
229    #[test]
230    fn s_all_query() {
231        type Graph = StaticGraph<i32>;
232        let edges = vec![
233            InputEdge::new(0, 1, 3),
234            InputEdge::new(1, 2, 3),
235            InputEdge::new(4, 2, 1),
236            InputEdge::new(2, 3, 6),
237            InputEdge::new(0, 4, 2),
238            InputEdge::new(4, 5, 2),
239            InputEdge::new(5, 3, 7),
240            InputEdge::new(1, 5, 2),
241        ];
242        let graph = Graph::new(edges);
243        let mut dfs = DFS::new(&[0], &[], graph.number_of_nodes());
244        assert!(dfs.run(&graph));
245
246        let path = dfs.fetch_node_path_from_node(3);
247        assert_eq!(path, vec![0, 4, 5, 3]);
248
249        let path: Vec<usize> = dfs.path_iter().collect();
250        assert!(path.is_empty());
251    }
252
253    #[test]
254    fn multi_s_all_query() {
255        type Graph = StaticGraph<i32>;
256        let edges = vec![
257            InputEdge::new(0, 1, 3),
258            InputEdge::new(1, 2, 3),
259            InputEdge::new(4, 2, 1),
260            InputEdge::new(2, 3, 6),
261            InputEdge::new(0, 4, 2),
262            InputEdge::new(4, 5, 2),
263            InputEdge::new(5, 3, 7),
264            InputEdge::new(1, 5, 2),
265        ];
266        let graph = Graph::new(edges);
267        let mut dfs = DFS::new(&[0, 1], &[], graph.number_of_nodes());
268        assert!(dfs.run(&graph));
269
270        // path unpacking
271        let path = dfs.fetch_node_path_from_node(3);
272        assert_eq!(path, vec![1, 5, 3]);
273
274        let path: Vec<usize> = dfs.path_iter().collect();
275        assert!(path.is_empty());
276    }
277}