1use std::collections::{HashMap, HashSet, VecDeque};
9use uni_common::core::id::{Eid, Vid};
10use uni_store::storage::adjacency_manager::AdjacencyManager;
11use uni_store::storage::direction::Direction;
12
13pub struct DirectTraversal<'a> {
18 am: &'a AdjacencyManager,
19 edge_types: Vec<u32>,
20}
21
22impl<'a> DirectTraversal<'a> {
23 pub fn new(am: &'a AdjacencyManager, edge_types: Vec<u32>) -> Self {
25 Self { am, edge_types }
26 }
27
28 pub fn neighbors(&self, vid: Vid, direction: Direction) -> Vec<(Vid, Eid)> {
30 let mut result = Vec::new();
31
32 for &edge_type in &self.edge_types {
33 let neighbors = self.am.get_neighbors(vid, edge_type, direction);
34 result.extend(neighbors);
35 }
36
37 result
38 }
39
40 pub fn bfs(&self, source: Vid, direction: Direction) -> BfsIterator<'_> {
42 BfsIterator::new(self, source, direction)
43 }
44
45 pub fn shortest_path(&self, source: Vid, target: Vid, direction: Direction) -> Option<Path> {
49 self.shortest_path_with_hops(source, target, direction, 0, u32::MAX)
50 }
51
52 pub fn shortest_path_with_hops(
64 &self,
65 source: Vid,
66 target: Vid,
67 direction: Direction,
68 min_hops: u32,
69 max_hops: u32,
70 ) -> Option<Path> {
71 if source == target {
73 if min_hops == 0 {
74 return Some(Path {
75 vertices: vec![source],
76 edges: Vec::new(),
77 });
78 } else {
79 return None;
81 }
82 }
83
84 if min_hops > max_hops {
86 return None;
87 }
88
89 let mut visited: HashMap<Vid, (Vid, Eid, u32)> = HashMap::default(); let mut frontier: VecDeque<(Vid, u32)> = VecDeque::new(); frontier.push_back((source, 0));
95 visited.insert(source, (source, Eid::new(0), 0)); while let Some((current, depth)) = frontier.pop_front() {
98 if depth >= max_hops {
100 continue;
101 }
102
103 for (neighbor, eid) in self.neighbors(current, direction) {
104 if visited.contains_key(&neighbor) {
105 continue;
106 }
107
108 let new_depth = depth + 1;
109 visited.insert(neighbor, (current, eid, new_depth));
110
111 if neighbor == target {
112 if new_depth >= min_hops && new_depth <= max_hops {
114 return Some(self.reconstruct_path_from_visited(source, target, &visited));
115 } else if new_depth < min_hops {
116 return None;
123 } else {
124 return None;
126 }
127 }
128
129 frontier.push_back((neighbor, new_depth));
130 }
131 }
132
133 None
134 }
135
136 fn reconstruct_path_from_visited(
138 &self,
139 source: Vid,
140 target: Vid,
141 visited: &HashMap<Vid, (Vid, Eid, u32)>,
142 ) -> Path {
143 let mut vertices = vec![target];
144 let mut edges = Vec::new();
145 let mut current = target;
146
147 while current != source {
148 if let Some(&(parent, eid, _)) = visited.get(¤t) {
149 edges.push(eid);
150 vertices.push(parent);
151 current = parent;
152 } else {
153 break;
154 }
155 }
156
157 vertices.reverse();
158 edges.reverse();
159
160 Path { vertices, edges }
161 }
162
163 pub fn all_shortest_paths_with_hops(
175 &self,
176 source: Vid,
177 target: Vid,
178 direction: Direction,
179 min_hops: u32,
180 max_hops: u32,
181 ) -> Vec<Path> {
182 if source == target {
184 if min_hops == 0 {
185 return vec![Path {
186 vertices: vec![source],
187 edges: Vec::new(),
188 }];
189 } else {
190 return Vec::new();
191 }
192 }
193
194 if min_hops > max_hops {
196 return Vec::new();
197 }
198
199 let dist_from_source = self.bfs_distances(source, direction, max_hops);
202
203 let shortest_dist = match dist_from_source.get(&target) {
205 Some(&d) if d >= min_hops && d <= max_hops => d,
206 Some(&d) if d < min_hops => return Vec::new(), _ => return Vec::new(), };
209
210 let mut all_paths = Vec::new();
212 let mut current_path = vec![source];
213 let mut current_edges = Vec::new();
214 let mut visited = HashSet::new();
215 visited.insert(source);
216
217 self.enumerate_shortest_paths(
218 source,
219 target,
220 direction,
221 shortest_dist,
222 0,
223 &dist_from_source,
224 &mut current_path,
225 &mut current_edges,
226 &mut visited,
227 &mut all_paths,
228 );
229
230 all_paths
231 }
232
233 fn bfs_distances(
235 &self,
236 source: Vid,
237 direction: Direction,
238 max_depth: u32,
239 ) -> HashMap<Vid, u32> {
240 let mut distances: HashMap<Vid, u32> = HashMap::default();
241 let mut frontier: VecDeque<(Vid, u32)> = VecDeque::new();
242
243 frontier.push_back((source, 0));
244 distances.insert(source, 0);
245
246 while let Some((current, depth)) = frontier.pop_front() {
247 if depth >= max_depth {
248 continue;
249 }
250
251 for (neighbor, _eid) in self.neighbors(current, direction) {
252 if let std::collections::hash_map::Entry::Vacant(e) = distances.entry(neighbor) {
253 let new_depth = depth + 1;
254 e.insert(new_depth);
255 frontier.push_back((neighbor, new_depth));
256 }
257 }
258 }
259
260 distances
261 }
262
263 #[allow(clippy::too_many_arguments)]
265 fn enumerate_shortest_paths(
266 &self,
267 current: Vid,
268 target: Vid,
269 direction: Direction,
270 target_dist: u32,
271 current_dist: u32,
272 dist_from_source: &HashMap<Vid, u32>,
273 current_path: &mut Vec<Vid>,
274 current_edges: &mut Vec<Eid>,
275 visited: &mut HashSet<Vid>,
276 all_paths: &mut Vec<Path>,
277 ) {
278 if current == target && current_dist == target_dist {
280 all_paths.push(Path {
281 vertices: current_path.clone(),
282 edges: current_edges.clone(),
283 });
284 return;
285 }
286
287 if current_dist >= target_dist {
289 return;
290 }
291
292 for (neighbor, eid) in self.neighbors(current, direction) {
294 if let Some(&neighbor_dist) = dist_from_source.get(&neighbor)
297 && neighbor_dist == current_dist + 1
298 && !visited.contains(&neighbor)
299 {
300 visited.insert(neighbor);
301 current_path.push(neighbor);
302 current_edges.push(eid);
303
304 self.enumerate_shortest_paths(
305 neighbor,
306 target,
307 direction,
308 target_dist,
309 current_dist + 1,
310 dist_from_source,
311 current_path,
312 current_edges,
313 visited,
314 all_paths,
315 );
316
317 current_path.pop();
318 current_edges.pop();
319 visited.remove(&neighbor);
320 }
321 }
322 }
323}
324
325pub struct BfsIterator<'a> {
327 traversal: &'a DirectTraversal<'a>,
328 frontier: VecDeque<(Vid, u32)>,
329 visited: HashSet<Vid>,
330 direction: Direction,
331}
332
333impl<'a> BfsIterator<'a> {
334 fn new(traversal: &'a DirectTraversal<'a>, source: Vid, direction: Direction) -> Self {
335 let mut frontier = VecDeque::new();
336 let mut visited = HashSet::default();
337
338 frontier.push_back((source, 0));
339 visited.insert(source);
340
341 Self {
342 traversal,
343 frontier,
344 visited,
345 direction,
346 }
347 }
348}
349
350impl Iterator for BfsIterator<'_> {
351 type Item = (Vid, u32);
352
353 fn next(&mut self) -> Option<Self::Item> {
354 let (current, distance) = self.frontier.pop_front()?;
355
356 for (neighbor, _eid) in self.traversal.neighbors(current, self.direction) {
358 if self.visited.insert(neighbor) {
359 self.frontier.push_back((neighbor, distance + 1));
360 }
361 }
362
363 Some((current, distance))
364 }
365}
366
367#[derive(Debug, Clone)]
369pub struct Path {
370 pub vertices: Vec<Vid>,
372 pub edges: Vec<Eid>,
374}
375
376impl Path {
377 pub fn len(&self) -> usize {
379 self.edges.len()
380 }
381
382 pub fn is_empty(&self) -> bool {
384 self.edges.is_empty()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
396 fn test_path_length() {
397 let path = Path {
398 vertices: vec![Vid::new(0), Vid::new(1), Vid::new(2)],
399 edges: vec![Eid::new(0), Eid::new(1)],
400 };
401
402 assert_eq!(path.len(), 2);
403 assert!(!path.is_empty());
404 }
405
406 #[test]
407 fn test_path_empty() {
408 let path = Path {
410 vertices: vec![Vid::new(0)],
411 edges: vec![],
412 };
413
414 assert_eq!(path.len(), 0);
415 assert!(path.is_empty());
416 }
417
418 #[test]
419 fn test_path_single_hop() {
420 let path = Path {
421 vertices: vec![Vid::new(0), Vid::new(1)],
422 edges: vec![Eid::new(0)],
423 };
424
425 assert_eq!(path.len(), 1);
426 assert!(!path.is_empty());
427 }
428
429 #[test]
433 fn test_hop_constraint_validation() {
434 fn is_valid_path_length(path_len: u32, min_hops: u32, max_hops: u32) -> bool {
436 path_len >= min_hops && path_len <= max_hops
437 }
438
439 assert!(is_valid_path_length(3, 1, 5));
441
442 assert!(is_valid_path_length(0, 0, 5));
444
445 assert!(!is_valid_path_length(0, 1, 5));
447
448 assert!(!is_valid_path_length(6, 1, 5));
450
451 assert!(is_valid_path_length(5, 5, 5));
453
454 assert!(!is_valid_path_length(3, 5, 2));
456 }
457
458 #[test]
459 fn test_hop_constraint_edge_cases() {
460 fn is_valid_path_length(path_len: u32, min_hops: u32, max_hops: u32) -> bool {
461 min_hops <= max_hops && path_len >= min_hops && path_len <= max_hops
462 }
463
464 assert!(is_valid_path_length(1000, 1, u32::MAX));
466
467 assert!(is_valid_path_length(0, 0, 10));
469
470 assert!(is_valid_path_length(1, 1, 1)); assert!(!is_valid_path_length(2, 1, 1)); assert!(!is_valid_path_length(0, 1, 1)); }
475}