rustfst/algorithms/
isomorphic.rs

1use std::cmp::Ordering;
2use std::collections::VecDeque;
3
4use anyhow::Result;
5
6use crate::fst_traits::ExpandedFst;
7use crate::semirings::Semiring;
8use crate::{StateId, Tr, Trs, KDELTA};
9use std::marker::PhantomData;
10
11struct Isomorphism<'a, W: Semiring, F1: ExpandedFst<W>, F2: ExpandedFst<W>> {
12    fst_1: &'a F1,
13    fst_2: &'a F2,
14    state_pairs: Vec<Option<StateId>>,
15    queue: VecDeque<(StateId, StateId)>,
16    w: PhantomData<W>,
17    delta: f32,
18    non_det: bool,
19}
20
21/// Compare trs in the order input label, output label, weight and nextstate.
22pub fn tr_compare<W: Semiring>(tr_1: &Tr<W>, tr_2: &Tr<W>) -> Ordering {
23    if tr_1.ilabel < tr_2.ilabel {
24        return Ordering::Less;
25    }
26    if tr_1.ilabel > tr_2.ilabel {
27        return Ordering::Greater;
28    }
29    if tr_1.olabel < tr_2.olabel {
30        return Ordering::Less;
31    }
32    if tr_1.olabel > tr_2.olabel {
33        return Ordering::Greater;
34    }
35    if tr_1.weight < tr_2.weight {
36        return Ordering::Less;
37    }
38    if tr_1.weight > tr_2.weight {
39        return Ordering::Greater;
40    }
41    if tr_1.nextstate < tr_2.nextstate {
42        return Ordering::Less;
43    }
44    if tr_1.nextstate > tr_2.nextstate {
45        return Ordering::Greater;
46    }
47    Ordering::Equal
48}
49
50impl<'a, W: Semiring, F1: ExpandedFst<W>, F2: ExpandedFst<W>> Isomorphism<'a, W, F1, F2> {
51    fn new(fst_1: &'a F1, fst_2: &'a F2, delta: f32) -> Self {
52        Self {
53            fst_1,
54            fst_2,
55            state_pairs: vec![None; fst_1.num_states()],
56            queue: VecDeque::new(),
57            w: PhantomData,
58            delta,
59            non_det: false,
60        }
61    }
62
63    // Maintains state correspondences and queue.
64    fn pair_state(&mut self, s1: StateId, s2: StateId) -> bool {
65        if self.state_pairs[s1 as usize] == Some(s2) {
66            return true; // already seen this pair
67        } else if self.state_pairs[s1 as usize].is_some() {
68            return false; // s1 already paired with another s2
69        }
70        self.state_pairs[s1 as usize] = Some(s2);
71        self.queue.push_back((s1, s2));
72        true
73    }
74
75    fn ismorphic_state(&mut self, s1: StateId, s2: StateId) -> Result<bool> {
76        let fw1 = self.fst_1.final_weight(s1)?;
77        let fw2 = self.fst_2.final_weight(s2)?;
78        let fw_equal = match (fw1, fw2) {
79            (Some(w1), Some(w2)) => w1.approx_equal(w2, self.delta),
80            (Some(_), None) => false,
81            (None, Some(_)) => false,
82            (None, None) => true,
83        };
84        if !fw_equal {
85            return Ok(false);
86        }
87
88        let ntrs1 = self.fst_1.num_trs(s1)?;
89        let ntrs2 = self.fst_2.num_trs(s2)?;
90
91        if ntrs1 != ntrs2 {
92            return Ok(false);
93        }
94
95        let trs1_owner = self.fst_1.get_trs(s1)?;
96        let mut trs1: Vec<_> = trs1_owner.trs().iter().collect();
97        let trs2_owner = self.fst_2.get_trs(s2)?;
98        let mut trs2: Vec<_> = trs2_owner.trs().iter().collect();
99
100        trs1.sort_by(|a, b| tr_compare(a, b));
101        trs2.sort_by(|a, b| tr_compare(a, b));
102
103        for i in 0..trs1.len() {
104            let arc1 = trs1[i];
105            let arc2 = trs2[i];
106            if arc1.ilabel != arc2.ilabel {
107                return Ok(false);
108            }
109            if arc1.olabel != arc2.olabel {
110                return Ok(false);
111            }
112            if !(arc1.weight.approx_equal(&arc2.weight, self.delta)) {
113                return Ok(false);
114            }
115            if !(self.pair_state(arc1.nextstate, arc2.nextstate)) {
116                return Ok(false);
117            }
118            if i > 0 {
119                let arc0 = trs1[i - 1];
120                if arc1.ilabel == arc0.ilabel
121                    && arc1.olabel == arc0.olabel
122                    && arc1.weight.approx_equal(&arc0.weight, self.delta)
123                {
124                    // Any subsequent matching failure maybe a false negative
125                    // since we only consider one permutation when pairing destination
126                    // states of nondeterministic transitions.
127                    self.non_det = true;
128                }
129            }
130        }
131        Ok(true)
132    }
133
134    fn isomorphic(&mut self) -> Result<bool> {
135        // Both FSTs don't have a start state => both don't recognize anything
136        if self.fst_1.start().is_none() && self.fst_2.start().is_none() {
137            return Ok(true);
138        }
139
140        // Only one FST has a start state => false
141        if self.fst_1.start().is_none() || self.fst_2.start().is_none() {
142            return Ok(false);
143        }
144
145        self.pair_state(self.fst_1.start().unwrap(), self.fst_2.start().unwrap());
146
147        while !self.queue.is_empty() {
148            let (s1, s2) = self.queue.pop_front().unwrap();
149            if !self.ismorphic_state(s1, s2)? {
150                if self.non_det {
151                    bail!("Isomorphic: Non-determinism as an unweighted automaton. state1 = {} state2 = {}", s1, s2)
152                }
153                return Ok(false);
154            }
155        }
156
157        Ok(true)
158    }
159}
160
161/// Configuration for isomorphic comparison.
162pub struct IsomorphicConfig {
163    delta: f32,
164}
165
166impl Default for IsomorphicConfig {
167    fn default() -> Self {
168        Self { delta: KDELTA }
169    }
170}
171
172impl IsomorphicConfig {
173    pub fn new(delta: f32) -> Self {
174        Self { delta }
175    }
176}
177
178/// Determine if two transducers with a certain required determinism
179/// have the same states, irrespective of numbering, and the same transitions with
180/// the same labels and weights, irrespective of ordering.
181///
182/// In other words, Isomorphic(A, B) is true if and only if the states of A can
183/// be renumbered and the transitions leaving each state reordered so that Equal(A, B) is true.
184pub fn isomorphic<W, F1, F2>(fst_1: &F1, fst_2: &F2) -> Result<bool>
185where
186    W: Semiring,
187    F1: ExpandedFst<W>,
188    F2: ExpandedFst<W>,
189{
190    isomorphic_with_config(fst_1, fst_2, IsomorphicConfig::default())
191}
192
193/// Determine, with configurable comparison delta, if two transducers with a
194/// certain required determinism have the same states, irrespective of
195/// numbering, and the same transitions with the same labels and
196/// weights, irrespective of ordering.
197///
198/// In other words, Isomorphic(A, B) is true if and only if the states of A can
199/// be renumbered and the transitions leaving each state reordered so that Equal(A, B) is true.
200pub fn isomorphic_with_config<W, F1, F2>(
201    fst_1: &F1,
202    fst_2: &F2,
203    config: IsomorphicConfig,
204) -> Result<bool>
205where
206    W: Semiring,
207    F1: ExpandedFst<W>,
208    F2: ExpandedFst<W>,
209{
210    let mut iso = Isomorphism::new(fst_1, fst_2, config.delta);
211    iso.isomorphic()
212}
213
214#[cfg(test)]
215mod test {
216
217    use super::*;
218
219    use crate::fst_impls::VectorFst;
220    use crate::fst_traits::{MutableFst, SerializableFst};
221    use crate::semirings::{LogWeight, Semiring};
222    use crate::Tr;
223
224    #[test]
225    fn test_isomorphic_1() -> Result<()> {
226        let fst_1: VectorFst<LogWeight> = SerializableFst::from_text_string(
227            "0\t1\t12\t25\n\
228             1\n",
229        )?;
230
231        let mut fst_2 = fst_1.clone();
232        assert!(isomorphic(&fst_1, &fst_2)?);
233
234        fst_2.add_tr(0, Tr::new(33, 45, LogWeight::new(0.3), 1))?;
235        assert!(!isomorphic(&fst_1, &fst_2)?);
236
237        Ok(())
238    }
239
240    #[test]
241    fn test_isomorphic_2() -> Result<()> {
242        let fst_1: VectorFst<LogWeight> = SerializableFst::from_text_string(
243            "0\t1\t12\t25\n\
244             1\n",
245        )?;
246
247        let fst_2: VectorFst<LogWeight> = SerializableFst::from_text_string(
248            "1\t0\t12\t25\n\
249             0\n",
250        )?;
251
252        assert!(isomorphic(&fst_1, &fst_2)?);
253
254        Ok(())
255    }
256}