rustfst/semirings/
union_weight.rs

1use std::borrow::Borrow;
2use std::cmp::Ordering;
3use std::fmt::Debug;
4use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::io::Write;
7use std::marker::PhantomData;
8
9use anyhow::Result;
10use nom::branch::alt;
11use nom::bytes::complete::tag;
12use nom::multi::{count, separated_list0};
13use nom::IResult;
14
15use crate::parsers::nom_utils::NomCustomError;
16use crate::parsers::parse_bin_i32;
17use crate::parsers::write_bin_i32;
18use crate::semirings::{
19    DivideType, ReverseBack, Semiring, SemiringProperties, SerializableSemiring,
20    WeaklyDivisibleSemiring, WeightQuantize,
21};
22
23pub trait UnionWeightOption<W: Semiring>:
24    Debug + Hash + Clone + PartialOrd + Eq + Sync + 'static
25{
26    type ReverseOptions: UnionWeightOption<W::ReverseWeight>;
27    fn compare(w1: &W, w2: &W) -> bool;
28    fn merge(w1: &W, w2: &W) -> Result<W>;
29}
30
31/// Semiring that uses Times() and One() from W and union and the empty set
32/// for Plus() and Zero(), respectively. Template argument O specifies the union
33/// weight options as above.
34#[derive(PartialOrd, PartialEq, Clone, Eq, Debug, Hash, Default)]
35pub struct UnionWeight<W: Semiring, O: UnionWeightOption<W>> {
36    pub(crate) list: Vec<W>,
37    ghost: PhantomData<O>,
38}
39
40impl<W, O> AsRef<UnionWeight<W, O>> for UnionWeight<W, O>
41where
42    W: Semiring,
43    O: UnionWeightOption<W>,
44{
45    fn as_ref(&self) -> &Self {
46        self
47    }
48}
49
50impl<W: Semiring, O: UnionWeightOption<W>> Semiring for UnionWeight<W, O> {
51    type Type = Vec<W>;
52    type ReverseWeight = UnionWeight<W::ReverseWeight, O::ReverseOptions>;
53
54    fn zero() -> Self {
55        Self {
56            list: vec![],
57            ghost: PhantomData,
58        }
59    }
60
61    fn one() -> Self {
62        Self {
63            list: vec![W::one()],
64            ghost: PhantomData,
65        }
66    }
67
68    fn new(value: Self::Type) -> Self {
69        Self {
70            list: value,
71            ghost: PhantomData,
72        }
73    }
74
75    fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
76        if self.is_zero() {
77            self.set_value(rhs.borrow().value().clone());
78        } else if rhs.borrow().is_zero() {
79            // Nothing
80        } else {
81            let mut sum: UnionWeight<W, O> = UnionWeight::zero();
82            let n1 = self.list.len();
83            let n2 = rhs.borrow().list.len();
84            let mut i1 = 0;
85            let mut i2 = 0;
86            while i1 < n1 && i2 < n2 {
87                let v1 = unsafe { self.list.get_unchecked(i1) };
88                let v2 = unsafe { rhs.borrow().list.get_unchecked(i2) };
89                if O::compare(v1, v2) {
90                    sum.push_back(v1.clone(), true)?;
91                    i1 += 1;
92                } else {
93                    sum.push_back(v2.clone(), true)?;
94                    i2 += 1;
95                }
96            }
97
98            for i in i1..n1 {
99                let v1 = unsafe { self.list.get_unchecked(i) };
100                sum.push_back(v1.clone(), true)?;
101            }
102
103            for i in i2..n2 {
104                let v2 = unsafe { rhs.borrow().list.get_unchecked(i) };
105                sum.push_back(v2.clone(), true)?;
106            }
107            //TODO: Remove this copy and do the modification inplace
108            self.set_value(sum.take_value());
109        }
110        Ok(())
111    }
112
113    fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
114        if self.is_zero() || rhs.borrow().is_zero() {
115            self.set_value(Self::zero().take_value());
116        } else {
117            let mut prod1: UnionWeight<W, O> = UnionWeight::zero();
118            for w1 in self.iter() {
119                let mut prod2: UnionWeight<W, O> = UnionWeight::zero();
120                for w2 in rhs.borrow().iter() {
121                    let p = w1.times(w2)?;
122                    prod2.push_back(p, true)?;
123                }
124                prod1.plus_assign(prod2)?;
125            }
126            self.set_value(prod1.take_value());
127        }
128        Ok(())
129    }
130
131    fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool {
132        if self.len() != rhs.borrow().len() {
133            return false;
134        }
135        let it1 = self.iter();
136        let it2 = rhs.borrow().iter();
137        for (w1, w2) in it1.zip(it2) {
138            if !w1.approx_equal(w2, delta) {
139                return false;
140            }
141        }
142        true
143    }
144
145    fn value(&self) -> &Self::Type {
146        &self.list
147    }
148
149    fn take_value(self) -> Self::Type {
150        self.list
151    }
152
153    fn set_value(&mut self, value: Self::Type) {
154        self.list = value;
155    }
156
157    fn reverse(&self) -> Result<Self::ReverseWeight> {
158        let mut rw = Self::ReverseWeight::zero();
159        for v in self.iter() {
160            rw.push_back(v.reverse()?, false)?;
161        }
162        rw.list.sort_by(|v1, v2| {
163            if O::ReverseOptions::compare(v1, v2) {
164                Ordering::Less
165            } else {
166                Ordering::Greater
167            }
168        });
169        Ok(rw)
170    }
171
172    fn properties() -> SemiringProperties {
173        W::properties()
174            & (SemiringProperties::LEFT_SEMIRING
175                | SemiringProperties::RIGHT_SEMIRING
176                | SemiringProperties::COMMUTATIVE
177                | SemiringProperties::IDEMPOTENT)
178    }
179}
180
181impl<W: Semiring, O: UnionWeightOption<W>> UnionWeight<W, O> {
182    fn push_back(&mut self, weight: W, sorted: bool) -> Result<()> {
183        if self.list.is_empty() {
184            self.list.push(weight);
185        } else if sorted {
186            let n = self.list.len();
187            let back = &mut self.list[n - 1];
188            if O::compare(back, &weight) {
189                self.list.push(weight);
190            } else {
191                *back = O::merge(back, &weight)?;
192            }
193        } else {
194            let first = &mut self.list[0];
195            if O::compare(first, &weight) {
196                self.list.push(weight);
197            } else {
198                let first_cloned = first.clone();
199                *first = weight;
200                self.list.push(first_cloned);
201            }
202        }
203        Ok(())
204    }
205
206    pub fn len(&self) -> usize {
207        self.list.len()
208    }
209
210    pub fn is_empty(&self) -> bool {
211        self.list.is_empty()
212    }
213
214    pub fn iter(&self) -> impl Iterator<Item = &W> {
215        self.list.iter()
216    }
217}
218
219impl<W, O> WeaklyDivisibleSemiring for UnionWeight<W, O>
220where
221    W: WeaklyDivisibleSemiring,
222    O: UnionWeightOption<W>,
223{
224    fn divide_assign(&mut self, rhs: &Self, divide_type: DivideType) -> Result<()> {
225        if self.is_zero() || rhs.is_zero() {
226            self.list.clear();
227        }
228        let mut quot = Self::zero();
229        if self.len() == 1 {
230            for v in rhs.list.iter().rev() {
231                quot.push_back(self.list[0].divide(v, divide_type)?, true)?;
232            }
233        } else if rhs.len() == 1 {
234            for v in self.list.iter() {
235                quot.push_back(v.divide(&rhs.list[0], divide_type)?, true)?;
236            }
237        } else {
238            bail!("Expected at least of the two parameters to have a single element");
239        }
240        self.set_value(quot.take_value());
241        Ok(())
242    }
243}
244
245impl<W, O> WeightQuantize for UnionWeight<W, O>
246where
247    W: WeightQuantize,
248    O: UnionWeightOption<W>,
249{
250    fn quantize_assign(&mut self, delta: f32) -> Result<()> {
251        let v: Vec<_> = self.list.drain(..).collect();
252        for mut e in v {
253            e.quantize_assign(delta)?;
254            self.push_back(e.quantize(delta)?, true)?;
255        }
256        Ok(())
257    }
258}
259
260impl<W, O> UnionWeight<W, O>
261where
262    W: SerializableSemiring,
263    O: UnionWeightOption<W>,
264{
265    fn parse_text_empty_set(i: &str) -> IResult<&str, Self> {
266        let (i, _) = tag("EmptySet")(i)?;
267        Ok((i, Self::zero()))
268    }
269
270    fn parse_text_non_empty_set(i: &str) -> IResult<&str, Self> {
271        let (i, weights) = separated_list0(tag(","), W::parse_text)(i)?;
272        Ok((i, Self::new(weights)))
273    }
274}
275
276impl<W, O> Display for UnionWeight<W, O>
277where
278    W: SerializableSemiring,
279    O: UnionWeightOption<W>,
280{
281    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
282        if self.is_empty() {
283            write!(f, "EmptySet")?;
284        } else {
285            for (idx, w) in self.list.iter().enumerate() {
286                if idx > 0 {
287                    write!(f, ",")?;
288                }
289                write!(f, "{}", w)?;
290            }
291        }
292        Ok(())
293    }
294}
295
296impl<W, O> SerializableSemiring for UnionWeight<W, O>
297where
298    W: SerializableSemiring,
299    O: UnionWeightOption<W>,
300{
301    fn weight_type() -> String {
302        format!("{}_union", W::weight_type())
303    }
304
305    fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
306        let (i, n) = parse_bin_i32(i)?;
307        let (i, labels) = count(W::parse_binary, n as usize)(i)?;
308        Ok((i, Self::new(labels)))
309    }
310
311    fn write_binary<F: Write>(&self, file: &mut F) -> Result<()> {
312        write_bin_i32(file, self.list.len() as i32)?;
313        for w in self.list.iter() {
314            w.write_binary(file)?;
315        }
316        Ok(())
317    }
318
319    fn parse_text(i: &str) -> IResult<&str, Self> {
320        let (i, res) = alt((Self::parse_text_empty_set, Self::parse_text_non_empty_set))(i)?;
321        Ok((i, res))
322    }
323}
324
325impl<W: Semiring, O: UnionWeightOption<W>> ReverseBack<UnionWeight<W, O>>
326    for <UnionWeight<W, O> as Semiring>::ReverseWeight
327{
328    fn reverse_back(&self) -> Result<UnionWeight<W, O>> {
329        let res = self.reverse()?;
330        // TODO: Find a way to avoid this transmute. For the moment, it is necessary because of the compare function.
331        unsafe { Ok(std::mem::transmute(res)) }
332    }
333}