rustfst/
fst_path.rs

1use std::collections::VecDeque;
2use std::hash::{Hash, Hasher};
3
4use anyhow::Result;
5
6use crate::fst_traits::Fst;
7use crate::semirings::Semiring;
8use crate::{Label, StateId, Trs, EPS_LABEL};
9
10/// Structure representing a path in a FST
11/// (list of input labels, list of output labels and total weight).
12#[derive(PartialEq, Debug, Clone, PartialOrd)]
13pub struct FstPath<W: Semiring> {
14    /// List of input labels.
15    pub ilabels: Vec<Label>,
16    /// List of output labels.
17    pub olabels: Vec<Label>,
18    /// Total weight of the path computed by multiplying the weight of each transition.
19    pub weight: W,
20}
21
22impl<W: Semiring> FstPath<W> {
23    /// Creates a new Path.
24    pub fn new(ilabels: Vec<Label>, olabels: Vec<Label>, weight: W) -> Self {
25        FstPath {
26            ilabels,
27            olabels,
28            weight,
29        }
30    }
31
32    /// Adds the content of an FST transition to the Path.
33    /// Labels are added at the end of the corresponding vectors and the weight
34    /// is multiplied by the total weight already stored in the Path.
35    pub fn add_to_path(&mut self, ilabel: Label, olabel: Label, weight: &W) -> Result<()> {
36        if ilabel != EPS_LABEL {
37            self.ilabels.push(ilabel);
38        }
39
40        if olabel != EPS_LABEL {
41            self.olabels.push(olabel);
42        }
43
44        self.weight.times_assign(weight)
45    }
46
47    /// Add a single weight to the Path by multiplying the weight by the total weight of the path.
48    pub fn add_weight(&mut self, weight: &W) -> Result<()> {
49        self.weight.times_assign(weight)
50    }
51
52    /// Append a Path to the current Path. Labels are appended and weights multiplied.
53    pub fn concat(&mut self, other: FstPath<W>) -> Result<()> {
54        self.ilabels.extend(other.ilabels);
55        self.olabels.extend(other.olabels);
56        self.weight.times_assign(other.weight)
57    }
58
59    pub fn is_empty(&self) -> bool {
60        self.ilabels.is_empty() && self.olabels.is_empty() && self.weight.is_one()
61    }
62}
63
64impl<W: Semiring> Default for FstPath<W> {
65    /// Creates an empty path with a weight one.
66    fn default() -> Self {
67        FstPath {
68            ilabels: vec![],
69            olabels: vec![],
70            weight: W::one(),
71        }
72    }
73}
74
75#[allow(clippy::derived_hash_with_manual_eq)]
76impl<W: Semiring + Hash + Eq> Hash for FstPath<W> {
77    fn hash<H: Hasher>(&self, state: &mut H) {
78        self.ilabels.hash(state);
79        self.olabels.hash(state);
80        self.weight.hash(state);
81    }
82}
83
84impl<W: Semiring + Hash + Eq> Eq for FstPath<W> {}
85
86struct BfsState<W: Semiring> {
87    state: StateId,
88    weight_curr: W,
89    next_ilabel_idx: StateId,
90    next_olabel_idx: StateId,
91}
92
93/// Check if a FstPath can be generated by the given Fst.
94///
95/// Be careful with this function, it will loop forever if the path is not present
96/// in the Fst and the Fst is cyclic.
97pub fn check_path_in_fst<W: Semiring, F: Fst<W>>(fst: &F, fst_path: &FstPath<W>) -> bool {
98    if let Some(start) = fst.start() {
99        let mut queue = VecDeque::new();
100        queue.push_back(BfsState {
101            state: start,
102            next_ilabel_idx: 0,
103            next_olabel_idx: 0,
104            weight_curr: W::one(),
105        });
106
107        while !queue.is_empty() {
108            let lol = queue.pop_front().unwrap();
109            let state = lol.state;
110            let next_ilabel_idx = lol.next_ilabel_idx as usize;
111            let next_olabel_idx = lol.next_olabel_idx as usize;
112            let weight_curr = lol.weight_curr;
113
114            if next_ilabel_idx >= fst_path.ilabels.len()
115                && next_olabel_idx >= fst_path.olabels.len()
116            {
117                // No more labels left
118                if let Some(final_weight) = unsafe { fst.final_weight_unchecked(state) } {
119                    if weight_curr.times(final_weight).unwrap() == fst_path.weight {
120                        return true;
121                    }
122                }
123            }
124
125            for tr in unsafe { fst.get_trs_unchecked(state) }.trs() {
126                let match_ilabel = next_ilabel_idx < fst_path.ilabels.len()
127                    && tr.ilabel == fst_path.ilabels[next_ilabel_idx];
128                let match_olabel = next_olabel_idx < fst_path.olabels.len()
129                    && tr.olabel == fst_path.olabels[next_ilabel_idx];
130                let (new_next_ilabel_idx, new_next_olabel_idx) =
131                    if tr.ilabel == EPS_LABEL && tr.olabel == EPS_LABEL {
132                        (next_ilabel_idx, next_olabel_idx)
133                    } else if tr.ilabel != EPS_LABEL && tr.olabel == EPS_LABEL {
134                        if match_ilabel {
135                            (next_ilabel_idx + 1, next_olabel_idx)
136                        } else {
137                            continue;
138                        }
139                    } else if tr.ilabel == EPS_LABEL && tr.olabel != EPS_LABEL {
140                        if match_olabel {
141                            (next_ilabel_idx, next_olabel_idx + 1)
142                        } else {
143                            continue;
144                        }
145                    } else if match_ilabel && match_olabel {
146                        (next_ilabel_idx + 1, next_olabel_idx + 1)
147                    } else {
148                        continue;
149                    };
150                queue.push_back(BfsState {
151                    state: tr.nextstate,
152                    next_ilabel_idx: new_next_ilabel_idx as Label,
153                    next_olabel_idx: new_next_olabel_idx as Label,
154                    weight_curr: weight_curr.times(&tr.weight).unwrap(),
155                })
156            }
157        }
158
159        false
160    } else {
161        fst_path.is_empty()
162    }
163}
164
165/// Creates a Path containing the arguments.
166///
167/// There are multiple forms to this macro :
168///
169/// - Create an unweighted acceptor path :
170///
171/// ```
172/// # #[macro_use] extern crate rustfst; fn main() {
173/// # use rustfst::semirings::{IntegerWeight, Semiring};
174/// # use rustfst::FstPath;
175/// let path : FstPath<IntegerWeight> = fst_path![1,2,3];
176/// assert_eq!(path.ilabels, vec![1,2,3]);
177/// assert_eq!(path.olabels, vec![1,2,3]);
178/// assert_eq!(path.weight, IntegerWeight::one());
179/// # }
180/// ```
181///
182/// - Create an unweighted transducer path :
183///
184/// ```
185/// # #[macro_use] extern crate rustfst; fn main() {
186/// # use rustfst::semirings::{IntegerWeight, Semiring};
187/// # use rustfst::FstPath;
188/// let path : FstPath<IntegerWeight> = fst_path![1,2,3 => 1,2,4];
189/// assert_eq!(path.ilabels, vec![1,2,3]);
190/// assert_eq!(path.olabels, vec![1,2,4]);
191/// assert_eq!(path.weight, IntegerWeight::one());
192/// # }
193/// ```
194///
195/// - Create a weighted acceptor path :
196///
197/// ```
198/// # #[macro_use] extern crate rustfst; fn main() {
199/// # use rustfst::semirings::{IntegerWeight, Semiring};
200/// # use rustfst::FstPath;
201/// let path : FstPath<IntegerWeight> = fst_path![1,2,3; 18];
202/// assert_eq!(path.ilabels, vec![1,2,3]);
203/// assert_eq!(path.olabels, vec![1,2,3]);
204/// assert_eq!(path.weight, IntegerWeight::new(18));
205/// # }
206/// ```
207///
208/// - Create a weighted transducer path :
209///
210/// ```
211/// # #[macro_use] extern crate rustfst; fn main() {
212/// # use rustfst::semirings::{IntegerWeight, Semiring};
213/// # use rustfst::FstPath;
214/// let path : FstPath<IntegerWeight> = fst_path![1,2,3 => 1,2,4; 18];
215/// assert_eq!(path.ilabels, vec![1,2,3]);
216/// assert_eq!(path.olabels, vec![1,2,4]);
217/// assert_eq!(path.weight, IntegerWeight::new(18));
218/// # }
219/// ```
220///
221#[macro_export]
222macro_rules! fst_path {
223    ( $( $x:expr ),*) => {
224        {
225            fn semiring_one<W: Semiring>() -> W {
226                W::one()
227            }
228            FstPath::new(
229                vec![$($x),*],
230                vec![$($x),*],
231                semiring_one()
232            )
233        }
234    };
235    ( $( $x:expr ),* => $( $y:expr ),* ) => {
236        {
237            fn semiring_one<W: Semiring>() -> W {
238                W::one()
239            }
240            FstPath::new(
241                vec![$($x),*],
242                vec![$($y),*],
243                semiring_one()
244            )
245        }
246    };
247    ( $( $x:expr ),* ; $weight:expr) => {
248        {
249            fn semiring_new<W: Semiring>(v: W::Type) -> W {
250                W::new(v)
251            }
252            FstPath::new(
253                vec![$($x),*],
254                vec![$($x),*],
255                semiring_new($weight)
256            )
257        }
258    };
259    ( $( $x:expr ),* => $( $y:expr ),* ; $weight:expr) => {
260        {
261            fn semiring_new<W: Semiring>(v: W::Type) -> W {
262                W::new(v)
263            }
264            FstPath::new(
265                vec![$($x),*],
266                vec![$($y),*],
267                semiring_new($weight)
268            )
269        }
270    };
271}
272
273#[cfg(test)]
274mod test {
275    use super::*;
276    use crate::fst_impls::VectorFst;
277    use crate::fst_traits::MutableFst;
278    use crate::semirings::TropicalWeight;
279
280    #[test]
281    fn test_check_path_in_fst() -> Result<()> {
282        let mut fst = VectorFst::<TropicalWeight>::new();
283        fst.add_states(3);
284        fst.set_start(0)?;
285        fst.emplace_tr(0, 1, 2, 1.2, 1)?;
286        fst.emplace_tr(0, 4, 6, 1.1, 1)?;
287        fst.emplace_tr(1, 2, 3, 0.3, 2)?;
288        fst.emplace_tr(1, 6, 7, 0.5, 2)?;
289        fst.emplace_tr(0, 10, 12, 3.0, 2)?;
290        fst.set_final(2, 3.2)?;
291
292        assert!(!check_path_in_fst(
293            &fst,
294            &FstPath::new(vec![], vec![], TropicalWeight::one())
295        ));
296        assert!(!check_path_in_fst(
297            &fst,
298            &FstPath::new(vec![1], vec![2], TropicalWeight::new(1.2))
299        ));
300        assert!(!check_path_in_fst(
301            &fst,
302            &FstPath::new(vec![1, 2], vec![2, 3], TropicalWeight::new(1.5))
303        ));
304        assert!(check_path_in_fst(
305            &fst,
306            &FstPath::new(vec![1, 2], vec![2, 3], TropicalWeight::new(4.7))
307        ));
308        assert!(!check_path_in_fst(
309            &fst,
310            &FstPath::new(vec![10], vec![10], TropicalWeight::new(3.0))
311        ));
312        assert!(!check_path_in_fst(
313            &fst,
314            &FstPath::new(vec![12], vec![12], TropicalWeight::new(6.2))
315        ));
316        assert!(!check_path_in_fst(
317            &fst,
318            &FstPath::new(vec![10], vec![10], TropicalWeight::new(6.2))
319        ));
320        assert!(check_path_in_fst(
321            &fst,
322            &FstPath::new(vec![10], vec![12], TropicalWeight::new(6.2))
323        ));
324
325        Ok(())
326    }
327}