rustfst/semirings/
probability_weight.rs1use std::borrow::Borrow;
2use std::f32;
3use std::hash::{Hash, Hasher};
4use std::io::Write;
5
6use anyhow::Result;
7use nom::number::complete::float;
8use nom::IResult;
9use ordered_float::OrderedFloat;
10
11use crate::parsers::nom_utils::NomCustomError;
12use crate::parsers::parse_bin_f32;
13use crate::parsers::write_bin_f32;
14use crate::semirings::utils_float::float_approx_equal;
15use crate::semirings::{
16 CompleteSemiring, DivideType, ReverseBack, Semiring, SemiringProperties, SerializableSemiring,
17 StarSemiring, WeaklyDivisibleSemiring, WeightQuantize,
18};
19use crate::KDELTA;
20
21#[derive(Clone, Debug, PartialOrd, Default, Copy, Eq)]
23pub struct ProbabilityWeight {
24 value: OrderedFloat<f32>,
25}
26
27impl Semiring for ProbabilityWeight {
28 type Type = f32;
29 type ReverseWeight = ProbabilityWeight;
30
31 fn zero() -> Self {
32 Self {
33 value: OrderedFloat(0.0),
34 }
35 }
36 fn one() -> Self {
37 Self {
38 value: OrderedFloat(1.0),
39 }
40 }
41
42 fn new(value: <Self as Semiring>::Type) -> Self {
43 ProbabilityWeight {
44 value: OrderedFloat(value),
45 }
46 }
47
48 fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
49 self.value.0 += rhs.borrow().value.0;
50 Ok(())
51 }
52
53 fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
54 self.value.0 *= rhs.borrow().value.0;
55 Ok(())
56 }
57
58 fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool {
59 float_approx_equal(self.value.0, rhs.borrow().value.0, delta)
60 }
61
62 fn value(&self) -> &Self::Type {
63 self.value.as_ref()
64 }
65
66 fn take_value(self) -> Self::Type {
67 self.value.into_inner()
68 }
69
70 fn set_value(&mut self, value: <Self as Semiring>::Type) {
71 self.value.0 = value
72 }
73
74 fn reverse(&self) -> Result<Self::ReverseWeight> {
75 Ok(*self)
76 }
77
78 fn properties() -> SemiringProperties {
79 SemiringProperties::LEFT_SEMIRING
80 | SemiringProperties::RIGHT_SEMIRING
81 | SemiringProperties::COMMUTATIVE
82 }
83}
84
85impl ReverseBack<ProbabilityWeight> for ProbabilityWeight {
86 fn reverse_back(&self) -> Result<ProbabilityWeight> {
87 unimplemented!()
88 }
89}
90
91impl AsRef<ProbabilityWeight> for ProbabilityWeight {
92 fn as_ref(&self) -> &ProbabilityWeight {
93 self
94 }
95}
96
97display_semiring!(ProbabilityWeight);
98
99impl CompleteSemiring for ProbabilityWeight {}
100
101impl SerializableSemiring for ProbabilityWeight {
102 fn weight_type() -> String {
103 "probability".to_string()
104 }
105
106 fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
107 let (i, weight) = parse_bin_f32(i)?;
108 Ok((i, Self::new(weight)))
109 }
110
111 fn write_binary<F: Write>(&self, file: &mut F) -> Result<()> {
112 write_bin_f32(file, *self.value())
113 }
114
115 fn parse_text(i: &str) -> IResult<&str, Self> {
116 let (i, f) = float(i)?;
117 Ok((i, Self::new(f)))
118 }
119}
120
121impl StarSemiring for ProbabilityWeight {
122 fn closure(&self) -> Self {
123 Self::new(1.0 / (1.0 - self.value.0))
124 }
125}
126
127impl WeaklyDivisibleSemiring for ProbabilityWeight {
128 fn divide_assign(&mut self, rhs: &Self, _divide_type: DivideType) -> Result<()> {
129 if rhs.value.0 == 0.0 {
131 bail!("Division by 0")
132 }
133 self.value.0 /= rhs.value.0;
134 Ok(())
135 }
136}
137
138impl_quantize_f32!(ProbabilityWeight);
139
140partial_eq_and_hash_f32!(ProbabilityWeight);
141
142test_semiring_serializable!(
143 tests_probability_weight_serializable,
144 ProbabilityWeight,
145 ProbabilityWeight::one() ProbabilityWeight::zero() ProbabilityWeight::new(0.3) ProbabilityWeight::new(0.5) ProbabilityWeight::new(0.0) ProbabilityWeight::new(1.0)
146);
147
148impl From<f32> for ProbabilityWeight {
149 fn from(f: f32) -> Self {
150 Self::new(f)
151 }
152}