1use std::borrow::Borrow;
2use std::cmp::Ordering;
3use std::fmt::Debug;
4use std::fmt::{Display, Formatter};
5use std::hash::Hash;
6use std::io::Write;
7use std::marker::PhantomData;
8
9use anyhow::Result;
10use nom::branch::alt;
11use nom::bytes::complete::tag;
12use nom::multi::{count, separated_list0};
13use nom::IResult;
14
15use crate::parsers::nom_utils::NomCustomError;
16use crate::parsers::parse_bin_i32;
17use crate::parsers::write_bin_i32;
18use crate::semirings::{
19 DivideType, ReverseBack, Semiring, SemiringProperties, SerializableSemiring,
20 WeaklyDivisibleSemiring, WeightQuantize,
21};
22
23pub trait UnionWeightOption<W: Semiring>:
24 Debug + Hash + Clone + PartialOrd + Eq + Sync + 'static
25{
26 type ReverseOptions: UnionWeightOption<W::ReverseWeight>;
27 fn compare(w1: &W, w2: &W) -> bool;
28 fn merge(w1: &W, w2: &W) -> Result<W>;
29}
30
31#[derive(PartialOrd, PartialEq, Clone, Eq, Debug, Hash, Default)]
35pub struct UnionWeight<W: Semiring, O: UnionWeightOption<W>> {
36 pub(crate) list: Vec<W>,
37 ghost: PhantomData<O>,
38}
39
40impl<W, O> AsRef<UnionWeight<W, O>> for UnionWeight<W, O>
41where
42 W: Semiring,
43 O: UnionWeightOption<W>,
44{
45 fn as_ref(&self) -> &Self {
46 self
47 }
48}
49
50impl<W: Semiring, O: UnionWeightOption<W>> Semiring for UnionWeight<W, O> {
51 type Type = Vec<W>;
52 type ReverseWeight = UnionWeight<W::ReverseWeight, O::ReverseOptions>;
53
54 fn zero() -> Self {
55 Self {
56 list: vec![],
57 ghost: PhantomData,
58 }
59 }
60
61 fn one() -> Self {
62 Self {
63 list: vec![W::one()],
64 ghost: PhantomData,
65 }
66 }
67
68 fn new(value: Self::Type) -> Self {
69 Self {
70 list: value,
71 ghost: PhantomData,
72 }
73 }
74
75 fn plus_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
76 if self.is_zero() {
77 self.set_value(rhs.borrow().value().clone());
78 } else if rhs.borrow().is_zero() {
79 } else {
81 let mut sum: UnionWeight<W, O> = UnionWeight::zero();
82 let n1 = self.list.len();
83 let n2 = rhs.borrow().list.len();
84 let mut i1 = 0;
85 let mut i2 = 0;
86 while i1 < n1 && i2 < n2 {
87 let v1 = unsafe { self.list.get_unchecked(i1) };
88 let v2 = unsafe { rhs.borrow().list.get_unchecked(i2) };
89 if O::compare(v1, v2) {
90 sum.push_back(v1.clone(), true)?;
91 i1 += 1;
92 } else {
93 sum.push_back(v2.clone(), true)?;
94 i2 += 1;
95 }
96 }
97
98 for i in i1..n1 {
99 let v1 = unsafe { self.list.get_unchecked(i) };
100 sum.push_back(v1.clone(), true)?;
101 }
102
103 for i in i2..n2 {
104 let v2 = unsafe { rhs.borrow().list.get_unchecked(i) };
105 sum.push_back(v2.clone(), true)?;
106 }
107 self.set_value(sum.take_value());
109 }
110 Ok(())
111 }
112
113 fn times_assign<P: Borrow<Self>>(&mut self, rhs: P) -> Result<()> {
114 if self.is_zero() || rhs.borrow().is_zero() {
115 self.set_value(Self::zero().take_value());
116 } else {
117 let mut prod1: UnionWeight<W, O> = UnionWeight::zero();
118 for w1 in self.iter() {
119 let mut prod2: UnionWeight<W, O> = UnionWeight::zero();
120 for w2 in rhs.borrow().iter() {
121 let p = w1.times(w2)?;
122 prod2.push_back(p, true)?;
123 }
124 prod1.plus_assign(prod2)?;
125 }
126 self.set_value(prod1.take_value());
127 }
128 Ok(())
129 }
130
131 fn approx_equal<P: Borrow<Self>>(&self, rhs: P, delta: f32) -> bool {
132 if self.len() != rhs.borrow().len() {
133 return false;
134 }
135 let it1 = self.iter();
136 let it2 = rhs.borrow().iter();
137 for (w1, w2) in it1.zip(it2) {
138 if !w1.approx_equal(w2, delta) {
139 return false;
140 }
141 }
142 true
143 }
144
145 fn value(&self) -> &Self::Type {
146 &self.list
147 }
148
149 fn take_value(self) -> Self::Type {
150 self.list
151 }
152
153 fn set_value(&mut self, value: Self::Type) {
154 self.list = value;
155 }
156
157 fn reverse(&self) -> Result<Self::ReverseWeight> {
158 let mut rw = Self::ReverseWeight::zero();
159 for v in self.iter() {
160 rw.push_back(v.reverse()?, false)?;
161 }
162 rw.list.sort_by(|v1, v2| {
163 if O::ReverseOptions::compare(v1, v2) {
164 Ordering::Less
165 } else {
166 Ordering::Greater
167 }
168 });
169 Ok(rw)
170 }
171
172 fn properties() -> SemiringProperties {
173 W::properties()
174 & (SemiringProperties::LEFT_SEMIRING
175 | SemiringProperties::RIGHT_SEMIRING
176 | SemiringProperties::COMMUTATIVE
177 | SemiringProperties::IDEMPOTENT)
178 }
179}
180
181impl<W: Semiring, O: UnionWeightOption<W>> UnionWeight<W, O> {
182 fn push_back(&mut self, weight: W, sorted: bool) -> Result<()> {
183 if self.list.is_empty() {
184 self.list.push(weight);
185 } else if sorted {
186 let n = self.list.len();
187 let back = &mut self.list[n - 1];
188 if O::compare(back, &weight) {
189 self.list.push(weight);
190 } else {
191 *back = O::merge(back, &weight)?;
192 }
193 } else {
194 let first = &mut self.list[0];
195 if O::compare(first, &weight) {
196 self.list.push(weight);
197 } else {
198 let first_cloned = first.clone();
199 *first = weight;
200 self.list.push(first_cloned);
201 }
202 }
203 Ok(())
204 }
205
206 pub fn len(&self) -> usize {
207 self.list.len()
208 }
209
210 pub fn is_empty(&self) -> bool {
211 self.list.is_empty()
212 }
213
214 pub fn iter(&self) -> impl Iterator<Item = &W> {
215 self.list.iter()
216 }
217}
218
219impl<W, O> WeaklyDivisibleSemiring for UnionWeight<W, O>
220where
221 W: WeaklyDivisibleSemiring,
222 O: UnionWeightOption<W>,
223{
224 fn divide_assign(&mut self, rhs: &Self, divide_type: DivideType) -> Result<()> {
225 if self.is_zero() || rhs.is_zero() {
226 self.list.clear();
227 }
228 let mut quot = Self::zero();
229 if self.len() == 1 {
230 for v in rhs.list.iter().rev() {
231 quot.push_back(self.list[0].divide(v, divide_type)?, true)?;
232 }
233 } else if rhs.len() == 1 {
234 for v in self.list.iter() {
235 quot.push_back(v.divide(&rhs.list[0], divide_type)?, true)?;
236 }
237 } else {
238 bail!("Expected at least of the two parameters to have a single element");
239 }
240 self.set_value(quot.take_value());
241 Ok(())
242 }
243}
244
245impl<W, O> WeightQuantize for UnionWeight<W, O>
246where
247 W: WeightQuantize,
248 O: UnionWeightOption<W>,
249{
250 fn quantize_assign(&mut self, delta: f32) -> Result<()> {
251 let v: Vec<_> = self.list.drain(..).collect();
252 for mut e in v {
253 e.quantize_assign(delta)?;
254 self.push_back(e.quantize(delta)?, true)?;
255 }
256 Ok(())
257 }
258}
259
260impl<W, O> UnionWeight<W, O>
261where
262 W: SerializableSemiring,
263 O: UnionWeightOption<W>,
264{
265 fn parse_text_empty_set(i: &str) -> IResult<&str, Self> {
266 let (i, _) = tag("EmptySet")(i)?;
267 Ok((i, Self::zero()))
268 }
269
270 fn parse_text_non_empty_set(i: &str) -> IResult<&str, Self> {
271 let (i, weights) = separated_list0(tag(","), W::parse_text)(i)?;
272 Ok((i, Self::new(weights)))
273 }
274}
275
276impl<W, O> Display for UnionWeight<W, O>
277where
278 W: SerializableSemiring,
279 O: UnionWeightOption<W>,
280{
281 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
282 if self.is_empty() {
283 write!(f, "EmptySet")?;
284 } else {
285 for (idx, w) in self.list.iter().enumerate() {
286 if idx > 0 {
287 write!(f, ",")?;
288 }
289 write!(f, "{}", w)?;
290 }
291 }
292 Ok(())
293 }
294}
295
296impl<W, O> SerializableSemiring for UnionWeight<W, O>
297where
298 W: SerializableSemiring,
299 O: UnionWeightOption<W>,
300{
301 fn weight_type() -> String {
302 format!("{}_union", W::weight_type())
303 }
304
305 fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
306 let (i, n) = parse_bin_i32(i)?;
307 let (i, labels) = count(W::parse_binary, n as usize)(i)?;
308 Ok((i, Self::new(labels)))
309 }
310
311 fn write_binary<F: Write>(&self, file: &mut F) -> Result<()> {
312 write_bin_i32(file, self.list.len() as i32)?;
313 for w in self.list.iter() {
314 w.write_binary(file)?;
315 }
316 Ok(())
317 }
318
319 fn parse_text(i: &str) -> IResult<&str, Self> {
320 let (i, res) = alt((Self::parse_text_empty_set, Self::parse_text_non_empty_set))(i)?;
321 Ok((i, res))
322 }
323}
324
325impl<W: Semiring, O: UnionWeightOption<W>> ReverseBack<UnionWeight<W, O>>
326 for <UnionWeight<W, O> as Semiring>::ReverseWeight
327{
328 fn reverse_back(&self) -> Result<UnionWeight<W, O>> {
329 let res = self.reverse()?;
330 unsafe { Ok(std::mem::transmute(res)) }
332 }
333}