rustfst/semirings/
probability_weight.rs

1use std::borrow::Borrow;
2use std::f32;
3use std::hash::{Hash, Hasher};
4use std::io::Write;
5
6use anyhow::Result;
7use nom::number::complete::float;
8use nom::IResult;
9use ordered_float::OrderedFloat;
10
11use crate::parsers::nom_utils::NomCustomError;
12use crate::parsers::parse_bin_f32;
13use crate::parsers::write_bin_f32;
14use crate::semirings::utils_float::float_approx_equal;
15use crate::semirings::{
16    CompleteSemiring, DivideType, ReverseBack, Semiring, SemiringProperties, SerializableSemiring,
17    StarSemiring, WeaklyDivisibleSemiring, WeightQuantize,
18};
19use crate::KDELTA;
20
21/// Probability semiring: (x, +, 0.0, 1.0).
22#[derive(Clone, Debug, PartialOrd, Default, Copy, Eq)]
23pub struct ProbabilityWeight {
24    value: OrderedFloat<f32>,
25}
26
27impl Semiring for ProbabilityWeight {
28    type Type = f32;
29    type ReverseWeight = ProbabilityWeight;
30
31    fn zero() -> Self {
32        Self {
33            value: OrderedFloat(0.0),
34        }
35    }
36    fn one() -> Self {
37        Self {
38            value: OrderedFloat(1.0),
39        }
40    }
41
42    fn new(value: <Self as Semiring>::Type) -> Self {
43        ProbabilityWeight {
44            value: OrderedFloat(value),
45        }
46    }
47
48    fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
49        self.value.0 += rhs.borrow().value.0;
50        Ok(())
51    }
52
53    fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
54        self.value.0 *= rhs.borrow().value.0;
55        Ok(())
56    }
57
58    fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool {
59        float_approx_equal(self.value.0, rhs.borrow().value.0, delta)
60    }
61
62    fn value(&self) -> &Self::Type {
63        self.value.as_ref()
64    }
65
66    fn take_value(self) -> Self::Type {
67        self.value.into_inner()
68    }
69
70    fn set_value(&mut self, value: <Self as Semiring>::Type) {
71        self.value.0 = value
72    }
73
74    fn reverse(&self) -> Result<Self::ReverseWeight> {
75        Ok(*self)
76    }
77
78    fn properties() -> SemiringProperties {
79        SemiringProperties::LEFT_SEMIRING
80            | SemiringProperties::RIGHT_SEMIRING
81            | SemiringProperties::COMMUTATIVE
82    }
83}
84
85impl ReverseBack<ProbabilityWeight> for ProbabilityWeight {
86    fn reverse_back(&self) -> Result<ProbabilityWeight> {
87        unimplemented!()
88    }
89}
90
91impl AsRef<ProbabilityWeight> for ProbabilityWeight {
92    fn as_ref(&self) -> &ProbabilityWeight {
93        self
94    }
95}
96
97display_semiring!(ProbabilityWeight);
98
99impl CompleteSemiring for ProbabilityWeight {}
100
101impl SerializableSemiring for ProbabilityWeight {
102    fn weight_type() -> String {
103        "probability".to_string()
104    }
105
106    fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
107        let (i, weight) = parse_bin_f32(i)?;
108        Ok((i, Self::new(weight)))
109    }
110
111    fn write_binary<F: Write>(&self, file: &mut F) -> Result<()> {
112        write_bin_f32(file, *self.value())
113    }
114
115    fn parse_text(i: &str) -> IResult<&str, Self> {
116        let (i, f) = float(i)?;
117        Ok((i, Self::new(f)))
118    }
119}
120
121impl StarSemiring for ProbabilityWeight {
122    fn closure(&self) -> Self {
123        Self::new(1.0 / (1.0 - self.value.0))
124    }
125}
126
127impl WeaklyDivisibleSemiring for ProbabilityWeight {
128    fn divide_assign(&mut self, rhs: &Self, _divide_type: DivideType) -> Result<()> {
129        // May panic if rhs.value == 0.0
130        if rhs.value.0 == 0.0 {
131            bail!("Division by 0")
132        }
133        self.value.0 /= rhs.value.0;
134        Ok(())
135    }
136}
137
138impl_quantize_f32!(ProbabilityWeight);
139
140partial_eq_and_hash_f32!(ProbabilityWeight);
141
142test_semiring_serializable!(
143    tests_probability_weight_serializable,
144    ProbabilityWeight,
145    ProbabilityWeight::one() ProbabilityWeight::zero() ProbabilityWeight::new(0.3) ProbabilityWeight::new(0.5) ProbabilityWeight::new(0.0) ProbabilityWeight::new(1.0)
146);
147
148impl From<f32> for ProbabilityWeight {
149    fn from(f: f32) -> Self {
150        Self::new(f)
151    }
152}