1use fxhash::{FxHashMap, FxHashSet};
2use std::ops::ControlFlow;
3use uni_common::core::id::{Eid, Vid};
4
5use super::nfa::{NfaStateId, PathMode, PathSelector};
6
7#[derive(Debug, Clone)]
13pub struct PredRec {
14 pub src_vid: Vid,
15 pub src_state: NfaStateId,
16 pub eid: Eid,
17 pub next: i32,
19}
20
21pub struct PredecessorDag {
27 pred_pool: Vec<PredRec>,
29
30 pred_head: FxHashMap<(Vid, NfaStateId, u32), i32>,
33
34 first_depth: FxHashMap<(Vid, NfaStateId), u32>,
36
37 selector: PathSelector,
39}
40
41impl PredecessorDag {
42 pub fn new(selector: PathSelector) -> Self {
44 Self {
45 pred_pool: Vec::new(),
46 pred_head: FxHashMap::default(),
47 first_depth: FxHashMap::default(),
48 selector,
49 }
50 }
51
52 pub fn is_layered(&self) -> bool {
54 matches!(self.selector, PathSelector::All | PathSelector::Any)
55 }
56
57 #[allow(clippy::too_many_arguments)]
63 pub fn add_predecessor(
64 &mut self,
65 dst: Vid,
66 dst_state: NfaStateId,
67 src: Vid,
68 src_state: NfaStateId,
69 eid: Eid,
70 depth: u32,
71 ) {
72 let first = self.first_depth.entry((dst, dst_state)).or_insert(depth);
74 if depth < *first {
75 *first = depth;
76 }
77
78 if !self.is_layered() && depth > *self.first_depth.get(&(dst, dst_state)).unwrap() {
80 return;
81 }
82
83 let key = (dst, dst_state, depth);
85 let current_head = self.pred_head.get(&key).copied().unwrap_or(-1);
86
87 let new_idx = self.pred_pool.len() as i32;
89 self.pred_pool.push(PredRec {
90 src_vid: src,
91 src_state,
92 eid,
93 next: current_head,
94 });
95
96 self.pred_head.insert(key, new_idx);
98 }
99
100 #[allow(clippy::too_many_arguments)]
106 pub fn enumerate_paths<F>(
107 &self,
108 source: Vid,
109 target: Vid,
110 accepting_state: NfaStateId,
111 min_depth: u32,
112 max_depth: u32,
113 mode: &PathMode,
114 yield_path: &mut F,
115 ) where
116 F: FnMut(&[Vid], &[Eid]) -> ControlFlow<()>,
117 {
118 for depth in min_depth..=max_depth {
119 if depth == 0 {
121 if source == target && yield_path(&[source], &[]).is_break() {
122 return;
123 }
124 continue;
125 }
126
127 if !self
128 .pred_head
129 .contains_key(&(target, accepting_state, depth))
130 {
131 continue;
132 }
133
134 let mut nodes = Vec::with_capacity(depth as usize + 1);
135 let mut edges = Vec::with_capacity(depth as usize);
136 let mut node_set = FxHashSet::default();
137 let mut edge_set = FxHashSet::default();
138
139 nodes.push(target);
141 if matches!(mode, PathMode::Acyclic | PathMode::Simple) {
142 node_set.insert(target);
143 }
144
145 if self
146 .dfs_backward(
147 source,
148 target,
149 accepting_state,
150 depth,
151 &mut nodes,
152 &mut edges,
153 &mut node_set,
154 &mut edge_set,
155 mode,
156 yield_path,
157 )
158 .is_break()
159 {
160 return;
161 }
162 }
163 }
164
165 pub fn has_trail_valid_path(
169 &self,
170 source: Vid,
171 target: Vid,
172 accepting_state: NfaStateId,
173 min_depth: u32,
174 max_depth: u32,
175 ) -> bool {
176 let mut found = false;
177 self.enumerate_paths(
178 source,
179 target,
180 accepting_state,
181 min_depth,
182 max_depth,
183 &PathMode::Trail,
184 &mut |_nodes, _edges| {
185 found = true;
186 ControlFlow::Break(())
187 },
188 );
189 found
190 }
191
192 #[allow(clippy::too_many_arguments)]
194 fn dfs_backward<F>(
195 &self,
196 source: Vid,
197 current_vid: Vid,
198 current_state: NfaStateId,
199 remaining_depth: u32,
200 nodes: &mut Vec<Vid>,
201 edges: &mut Vec<Eid>,
202 node_set: &mut FxHashSet<Vid>,
203 edge_set: &mut FxHashSet<Eid>,
204 mode: &PathMode,
205 yield_path: &mut F,
206 ) -> ControlFlow<()>
207 where
208 F: FnMut(&[Vid], &[Eid]) -> ControlFlow<()>,
209 {
210 if remaining_depth == 0 {
211 if current_vid == source {
212 let fwd_nodes: Vec<Vid> = nodes.iter().rev().copied().collect();
214 let fwd_edges: Vec<Eid> = edges.iter().rev().copied().collect();
215 return yield_path(&fwd_nodes, &fwd_edges);
216 }
217 return ControlFlow::Continue(());
218 }
219
220 let key = (current_vid, current_state, remaining_depth);
221 let Some(&head) = self.pred_head.get(&key) else {
222 return ControlFlow::Continue(());
223 };
224
225 let mut idx = head;
226 while idx >= 0 {
227 let pred = &self.pred_pool[idx as usize];
228
229 let skip = match mode {
231 PathMode::Walk => false,
232 PathMode::Trail => edge_set.contains(&pred.eid),
233 PathMode::Acyclic => node_set.contains(&pred.src_vid),
234 PathMode::Simple => {
235 node_set.contains(&pred.src_vid)
237 && !(remaining_depth == 1 && pred.src_vid == source)
238 }
239 };
240
241 if skip {
242 idx = pred.next;
243 continue;
244 }
245
246 nodes.push(pred.src_vid);
248 edges.push(pred.eid);
249
250 if matches!(mode, PathMode::Trail) {
251 edge_set.insert(pred.eid);
252 }
253 if matches!(mode, PathMode::Acyclic | PathMode::Simple) {
254 node_set.insert(pred.src_vid);
255 }
256
257 let result = self.dfs_backward(
259 source,
260 pred.src_vid,
261 pred.src_state,
262 remaining_depth - 1,
263 nodes,
264 edges,
265 node_set,
266 edge_set,
267 mode,
268 yield_path,
269 );
270
271 nodes.pop();
273 edges.pop();
274
275 if matches!(mode, PathMode::Trail) {
276 edge_set.remove(&pred.eid);
277 }
278 if matches!(mode, PathMode::Acyclic | PathMode::Simple) {
279 node_set.remove(&pred.src_vid);
280 }
281
282 if result.is_break() {
283 return ControlFlow::Break(());
284 }
285
286 idx = pred.next;
287 }
288
289 ControlFlow::Continue(())
290 }
291
292 pub fn pool_len(&self) -> usize {
294 self.pred_pool.len()
295 }
296
297 pub fn first_depth_of(&self, vid: Vid, state: NfaStateId) -> Option<u32> {
299 self.first_depth.get(&(vid, state)).copied()
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 fn vid(n: u64) -> Vid {
308 Vid::new(n)
309 }
310 fn eid(n: u64) -> Eid {
311 Eid::new(n)
312 }
313
314 fn collect_paths(
316 dag: &PredecessorDag,
317 source: Vid,
318 target: Vid,
319 accepting_state: NfaStateId,
320 min_depth: u32,
321 max_depth: u32,
322 mode: &PathMode,
323 ) -> Vec<(Vec<Vid>, Vec<Eid>)> {
324 let mut paths = Vec::new();
325 dag.enumerate_paths(
326 source,
327 target,
328 accepting_state,
329 min_depth,
330 max_depth,
331 mode,
332 &mut |nodes, edges| {
333 paths.push((nodes.to_vec(), edges.to_vec()));
334 ControlFlow::Continue(())
335 },
336 );
337 paths
338 }
339
340 #[test]
343 fn test_pred_dag_add_single() {
344 let mut dag = PredecessorDag::new(PathSelector::All);
345 dag.add_predecessor(vid(2), 1, vid(1), 0, eid(10), 1);
346 assert_eq!(dag.pool_len(), 1);
347 assert!(dag.pred_head.contains_key(&(vid(2), 1, 1)));
348 }
349
350 #[test]
351 fn test_pred_dag_add_chain() {
352 let mut dag = PredecessorDag::new(PathSelector::All);
354 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
355 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(11), 2);
356 assert_eq!(dag.pool_len(), 2);
357 assert!(dag.pred_head.contains_key(&(vid(1), 1, 1)));
358 assert!(dag.pred_head.contains_key(&(vid(2), 2, 2)));
359 }
360
361 #[test]
362 fn test_pred_dag_multiple_preds() {
363 let mut dag = PredecessorDag::new(PathSelector::All);
365 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 1);
366 dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 1);
367 assert_eq!(dag.pool_len(), 2);
368 let head = dag.pred_head[&(vid(2), 1, 1)];
370 assert!(head >= 0);
371 let first = &dag.pred_pool[head as usize];
372 assert!(first.next >= 0); }
374
375 #[test]
376 fn test_pred_dag_first_depth() {
377 let mut dag = PredecessorDag::new(PathSelector::All);
378 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 3);
379 assert_eq!(dag.first_depth_of(vid(2), 1), Some(3));
380
381 dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 2);
382 assert_eq!(dag.first_depth_of(vid(2), 1), Some(2));
383
384 dag.add_predecessor(vid(2), 1, vid(3), 0, eid(12), 5);
386 assert_eq!(dag.first_depth_of(vid(2), 1), Some(2));
387 }
388
389 #[test]
392 fn test_pred_dag_layered_stores_all() {
393 let mut dag = PredecessorDag::new(PathSelector::All);
394 assert!(dag.is_layered());
395
396 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 2);
398 dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 3);
399 assert_eq!(dag.pool_len(), 2);
400 assert!(dag.pred_head.contains_key(&(vid(2), 1, 2)));
401 assert!(dag.pred_head.contains_key(&(vid(2), 1, 3)));
402 }
403
404 #[test]
405 fn test_pred_dag_shortest_skips() {
406 let mut dag = PredecessorDag::new(PathSelector::AnyShortest);
407 assert!(!dag.is_layered());
408
409 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 2);
411 assert_eq!(dag.pool_len(), 1);
412
413 dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 3);
415 assert_eq!(dag.pool_len(), 1); dag.add_predecessor(vid(2), 1, vid(3), 0, eid(12), 2);
419 assert_eq!(dag.pool_len(), 2);
420 }
421
422 #[test]
423 fn test_pred_dag_selector_switch() {
424 let build = |selector: PathSelector| -> usize {
426 let mut dag = PredecessorDag::new(selector);
427 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 2);
428 dag.add_predecessor(vid(2), 1, vid(1), 0, eid(11), 3);
429 dag.add_predecessor(vid(2), 1, vid(3), 0, eid(12), 4);
430 dag.pool_len()
431 };
432
433 assert_eq!(build(PathSelector::All), 3); assert_eq!(build(PathSelector::AnyShortest), 1); }
436
437 #[test]
440 fn test_pred_dag_linear_walk() {
441 let mut dag = PredecessorDag::new(PathSelector::All);
443 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
444 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(11), 2);
445
446 let paths = collect_paths(&dag, vid(0), vid(2), 2, 2, 2, &PathMode::Walk);
447 assert_eq!(paths.len(), 1);
448 assert_eq!(paths[0].0, vec![vid(0), vid(1), vid(2)]);
449 assert_eq!(paths[0].1, vec![eid(10), eid(11)]);
450 }
451
452 #[test]
453 fn test_pred_dag_diamond_walk() {
454 let mut dag = PredecessorDag::new(PathSelector::All);
456 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
458 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(11), 1);
459 dag.add_predecessor(vid(3), 2, vid(1), 1, eid(12), 2);
461 dag.add_predecessor(vid(3), 2, vid(2), 1, eid(13), 2);
462
463 let paths = collect_paths(&dag, vid(0), vid(3), 2, 2, 2, &PathMode::Walk);
464 assert_eq!(paths.len(), 2);
465
466 let mut sorted: Vec<_> = paths.iter().map(|(n, _)| n.clone()).collect();
467 sorted.sort();
468 assert!(sorted.contains(&vec![vid(0), vid(1), vid(3)]));
469 assert!(sorted.contains(&vec![vid(0), vid(2), vid(3)]));
470 }
471
472 #[test]
473 fn test_pred_dag_multiple_depths() {
474 let mut dag = PredecessorDag::new(PathSelector::All);
477 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(10), 1);
479 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(11), 1);
481 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(12), 2);
482
483 let paths1 = collect_paths(&dag, vid(0), vid(2), 1, 1, 1, &PathMode::Walk);
485 assert_eq!(paths1.len(), 1);
486 assert_eq!(paths1[0].0, vec![vid(0), vid(2)]); let paths2 = collect_paths(&dag, vid(0), vid(2), 2, 2, 2, &PathMode::Walk);
490 assert_eq!(paths2.len(), 1);
491 assert_eq!(paths2[0].0, vec![vid(0), vid(1), vid(2)]); assert_eq!(paths1.len() + paths2.len(), 2);
495 }
496
497 #[test]
498 fn test_pred_dag_fan_out() {
499 let mut dag = PredecessorDag::new(PathSelector::All);
501 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
502 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(11), 1);
503 dag.add_predecessor(vid(3), 1, vid(0), 0, eid(12), 1);
504 dag.add_predecessor(vid(4), 2, vid(1), 1, eid(13), 2);
505 dag.add_predecessor(vid(4), 2, vid(2), 1, eid(14), 2);
506 dag.add_predecessor(vid(4), 2, vid(3), 1, eid(15), 2);
507
508 let paths = collect_paths(&dag, vid(0), vid(4), 2, 2, 2, &PathMode::Walk);
509 assert_eq!(paths.len(), 3);
510 }
511
512 #[test]
515 fn test_pred_dag_trail_no_repeat() {
516 let mut dag = PredecessorDag::new(PathSelector::All);
519 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
521 dag.add_predecessor(vid(0), 2, vid(1), 1, eid(2), 2);
523 dag.add_predecessor(vid(1), 3, vid(0), 2, eid(1), 3);
525
526 let walk_paths = collect_paths(&dag, vid(0), vid(1), 3, 3, 3, &PathMode::Walk);
528 assert_eq!(walk_paths.len(), 1);
529 assert_eq!(walk_paths[0].1, vec![eid(1), eid(2), eid(1)]);
530
531 let trail_paths = collect_paths(&dag, vid(0), vid(1), 3, 3, 3, &PathMode::Trail);
533 assert_eq!(trail_paths.len(), 0);
534 }
535
536 #[test]
537 fn test_pred_dag_trail_allows_node_repeat() {
538 let mut dag = PredecessorDag::new(PathSelector::All);
541 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
542 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(2), 2);
543 dag.add_predecessor(vid(1), 3, vid(2), 2, eid(3), 3);
544
545 let paths = collect_paths(&dag, vid(0), vid(1), 3, 3, 3, &PathMode::Trail);
546 assert_eq!(paths.len(), 1);
547 assert_eq!(paths[0].0, vec![vid(0), vid(1), vid(2), vid(1)]);
548 assert_eq!(paths[0].1, vec![eid(1), eid(2), eid(3)]);
549 }
550
551 #[test]
552 fn test_pred_dag_trail_diamond() {
553 let mut dag = PredecessorDag::new(PathSelector::All);
555 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
556 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(11), 1);
557 dag.add_predecessor(vid(3), 2, vid(1), 1, eid(12), 2);
558 dag.add_predecessor(vid(3), 2, vid(2), 1, eid(13), 2);
559
560 let paths = collect_paths(&dag, vid(0), vid(3), 2, 2, 2, &PathMode::Trail);
562 assert_eq!(paths.len(), 2);
563 }
564
565 #[test]
566 fn test_pred_dag_trail_cycle_2_hop() {
567 let mut dag = PredecessorDag::new(PathSelector::All);
572 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
573 dag.add_predecessor(vid(0), 2, vid(1), 1, eid(2), 2);
574
575 let paths = collect_paths(&dag, vid(0), vid(0), 2, 2, 2, &PathMode::Trail);
576 assert_eq!(paths.len(), 1);
577 assert_eq!(paths[0].0, vec![vid(0), vid(1), vid(0)]);
578 assert_eq!(paths[0].1, vec![eid(1), eid(2)]);
579 }
580
581 #[test]
584 fn test_pred_dag_acyclic_filter() {
585 let mut dag = PredecessorDag::new(PathSelector::All);
587 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
588 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(2), 2);
589 dag.add_predecessor(vid(0), 3, vid(2), 2, eid(3), 3);
590
591 let walk_paths = collect_paths(&dag, vid(0), vid(0), 3, 3, 3, &PathMode::Walk);
593 assert_eq!(walk_paths.len(), 1);
594
595 let acyclic_paths = collect_paths(&dag, vid(0), vid(0), 3, 3, 3, &PathMode::Acyclic);
597 assert_eq!(acyclic_paths.len(), 0);
598 }
599
600 #[test]
601 fn test_pred_dag_acyclic_diamond() {
602 let mut dag = PredecessorDag::new(PathSelector::All);
604 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
605 dag.add_predecessor(vid(2), 1, vid(0), 0, eid(11), 1);
606 dag.add_predecessor(vid(3), 2, vid(1), 1, eid(12), 2);
607 dag.add_predecessor(vid(3), 2, vid(2), 1, eid(13), 2);
608
609 let paths = collect_paths(&dag, vid(0), vid(3), 2, 2, 2, &PathMode::Acyclic);
610 assert_eq!(paths.len(), 2);
611 }
612
613 #[test]
616 fn test_has_trail_valid_true() {
617 let mut dag = PredecessorDag::new(PathSelector::All);
619 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
620 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(11), 2);
621
622 assert!(dag.has_trail_valid_path(vid(0), vid(2), 2, 2, 2));
623 }
624
625 #[test]
626 fn test_has_trail_valid_false() {
627 let mut dag = PredecessorDag::new(PathSelector::All);
629 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
630 dag.add_predecessor(vid(0), 2, vid(1), 1, eid(2), 2);
631 dag.add_predecessor(vid(1), 3, vid(0), 2, eid(1), 3); assert!(!dag.has_trail_valid_path(vid(0), vid(1), 3, 3, 3));
634 }
635
636 #[test]
637 fn test_has_trail_valid_one_of_many() {
638 let mut dag = PredecessorDag::new(PathSelector::All);
640 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(1), 1);
642 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(2), 2);
643 dag.add_predecessor(vid(3), 1, vid(0), 0, eid(3), 1);
645 dag.add_predecessor(vid(2), 2, vid(3), 1, eid(4), 2);
646
647 assert!(dag.has_trail_valid_path(vid(0), vid(2), 2, 2, 2));
649 }
650
651 #[test]
654 fn test_pred_dag_early_stop() {
655 let mut dag = PredecessorDag::new(PathSelector::All);
658 for i in 1..=10u64 {
659 dag.add_predecessor(Vid::new(i), 1, vid(0), 0, Eid::new(i), 1);
660 dag.add_predecessor(vid(99), 2, Vid::new(i), 1, Eid::new(100 + i), 2);
661 }
662
663 let mut count = 0;
664 dag.enumerate_paths(
665 vid(0),
666 vid(99),
667 2,
668 2,
669 2,
670 &PathMode::Walk,
671 &mut |_nodes, _edges| {
672 count += 1;
673 if count >= 3 {
674 ControlFlow::Break(())
675 } else {
676 ControlFlow::Continue(())
677 }
678 },
679 );
680 assert_eq!(count, 3); }
682
683 #[test]
684 fn test_pred_dag_empty_enumerate() {
685 let dag = PredecessorDag::new(PathSelector::All);
687 let paths = collect_paths(&dag, vid(0), vid(1), 0, 1, 5, &PathMode::Walk);
688 assert!(paths.is_empty());
689 }
690
691 #[test]
692 fn test_pred_dag_zero_length() {
693 let dag = PredecessorDag::new(PathSelector::All);
695 let paths = collect_paths(&dag, vid(5), vid(5), 0, 0, 0, &PathMode::Walk);
696 assert_eq!(paths.len(), 1);
697 assert_eq!(paths[0].0, vec![vid(5)]);
698 assert!(paths[0].1.is_empty());
699 }
700
701 #[test]
704 fn test_pred_dag_path_order() {
705 let mut dag = PredecessorDag::new(PathSelector::All);
708 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(10), 1);
709 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(20), 2);
710 dag.add_predecessor(vid(3), 3, vid(2), 2, eid(30), 3);
711
712 let paths = collect_paths(&dag, vid(0), vid(3), 3, 3, 3, &PathMode::Walk);
713 assert_eq!(paths.len(), 1);
714 assert_eq!(paths[0].0, vec![vid(0), vid(1), vid(2), vid(3)]);
716 assert_eq!(paths[0].1, vec![eid(10), eid(20), eid(30)]);
717 }
718
719 #[test]
720 fn test_pred_dag_eid_in_path() {
721 let mut dag = PredecessorDag::new(PathSelector::All);
724 dag.add_predecessor(vid(1), 1, vid(0), 0, eid(100), 1);
725 dag.add_predecessor(vid(2), 2, vid(1), 1, eid(200), 2);
726 dag.add_predecessor(vid(3), 3, vid(2), 2, eid(300), 3);
727
728 let paths = collect_paths(&dag, vid(0), vid(3), 3, 3, 3, &PathMode::Walk);
729 assert_eq!(paths.len(), 1);
730
731 let (nodes, edges) = &paths[0];
733 assert_eq!(nodes.len(), edges.len() + 1);
734 assert_eq!(edges[0], eid(100)); assert_eq!(edges[1], eid(200)); assert_eq!(edges[2], eid(300)); }
738}