Skip to main content

rustfst/algorithms/compose/
interval_reach_visitor.rs

1use crate::algorithms::compose::{IntInterval, IntervalSet};
2use crate::algorithms::dfs_visit::Visitor;
3use crate::fst_traits::Fst;
4use crate::semirings::Semiring;
5use crate::{StateId, Tr};
6use std::cmp::Ordering;
7
8static UNASSIGNED: usize = usize::MAX;
9
10pub struct IntervalReachVisitor<'a, F> {
11    fst: &'a F,
12    pub(crate) isets: Vec<IntervalSet>,
13    pub(crate) state2index: Vec<usize>,
14    index: usize,
15}
16
17impl<'a, F> IntervalReachVisitor<'a, F> {
18    pub fn new(fst: &'a F) -> Self {
19        Self {
20            fst,
21            isets: vec![],
22            state2index: vec![],
23            index: 1,
24        }
25    }
26}
27
28impl<'a, W: Semiring, F: Fst<W>> Visitor<'a, W, F> for IntervalReachVisitor<'a, F> {
29    /// Invoked before DFS visit.
30    fn init_visit(&mut self, _fst: &'a F) {}
31
32    /// Invoked when state discovered (2nd arg is DFS tree root).
33    fn init_state(&mut self, s: StateId, _root: StateId) -> bool {
34        while self.isets.len() <= (s as usize) {
35            self.isets.push(IntervalSet::default());
36        }
37        while self.state2index.len() <= (s as usize) {
38            self.state2index.push(UNASSIGNED);
39        }
40        if let Some(final_weight) = self.fst.final_weight(s).unwrap() {
41            if !final_weight.is_zero() {
42                let interval_set = &mut self.isets[s as usize];
43                if self.index == UNASSIGNED {
44                    if self.fst.num_trs(s).unwrap() > 0 {
45                        panic!("IntervalReachVisitor: state2index map must be empty for this FST")
46                    }
47                    let index = self.state2index[s as usize];
48                    if index == UNASSIGNED {
49                        panic!("IntervalReachVisitor: state2index map incomplete")
50                    }
51                    interval_set.push(IntInterval::new(index, index + 1));
52                } else {
53                    interval_set.push(IntInterval::new(self.index, self.index + 1));
54                    self.state2index[s as usize] = self.index;
55                    self.index += 1;
56                }
57            }
58        }
59        true
60    }
61
62    /// Invoked when tree transition to white/undiscovered state examined.
63    fn tree_tr(&mut self, _s: StateId, _tr: &Tr<W>) -> bool {
64        true
65    }
66
67    /// Invoked when back transition to grey/unfinished state examined.
68    fn back_tr(&mut self, _s: StateId, _tr: &Tr<W>) -> bool {
69        panic!("Cyclic input")
70    }
71
72    /// Invoked when forward or cross transition to black/finished state examined.
73    fn forward_or_cross_tr(&mut self, s: StateId, tr: &Tr<W>) -> bool {
74        union_vec_isets_unordered(&mut self.isets, s as usize, tr.nextstate as usize);
75        true
76    }
77
78    /// Invoked when state finished ('s' is tree root, 'parent' is kNoStateId,
79    /// and '_tr' is nullptr).
80    fn finish_state(&mut self, s: StateId, parent: Option<StateId>, _tr: Option<&Tr<W>>) {
81        if self.index != UNASSIGNED
82            && self.fst.is_final(s).unwrap()
83            && !self.fst.final_weight(s).unwrap().unwrap().is_zero()
84        {
85            let intervals = &mut self.isets[s as usize].intervals.intervals;
86            intervals[0].end = self.index;
87        }
88        self.isets[s as usize].normalize();
89        if let Some(p) = parent {
90            union_vec_isets_unordered(&mut self.isets, p as usize, s as usize);
91        }
92    }
93
94    /// Invoked after DFS visit.
95    fn finish_visit(&mut self) {}
96}
97
98// Perform the union of two IntervalSet stored in a vec. Utils to fix issue with borrow checker.
99fn union_vec_isets_unordered(isets: &mut [IntervalSet], i: usize, j: usize) {
100    debug_assert_ne!(i, j);
101    match i.cmp(&j) {
102        Ordering::Less => {
103            let (v_0_isupm1, v_isup1_end) = isets.split_at_mut(j);
104            v_0_isupm1[i].union(v_isup1_end[0].clone());
105        }
106        Ordering::Greater => {
107            let (v_0_jsupm1, v_jsup1_end) = isets.split_at_mut(i);
108            v_jsup1_end[0].union(v_0_jsupm1[j].clone());
109        }
110        Ordering::Equal => {
111            panic!("Unreachable code")
112        }
113    }
114}