rustfst/algorithms/determinize/
divisors.rs

1use std::fmt::Debug;
2
3use anyhow::Result;
4
5use crate::semirings::{
6    GallicWeight, GallicWeightLeft, GallicWeightMin, GallicWeightRestrict, StringWeightLeft,
7    StringWeightRestrict,
8};
9use crate::Semiring;
10
11pub trait CommonDivisor<W: Semiring>: PartialEq + Debug + Sync {
12    fn common_divisor(w1: &W, w2: &W) -> Result<W>;
13}
14
15#[derive(PartialEq, Debug)]
16pub struct DefaultCommonDivisor {}
17
18impl<W: Semiring> CommonDivisor<W> for DefaultCommonDivisor {
19    fn common_divisor(w1: &W, w2: &W) -> Result<W> {
20        w1.plus(w2)
21    }
22}
23
24#[derive(PartialEq, Debug)]
25pub struct LabelCommonDivisor {}
26
27macro_rules! impl_label_common_divisor {
28    ($string_semiring: ident) => {
29        impl CommonDivisor<$string_semiring> for LabelCommonDivisor {
30            fn common_divisor(
31                w1: &$string_semiring,
32                w2: &$string_semiring,
33            ) -> Result<$string_semiring> {
34                let mut iter1 = w1.iter();
35                let mut iter2 = w2.iter();
36                if w1.value.is_empty_list() || w2.value.is_empty_list() {
37                    Ok($string_semiring::one())
38                } else if w1.value.is_infinity() {
39                    Ok(iter2.next().unwrap().into())
40                } else if w2.value.is_infinity() {
41                    Ok(iter1.next().unwrap().into())
42                } else {
43                    let v1 = iter1.next().unwrap();
44                    let v2 = iter2.next().unwrap();
45                    if v1 == v2 {
46                        Ok(v1.into())
47                    } else {
48                        Ok($string_semiring::one())
49                    }
50                }
51            }
52        }
53    };
54}
55
56impl_label_common_divisor!(StringWeightLeft);
57impl_label_common_divisor!(StringWeightRestrict);
58
59#[derive(Debug, PartialEq)]
60pub struct GallicCommonDivisor {}
61
62macro_rules! impl_gallic_common_divisor {
63    ($gallic: ident) => {
64        impl<W: Semiring> CommonDivisor<$gallic<W>> for GallicCommonDivisor {
65            fn common_divisor(w1: &$gallic<W>, w2: &$gallic<W>) -> Result<$gallic<W>> {
66                let v1 = LabelCommonDivisor::common_divisor(w1.value1(), w2.value1())?;
67                let v2 = DefaultCommonDivisor::common_divisor(w1.value2(), w2.value2())?;
68                Ok((v1, v2).into())
69            }
70        }
71    };
72}
73
74impl_gallic_common_divisor!(GallicWeightLeft);
75impl_gallic_common_divisor!(GallicWeightRestrict);
76impl_gallic_common_divisor!(GallicWeightMin);
77
78impl<W: Semiring> CommonDivisor<GallicWeight<W>> for GallicCommonDivisor {
79    fn common_divisor(w1: &GallicWeight<W>, w2: &GallicWeight<W>) -> Result<GallicWeight<W>> {
80        let mut weight = GallicWeightRestrict::zero();
81        for w in w1.iter().chain(w2.iter()) {
82            weight = GallicCommonDivisor::common_divisor(&weight, w)?;
83        }
84        if weight.is_zero() {
85            Ok(GallicWeight::zero())
86        } else {
87            Ok(weight.into())
88        }
89    }
90}