rustfst/semirings/
gallic_weight.rs

1use std::borrow::Borrow;
2use std::cmp::Ordering;
3use std::fmt::{Display, Formatter};
4use std::io::Write;
5use std::marker::PhantomData;
6
7use anyhow::Result;
8use nom::IResult;
9
10use crate::parsers::nom_utils::NomCustomError;
11use crate::semirings::Semiring;
12#[cfg(test)]
13use crate::semirings::TropicalWeight;
14use crate::semirings::{
15    DivideType, SemiringProperties, SerializableSemiring, StringWeightLeft, StringWeightRestrict,
16    StringWeightRight, UnionWeight, UnionWeightOption, WeaklyDivisibleSemiring, WeightQuantize,
17};
18use crate::semirings::{ProductWeight, ReverseBack};
19use crate::Label;
20
21/// Product of StringWeightLeft and an arbitrary weight.
22#[derive(PartialOrd, PartialEq, Eq, Clone, Hash, Debug)]
23pub struct GallicWeightLeft<W>(ProductWeight<StringWeightLeft, W>)
24where
25    W: Semiring;
26
27/// Product of StringWeightRight and an arbitrary weight.
28#[derive(PartialOrd, PartialEq, Eq, Clone, Hash, Debug)]
29pub struct GallicWeightRight<W>(ProductWeight<StringWeightRight, W>)
30where
31    W: Semiring;
32
33/// Product of StringWeighRestrict and an arbitrary weight.
34#[derive(PartialOrd, PartialEq, Eq, Clone, Hash, Debug)]
35pub struct GallicWeightRestrict<W>(ProductWeight<StringWeightRestrict, W>)
36where
37    W: Semiring;
38
39/// Product of StringWeightRestrict and an arbitrary weight.
40#[derive(PartialOrd, PartialEq, Eq, Clone, Hash, Debug)]
41pub struct GallicWeightMin<W>(ProductWeight<StringWeightRestrict, W>)
42where
43    W: Semiring;
44
45fn natural_less<W: Semiring>(w1: &W, w2: &W) -> Result<bool> {
46    Ok((&w1.plus(w2)? == w1) && (w1 != w2))
47}
48
49#[allow(clippy::enum_variant_names)]
50pub enum GallicType {
51    GallicLeft,
52    GallicRight,
53    GallicRestrict,
54    GallicMin,
55}
56
57macro_rules! gallic_weight {
58    ($semiring: ty, $string_weight: ty, $gallic_type: expr, $reverse_semiring: ty) => {
59        impl<W> std::fmt::Display for $semiring
60        where
61            W: SerializableSemiring,
62        {
63            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64                self.0.fmt(f)
65            }
66        }
67
68        impl<W> AsRef<$semiring> for $semiring
69        where
70            W: Semiring,
71        {
72            fn as_ref(&self) -> &Self {
73                &self
74            }
75        }
76
77        impl<W: Semiring> ReverseBack<$semiring> for <$semiring as Semiring>::ReverseWeight {
78            fn reverse_back(&self) -> Result<$semiring> {
79                Ok(<$semiring>::new(self.0.reverse_back()?))
80            }
81        }
82
83        impl<W> Semiring for $semiring
84        where
85            W: Semiring,
86        {
87            type Type = ProductWeight<$string_weight, W>;
88            type ReverseWeight = $reverse_semiring;
89
90            fn zero() -> Self {
91                Self(ProductWeight::zero())
92            }
93
94            fn one() -> Self {
95                Self(ProductWeight::one())
96            }
97
98            fn new(value: Self::Type) -> Self {
99                Self(value)
100            }
101
102            fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
103                match $gallic_type {
104                    GallicType::GallicLeft => self.0.plus_assign(&rhs.borrow().0)?,
105                    GallicType::GallicRight => self.0.plus_assign(&rhs.borrow().0)?,
106                    GallicType::GallicRestrict => self.0.plus_assign(&rhs.borrow().0)?,
107                    GallicType::GallicMin => {
108                        if !natural_less(self.value2(), rhs.borrow().value2())? {
109                            self.set_value(rhs.borrow().value().clone());
110                        }
111                    }
112                };
113                Ok(())
114            }
115
116            fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
117                self.0.times_assign(&rhs.borrow().0)
118            }
119
120            fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool {
121                self.0.approx_equal(&rhs.borrow().0, delta)
122            }
123
124            fn value(&self) -> &Self::Type {
125                &self.0
126            }
127
128            fn take_value(self) -> Self::Type {
129                self.0
130            }
131
132            fn set_value(&mut self, value: Self::Type) {
133                self.0 = value;
134            }
135
136            fn reverse(&self) -> Result<Self::ReverseWeight> {
137                Ok(Self::ReverseWeight::new(self.0.reverse()?))
138            }
139
140            fn properties() -> SemiringProperties {
141                ProductWeight::<$string_weight, W>::properties()
142            }
143        }
144
145        impl<W> $semiring
146        where
147            W: Semiring,
148        {
149            pub fn value1(&self) -> &$string_weight {
150                &self.0.value1()
151            }
152
153            pub fn value2(&self) -> &W {
154                &self.0.value2()
155            }
156
157            pub fn set_value1(&mut self, new_weight: $string_weight) {
158                self.0.set_value1(new_weight);
159            }
160
161            pub fn set_value2(&mut self, new_weight: W) {
162                self.0.set_value2(new_weight)
163            }
164        }
165
166        impl<W> From<($string_weight, W)> for $semiring
167        where
168            W: Semiring,
169        {
170            fn from(w: ($string_weight, W)) -> Self {
171                Self::new(w.into())
172            }
173        }
174
175        impl<W> From<(Vec<Label>, W)> for $semiring
176        where
177            W: Semiring,
178        {
179            fn from(w: (Vec<Label>, W)) -> Self {
180                let (w1, w2) = w;
181                Self::new((w1.into(), w2).into())
182            }
183        }
184
185        impl<W> From<(Label, W)> for $semiring
186        where
187            W: Semiring,
188        {
189            fn from(w: (Label, W)) -> Self {
190                let (w1, w2) = w;
191                Self::new((w1.into(), w2).into())
192            }
193        }
194
195        impl<W> WeaklyDivisibleSemiring for $semiring
196        where
197            W: WeaklyDivisibleSemiring,
198        {
199            fn divide_assign(&mut self, rhs: &Self, divide_type: DivideType) -> Result<()> {
200                self.0
201                    .weight
202                    .0
203                    .divide_assign(&rhs.0.weight.0, divide_type)?;
204                self.0
205                    .weight
206                    .1
207                    .divide_assign(&rhs.0.weight.1, divide_type)?;
208                Ok(())
209            }
210        }
211
212        impl<W> WeightQuantize for $semiring
213        where
214            W: WeightQuantize,
215        {
216            fn quantize_assign(&mut self, delta: f32) -> Result<()> {
217                self.0.quantize_assign(delta)
218            }
219        }
220
221        impl<W: SerializableSemiring> SerializableSemiring for $semiring {
222            fn weight_type() -> String {
223                match $gallic_type {
224                    GallicType::GallicLeft => "left_gallic".to_string(),
225                    GallicType::GallicRight => "right_gallic".to_string(),
226                    GallicType::GallicRestrict => "restricted_gallic".to_string(),
227                    GallicType::GallicMin => "min_gallic".to_string(),
228                }
229            }
230
231            fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
232                let (i, w) = ProductWeight::<$string_weight, W>::parse_binary(i)?;
233                Ok((i, Self(w)))
234            }
235
236            fn write_binary<F: Write>(&self, file: &mut F) -> Result<()> {
237                self.0.write_binary(file)
238            }
239
240            fn parse_text(i: &str) -> IResult<&str, Self> {
241                let (i, w) = ProductWeight::<$string_weight, W>::parse_text(i)?;
242                Ok((i, Self(w)))
243            }
244        }
245    };
246}
247
248gallic_weight!(
249    GallicWeightLeft<W>,
250    StringWeightLeft,
251    GallicType::GallicLeft,
252    GallicWeightRight<W::ReverseWeight>
253);
254
255gallic_weight!(
256    GallicWeightRight<W>,
257    StringWeightRight,
258    GallicType::GallicRight,
259    GallicWeightLeft<W::ReverseWeight>
260);
261
262gallic_weight!(
263    GallicWeightRestrict<W>,
264    StringWeightRestrict,
265    GallicType::GallicRestrict,
266    GallicWeightRestrict<W::ReverseWeight>
267);
268
269gallic_weight!(
270    GallicWeightMin<W>,
271    StringWeightRestrict,
272    GallicType::GallicMin,
273    GallicWeightMin<W::ReverseWeight>
274);
275#[derive(Debug, Hash, Clone, PartialEq, PartialOrd, Eq)]
276pub struct GallicUnionWeightOption<W> {
277    ghost: PhantomData<W>,
278}
279
280impl<W: Semiring> UnionWeightOption<GallicWeightRestrict<W>>
281    for GallicUnionWeightOption<GallicWeightRestrict<W>>
282{
283    type ReverseOptions = GallicUnionWeightOption<GallicWeightRestrict<W::ReverseWeight>>;
284
285    fn compare(w1: &GallicWeightRestrict<W>, w2: &GallicWeightRestrict<W>) -> bool {
286        let s1 = w1.0.value1();
287        let s2 = w2.0.value1();
288        let n1 = s1.len_labels();
289        let n2 = s2.len_labels();
290
291        match n1.cmp(&n2) {
292            Ordering::Less => true,
293            Ordering::Greater => false,
294            Ordering::Equal => {
295                if n1 == 0 {
296                    return false;
297                }
298                let v1 = s1.value.unwrap_labels();
299                let v2 = s2.value.unwrap_labels();
300                for i in 0..n1 {
301                    let l1 = v1[i];
302                    let l2 = v2[i];
303                    if l1 < l2 {
304                        return true;
305                    }
306                    if l1 > l2 {
307                        return false;
308                    }
309                }
310                false
311            }
312        }
313    }
314
315    fn merge(
316        w1: &GallicWeightRestrict<W>,
317        w2: &GallicWeightRestrict<W>,
318    ) -> Result<GallicWeightRestrict<W>> {
319        let p = ProductWeight::new((w1.0.value1().clone(), w1.0.value2().plus(w2.0.value2())?));
320        Ok(GallicWeightRestrict(p))
321    }
322}
323
324/// UnionWeight of GallicWeightRestrict.
325#[derive(Debug, PartialOrd, PartialEq, Clone, Hash, Eq)]
326pub struct GallicWeight<W>(
327    pub UnionWeight<GallicWeightRestrict<W>, GallicUnionWeightOption<GallicWeightRestrict<W>>>,
328)
329where
330    W: Semiring;
331
332impl<W> Display for GallicWeight<W>
333where
334    W: SerializableSemiring,
335{
336    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
337        self.0.fmt(f)
338    }
339}
340
341impl<W> AsRef<GallicWeight<W>> for GallicWeight<W>
342where
343    W: Semiring,
344{
345    fn as_ref(&self) -> &Self {
346        self
347    }
348}
349
350impl<W: Semiring> Semiring for GallicWeight<W> {
351    type Type = Vec<GallicWeightRestrict<W>>;
352    type ReverseWeight = GallicWeight<W::ReverseWeight>;
353
354    fn zero() -> Self {
355        Self(UnionWeight::zero())
356    }
357
358    fn one() -> Self {
359        Self(UnionWeight::one())
360    }
361
362    fn new(value: Self::Type) -> Self {
363        Self(UnionWeight::new(value))
364    }
365
366    fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
367        self.0.plus_assign(&rhs.borrow().0)
368    }
369
370    fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
371        self.0.times_assign(&rhs.borrow().0)
372    }
373
374    fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool {
375        self.0.approx_equal(&rhs.borrow().0, delta)
376    }
377
378    fn value(&self) -> &Self::Type {
379        self.0.value()
380    }
381
382    fn take_value(self) -> Self::Type {
383        self.0.take_value()
384    }
385
386    fn set_value(&mut self, value: Self::Type) {
387        self.0.set_value(value)
388    }
389
390    fn reverse(&self) -> Result<Self::ReverseWeight> {
391        Ok(GallicWeight(self.0.reverse()?))
392    }
393
394    fn properties() -> SemiringProperties {
395        UnionWeight::<GallicWeightRestrict<W>, GallicUnionWeightOption<GallicWeightRestrict<W>>>::properties()
396    }
397}
398
399impl<W: Semiring> ReverseBack<GallicWeight<W>> for <GallicWeight<W> as Semiring>::ReverseWeight {
400    fn reverse_back(&self) -> Result<GallicWeight<W>> {
401        Ok(GallicWeight(self.0.reverse_back()?))
402    }
403}
404
405impl<W: Semiring> GallicWeight<W> {
406    pub fn len(&self) -> usize {
407        self.0.len()
408    }
409
410    pub fn is_empty(&self) -> bool {
411        self.0.is_empty()
412    }
413
414    pub fn iter(&self) -> impl Iterator<Item = &GallicWeightRestrict<W>> {
415        self.0.iter()
416    }
417}
418
419impl<W> From<(StringWeightRestrict, W)> for GallicWeight<W>
420where
421    W: Semiring,
422{
423    fn from(w: (StringWeightRestrict, W)) -> Self {
424        let (w1, w2) = w;
425        let mut gw = GallicWeightRestrict::one();
426        gw.set_value1(w1);
427        gw.set_value2(w2);
428        Self::new(vec![gw])
429    }
430}
431
432impl<W> From<GallicWeightRestrict<W>> for GallicWeight<W>
433where
434    W: Semiring,
435{
436    fn from(w: GallicWeightRestrict<W>) -> Self {
437        Self::new(vec![w])
438    }
439}
440
441impl<W> From<(Vec<Label>, W)> for GallicWeight<W>
442where
443    W: Semiring,
444{
445    fn from(w: (Vec<Label>, W)) -> Self {
446        let (w1, w2) = w;
447        let a: StringWeightRestrict = w1.into();
448        (a, w2).into()
449    }
450}
451
452impl<W> From<(Label, W)> for GallicWeight<W>
453where
454    W: Semiring,
455{
456    fn from(w: (Label, W)) -> Self {
457        let (w1, w2) = w;
458        (vec![w1], w2).into()
459    }
460}
461
462impl<W> WeaklyDivisibleSemiring for GallicWeight<W>
463where
464    W: WeaklyDivisibleSemiring,
465{
466    fn divide_assign(&mut self, rhs: &Self, divide_type: DivideType) -> Result<()> {
467        self.0.divide_assign(&rhs.0, divide_type)?;
468        Ok(())
469    }
470}
471
472impl<W> WeightQuantize for GallicWeight<W>
473where
474    W: WeightQuantize,
475{
476    fn quantize_assign(&mut self, delta: f32) -> Result<()> {
477        self.0.quantize_assign(delta)
478    }
479}
480
481impl<W: SerializableSemiring> SerializableSemiring for GallicWeight<W> {
482    fn weight_type() -> String {
483        "gallic".to_string()
484    }
485
486    fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
487        let (i, w) = UnionWeight::<
488            GallicWeightRestrict<W>,
489            GallicUnionWeightOption<GallicWeightRestrict<W>>,
490        >::parse_binary(i)?;
491        Ok((i, Self(w)))
492    }
493
494    fn write_binary<F: Write>(&self, file: &mut F) -> Result<()> {
495        self.0.write_binary(file)
496    }
497
498    fn parse_text(i: &str) -> IResult<&str, Self> {
499        let (i, w) = UnionWeight::<
500            GallicWeightRestrict<W>,
501            GallicUnionWeightOption<GallicWeightRestrict<W>>,
502        >::parse_text(i)?;
503        Ok((i, Self(w)))
504    }
505}
506
507test_semiring_serializable!(
508    test_gallic_weight_left_serializable,
509    GallicWeightLeft<TropicalWeight>,
510    GallicWeightLeft::one()
511    GallicWeightLeft::zero()
512    GallicWeightLeft::from((vec![1,2],TropicalWeight::new(0.3)))
513);
514
515test_semiring_serializable!(
516    test_gallic_weight_right_serializable,
517    GallicWeightRight<TropicalWeight>,
518    GallicWeightRight::one()
519    GallicWeightRight::zero()
520    GallicWeightRight::from((vec![1,2],TropicalWeight::new(0.3)))
521);
522
523test_semiring_serializable!(
524    test_gallic_weight_restrict_serializable,
525    GallicWeightRestrict<TropicalWeight>,
526    GallicWeightRestrict::one()
527    GallicWeightRestrict::zero()
528    GallicWeightRestrict::from((vec![1,2],TropicalWeight::new(0.3)))
529);
530
531test_semiring_serializable!(
532    test_gallic_weight_min_serializable,
533    GallicWeightMin<TropicalWeight>,
534    GallicWeightMin::one()
535    GallicWeightMin::zero()
536    GallicWeightMin::from((vec![1,2],TropicalWeight::new(0.3)))
537);
538
539test_semiring_serializable!(
540    test_gallic_weight_serializable,
541    GallicWeight<TropicalWeight>,
542    GallicWeight::one()
543    GallicWeight::zero()
544    GallicWeight::from((vec![1,2],TropicalWeight::new(0.3)))
545);