1use bitvec::vec::BitVec;
2use log::{debug, info};
3use std::time::Instant;
4
5use crate::graph::{EdgeID, Graph, INVALID_NODE_ID, NodeID};
6
7pub struct DFS {
8 sources: Vec<NodeID>,
9 target_set: BitVec,
10 parents: Vec<NodeID>,
11 target: NodeID,
12 stack: Vec<usize>,
13 empty_target_set: bool,
14}
15
16impl DFS {
17 pub fn new(source_list: &[NodeID], target_list: &[NodeID], number_of_nodes: usize) -> Self {
19 let mut temp = Self {
20 sources: source_list.to_vec(),
21 target_set: BitVec::with_capacity(number_of_nodes),
22 parents: Vec::new(),
23 target: INVALID_NODE_ID,
24 stack: Vec::new(),
25 empty_target_set: target_list.is_empty(),
26 };
27
28 temp.target_set.resize(number_of_nodes, false);
30 for i in target_list {
31 temp.target_set.set(*i, true);
32 }
33
34 temp.populate_sources(number_of_nodes);
35 temp
36 }
37
38 fn populate_sources(&mut self, number_of_nodes: usize) {
39 self.parents.resize(number_of_nodes, INVALID_NODE_ID);
40 for s in &self.sources {
41 self.parents[*s] = *s;
42 }
43 }
44
45 pub fn run<T, G: Graph<T>>(&mut self, graph: &G) -> bool {
46 self.run_with_filter(graph, |_graph, _edge| false)
47 }
48
49 pub fn run_with_filter<T, F, G: Graph<T>>(&mut self, graph: &G, filter: F) -> bool
52 where
53 F: Fn(&G, EdgeID) -> bool,
54 {
55 let start = Instant::now();
56 self.stack.clear();
58 self.stack.extend(self.sources.iter().copied());
59
60 self.parents.fill(INVALID_NODE_ID);
62 for s in &self.sources {
63 self.parents[*s] = *s;
64 }
65
66 while let Some(node) = self.stack.pop() {
67 let node_is_source = self.parents[node] == node;
68 for edge in graph.edge_range(node) {
70 if filter(graph, edge) {
71 continue;
72 }
73 let target = graph.target(edge);
74 if self.parents[target] != INVALID_NODE_ID
75 || (node_is_source && self.parents[target] == target)
76 {
77 continue;
80 }
81 self.parents[target] = node;
82 unsafe {
83 if *self.target_set.get_unchecked(target) {
85 self.target = target;
86 debug!("setting target {}", self.target);
87 let duration = start.elapsed();
89 info!("D/DFS took: {duration:?} (done)");
90 return true;
91 }
92 }
93 self.stack.push(target);
94 }
95 }
96
97 let duration = start.elapsed();
98 info!("DFS took: {duration:?} (done)");
99
100 self.empty_target_set
102 }
103
104 pub fn fetch_node_path(&self) -> Vec<NodeID> {
107 self.fetch_node_path_from_node(self.target)
108 }
109
110 pub fn fetch_node_path_from_node(&self, t: NodeID) -> Vec<NodeID> {
114 let mut id = t;
115 let mut path = Vec::new();
116 while id != self.parents[id] {
117 path.push(id);
118 id = self.parents[id];
119 }
120 path.push(id);
121 path.reverse();
122 path
123 }
124
125 pub fn fetch_edge_path<T>(&self, graph: &impl Graph<T>) -> Vec<EdgeID> {
127 let mut id = self.target;
129 let mut path = Vec::new();
130 while id != self.parents[id] {
131 let edge_id = graph.find_edge(self.parents[id], id).unwrap();
132 path.push(edge_id);
133 id = self.parents[id];
134 }
135
136 path.reverse();
137 path
138 }
139
140 pub fn path_iter(&self) -> PathIter {
142 PathIter::new(self)
143 }
144}
145
146pub struct PathIter<'a> {
147 dfs: &'a DFS,
148 id: usize,
149}
150
151impl PathIter<'_> {
152 pub fn new(dfs: &DFS) -> PathIter {
153 debug!("init: {}", dfs.target);
154 PathIter {
155 dfs,
156 id: dfs.target,
157 }
158 }
159}
160
161impl Iterator for PathIter<'_> {
162 type Item = NodeID;
163 fn next(&mut self) -> Option<NodeID> {
164 if self.id == INVALID_NODE_ID {
165 return None;
167 }
168
169 let result = self.id;
171 self.id = self.dfs.parents[self.id];
172 if result == self.dfs.parents[result] {
173 self.id = INVALID_NODE_ID;
174 }
175 Some(result)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use crate::edge::InputEdge;
182 use crate::graph::Graph;
183 use crate::{dfs::DFS, static_graph::StaticGraph};
184
185 #[test]
186 fn s_t_query_fetch_node_string() {
187 type Graph = StaticGraph<i32>;
188 let edges = vec![
189 InputEdge::new(0, 1, 3),
190 InputEdge::new(1, 2, 3),
191 InputEdge::new(4, 2, 1),
192 InputEdge::new(2, 3, 6),
193 InputEdge::new(0, 4, 2),
194 InputEdge::new(4, 5, 2),
195 InputEdge::new(5, 3, 7),
196 InputEdge::new(1, 5, 2),
197 ];
198 let graph = Graph::new(edges);
199 let mut dfs = DFS::new(&[0], &[5], graph.number_of_nodes());
200 assert!(dfs.run(&graph));
201
202 let path = dfs.fetch_node_path();
203 assert_eq!(path, vec![0, 4, 5]);
204
205 let path: Vec<usize> = dfs.path_iter().collect();
206 assert_eq!(path, vec![5, 4, 0]);
207 }
208
209 #[test]
210 fn s_t_query_edge_list() {
211 type Graph = StaticGraph<i32>;
212 let edges = vec![
213 InputEdge::new(0, 1, 3),
214 InputEdge::new(1, 2, 3),
215 InputEdge::new(4, 2, 1),
216 InputEdge::new(2, 3, 6),
217 InputEdge::new(0, 4, 2),
218 InputEdge::new(4, 5, 2),
219 InputEdge::new(5, 3, 7),
220 InputEdge::new(1, 5, 2),
221 ];
222 let graph = Graph::new(edges);
223 let mut dfs = DFS::new(&[0], &[5], graph.number_of_nodes());
224 assert!(dfs.run(&graph));
225 let path = dfs.fetch_edge_path(&graph);
226 assert_eq!(path, vec![1, 6]);
227 }
228
229 #[test]
230 fn s_all_query() {
231 type Graph = StaticGraph<i32>;
232 let edges = vec![
233 InputEdge::new(0, 1, 3),
234 InputEdge::new(1, 2, 3),
235 InputEdge::new(4, 2, 1),
236 InputEdge::new(2, 3, 6),
237 InputEdge::new(0, 4, 2),
238 InputEdge::new(4, 5, 2),
239 InputEdge::new(5, 3, 7),
240 InputEdge::new(1, 5, 2),
241 ];
242 let graph = Graph::new(edges);
243 let mut dfs = DFS::new(&[0], &[], graph.number_of_nodes());
244 assert!(dfs.run(&graph));
245
246 let path = dfs.fetch_node_path_from_node(3);
247 assert_eq!(path, vec![0, 4, 5, 3]);
248
249 let path: Vec<usize> = dfs.path_iter().collect();
250 assert!(path.is_empty());
251 }
252
253 #[test]
254 fn multi_s_all_query() {
255 type Graph = StaticGraph<i32>;
256 let edges = vec![
257 InputEdge::new(0, 1, 3),
258 InputEdge::new(1, 2, 3),
259 InputEdge::new(4, 2, 1),
260 InputEdge::new(2, 3, 6),
261 InputEdge::new(0, 4, 2),
262 InputEdge::new(4, 5, 2),
263 InputEdge::new(5, 3, 7),
264 InputEdge::new(1, 5, 2),
265 ];
266 let graph = Graph::new(edges);
267 let mut dfs = DFS::new(&[0, 1], &[], graph.number_of_nodes());
268 assert!(dfs.run(&graph));
269
270 let path = dfs.fetch_node_path_from_node(3);
272 assert_eq!(path, vec![1, 5, 3]);
273
274 let path: Vec<usize> = dfs.path_iter().collect();
275 assert!(path.is_empty());
276 }
277}