Skip to main content

rustfst/semirings/
log_weight.rs

1use std::borrow::Borrow;
2use std::f32;
3use std::hash::{Hash, Hasher};
4use std::io::Write;
5
6use anyhow::Result;
7use nom::branch::alt;
8use nom::bytes::complete::tag_no_case;
9use nom::combinator::map;
10use nom::number::complete::float;
11use nom::IResult;
12use ordered_float::OrderedFloat;
13
14use crate::parsers::nom_utils::NomCustomError;
15use crate::parsers::parse_bin_f32;
16use crate::parsers::write_bin_f32;
17use crate::semirings::utils_float::float_approx_equal;
18use crate::semirings::{
19    CompleteSemiring, DivideType, ReverseBack, Semiring, SemiringProperties, SerializableSemiring,
20    StarSemiring, WeaklyDivisibleSemiring, WeightQuantize,
21};
22use crate::KDELTA;
23
24/// Log semiring: (log(e^-x + e^-y), +, inf, 0).
25#[derive(Clone, Debug, PartialOrd, Default, Copy, Eq)]
26pub struct LogWeight {
27    value: OrderedFloat<f32>,
28}
29
30fn ln_pos_exp(x: f32) -> f32 {
31    ((-x).exp()).ln_1p()
32}
33
34impl Semiring for LogWeight {
35    type Type = f32;
36    type ReverseWeight = LogWeight;
37
38    fn zero() -> Self {
39        Self {
40            value: OrderedFloat(f32::INFINITY),
41        }
42    }
43    fn one() -> Self {
44        Self {
45            value: OrderedFloat(0.0),
46        }
47    }
48
49    fn new(value: <Self as Semiring>::Type) -> Self {
50        LogWeight {
51            value: OrderedFloat(value),
52        }
53    }
54
55    fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
56        let f1 = self.value();
57        let f2 = rhs.borrow().value();
58        self.value.0 = if f1.eq(&f32::INFINITY) {
59            *f2
60        } else if f2.eq(&f32::INFINITY) {
61            *f1
62        } else if f1 > f2 {
63            f2 - ln_pos_exp(f1 - f2)
64        } else {
65            f1 - ln_pos_exp(f2 - f1)
66        };
67        Ok(())
68    }
69
70    fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
71        let f1 = self.value();
72        let f2 = rhs.borrow().value();
73        if f1.eq(&f32::INFINITY) {
74        } else if f2.eq(&f32::INFINITY) {
75            self.value.0 = *f2;
76        } else {
77            self.value.0 += f2;
78        }
79        Ok(())
80    }
81
82    fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool {
83        float_approx_equal(self.value.0, rhs.borrow().value.0, delta)
84    }
85
86    fn value(&self) -> &Self::Type {
87        self.value.as_ref()
88    }
89
90    fn take_value(self) -> Self::Type {
91        self.value.into_inner()
92    }
93
94    fn set_value(&mut self, value: <Self as Semiring>::Type) {
95        self.value.0 = value
96    }
97
98    fn reverse(&self) -> Result<Self::ReverseWeight> {
99        Ok(*self)
100    }
101
102    fn properties() -> SemiringProperties {
103        SemiringProperties::LEFT_SEMIRING
104            | SemiringProperties::RIGHT_SEMIRING
105            | SemiringProperties::COMMUTATIVE
106    }
107}
108
109impl ReverseBack<LogWeight> for LogWeight {
110    fn reverse_back(&self) -> Result<LogWeight> {
111        Ok(*self)
112    }
113}
114
115impl AsRef<LogWeight> for LogWeight {
116    fn as_ref(&self) -> &LogWeight {
117        self
118    }
119}
120
121display_semiring!(LogWeight);
122
123impl CompleteSemiring for LogWeight {}
124
125impl StarSemiring for LogWeight {
126    fn closure(&self) -> Self {
127        if self.value.0 >= 0.0 && self.value.0 < 1.0 {
128            Self::new((1.0 - self.value.0).ln())
129        } else {
130            Self::new(f32::NEG_INFINITY)
131        }
132    }
133}
134
135impl WeaklyDivisibleSemiring for LogWeight {
136    fn divide_assign(&mut self, rhs: &Self, _divide_type: DivideType) -> Result<()> {
137        self.value.0 -= rhs.value.0;
138        Ok(())
139    }
140}
141
142impl_quantize_f32!(LogWeight);
143
144partial_eq_and_hash_f32!(LogWeight);
145
146impl SerializableSemiring for LogWeight {
147    fn weight_type() -> String {
148        "log".to_string()
149    }
150
151    fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
152        let (i, weight) = parse_bin_f32(i)?;
153        Ok((i, Self::new(weight)))
154    }
155
156    fn write_binary<F: Write>(&self, file: &mut F) -> Result<()> {
157        write_bin_f32(file, *self.value())
158    }
159
160    fn parse_text(i: &str) -> IResult<&str, Self> {
161        // FIXME: nom 7 does not fully parse "infinity", therefore it is done manually here until
162        // the PR https://github.com/rust-bakery/nom/pull/1673 is merged.
163        let (i, f) = alt((map(tag_no_case("infinity"), |_| f32::INFINITY), float))(i)?;
164        Ok((i, Self::new(f)))
165    }
166}
167
168test_semiring_serializable!(
169    tests_log_weight_serializable,
170    LogWeight,
171    LogWeight::new(0.3) LogWeight::new(0.5) LogWeight::new(0.0) LogWeight::new(-1.2)
172);
173
174impl From<f32> for LogWeight {
175    fn from(f: f32) -> Self {
176        LogWeight::new(f)
177    }
178}