1use std::fmt;
2use std::iter::FromIterator;
3use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
4
5use tract_num_traits::Zero;
6
7use crate::internal::*;
8
9pub trait Factoid: fmt::Debug + Clone + PartialEq + Default + Hash {
11 type Concrete: fmt::Debug;
12
13 fn concretize(&self) -> Option<Self::Concrete>;
15
16 fn is_concrete(&self) -> bool {
18 self.concretize().is_some()
19 }
20
21 fn unify(&self, other: &Self) -> TractResult<Self>;
23
24 fn unify_with(&mut self, other: &Self) -> TractResult<bool> {
29 let new = self.unify(other)?;
30 let mut changed = false;
31 if &new != self {
32 changed = true;
33 *self = new;
34 }
35 Ok(changed)
36 }
37
38 fn unify_with_mut(&mut self, other: &mut Self) -> TractResult<bool> {
43 let new = self.unify(other)?;
44 let mut changed = false;
45 if &new != self {
46 changed = true;
47 *self = new.clone();
48 }
49 if &new != other {
50 changed = true;
51 *other = new;
52 }
53 Ok(changed)
54 }
55
56 fn unify_all(facts: &mut [&mut Self]) -> TractResult<bool> {
61 let mut overall_changed = false;
62 loop {
63 let mut changed = false;
64 for i in 0..facts.len() - 1 {
65 for j in i + 1..facts.len() {
66 let (left, right) = facts.split_at_mut(j);
67 let c = left[i].unify_with(right[0])?;
68 changed = changed || c;
69 overall_changed = changed || c;
70 }
71 }
72 if !changed {
73 return Ok(overall_changed);
74 }
75 }
76 }
77}
78
79#[derive(Clone, PartialEq, Eq, Hash)]
81pub enum GenericFactoid<T: fmt::Debug + Clone + PartialEq + Hash> {
82 Only(T),
83 Any,
84}
85
86#[allow(clippy::derivable_impls)]
88impl<T: fmt::Debug + Clone + PartialEq + Hash> Default for GenericFactoid<T> {
89 fn default() -> Self {
90 GenericFactoid::Any
91 }
92}
93
94impl<T: Copy + Clone + fmt::Debug + PartialEq + Hash> Copy for GenericFactoid<T> {}
95
96impl<T: fmt::Debug + Clone + PartialEq + Hash> Factoid for GenericFactoid<T> {
97 type Concrete = T;
98
99 fn concretize(&self) -> Option<T> {
101 match self {
102 GenericFactoid::Any => None,
103 GenericFactoid::Only(m) => Some(m.clone()),
104 }
105 }
106
107 fn unify(&self, other: &Self) -> TractResult<Self> {
109 let fact = match (self, other) {
110 (_, GenericFactoid::Any) => self.clone(),
111 (GenericFactoid::Any, _) => other.clone(),
112 _ if self == other => self.clone(),
113 _ => bail!("Impossible to unify {:?} with {:?}.", self, other),
114 };
115
116 Ok(fact)
117 }
118}
119
120impl<T: fmt::Debug + Clone + PartialEq + Hash> From<T> for GenericFactoid<T> {
121 fn from(t: T) -> Self {
122 GenericFactoid::Only(t)
123 }
124}
125
126impl<T: fmt::Display + fmt::Debug + Clone + PartialEq + Hash> fmt::Display for GenericFactoid<T> {
127 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
128 match self {
129 GenericFactoid::Any => write!(formatter, "?"),
130 GenericFactoid::Only(u) => write!(formatter, "{u}"),
131 }
132 }
133}
134
135impl<T: fmt::Debug + Clone + PartialEq + Hash> fmt::Debug for GenericFactoid<T> {
136 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
137 match self {
138 GenericFactoid::Any => write!(formatter, "?"),
139 GenericFactoid::Only(u) => write!(formatter, "{u:?}"),
140 }
141 }
142}
143
144pub type TypeFactoid = GenericFactoid<DatumType>;
146
147#[derive(Clone, PartialEq, Eq, Hash)]
157pub struct ShapeFactoid {
158 pub(super) open: bool,
159 pub(super) dims: TVec<GenericFactoid<TDim>>,
160}
161
162impl ShapeFactoid {
163 pub fn open(dims: TVec<DimFact>) -> ShapeFactoid {
165 ShapeFactoid { open: true, dims }
166 }
167
168 pub fn is_open(&self) -> bool {
169 self.open
170 }
171
172 pub fn closed(dims: TVec<DimFact>) -> ShapeFactoid {
174 ShapeFactoid { open: false, dims }
175 }
176
177 pub fn rank(&self) -> IntFactoid {
178 if self.open {
179 GenericFactoid::Any
180 } else {
181 GenericFactoid::Only(self.dims.len() as i64)
182 }
183 }
184
185 pub fn ensure_rank_at_least(&mut self, n: usize) -> bool {
186 let mut changed = false;
187 while self.dims.len() <= n {
188 self.dims.push(GenericFactoid::Any);
189 changed = true;
190 }
191 changed
192 }
193
194 pub fn dim(&self, i: usize) -> Option<DimFact> {
195 self.dims().nth(i).cloned()
196 }
197
198 pub fn set_dim(&mut self, i: usize, d: TDim) -> bool {
199 let fact = GenericFactoid::Only(d.clone());
200 if self.dim(i).as_ref() == Some(&fact) {
201 return false;
202 }
203 self.dims[i] = GenericFactoid::Only(d);
204 true
205 }
206
207 pub fn dims(&self) -> impl Iterator<Item = &DimFact> {
208 self.dims.iter()
209 }
210
211 pub fn as_concrete_finite(&self) -> TractResult<Option<TVec<usize>>> {
212 if self.open {
213 return Ok(None);
214 }
215 Ok(self.dims.iter().map(|d| d.concretize().and_then(|d| d.to_usize().ok())).collect())
216 }
217
218 pub fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
219 let rank_compatible =
220 if self.is_open() { self.dims.len() <= t.rank() } else { self.dims.len() == t.rank() };
221 if !rank_compatible {
222 return Ok(false);
223 }
224
225 for i in 0..t.rank() {
226 let dim = self.dims.get(i).and_then(|el| el.concretize());
227 if let Some(dim) = dim.and_then(|dim| {
228 dim.eval(symbols.unwrap_or(&SymbolValues::default())).to_usize().ok()
229 }) {
230 if dim != t.shape()[i] {
231 return Ok(false);
232 }
233 }
234 }
235 Ok(true)
236 }
237}
238
239impl Factoid for ShapeFactoid {
240 type Concrete = TVec<TDim>;
241
242 fn concretize(self: &ShapeFactoid) -> Option<TVec<TDim>> {
244 if self.open {
245 return None;
246 }
247
248 let dims: TVec<_> = self.dims().filter_map(|d| d.concretize()).collect();
249
250 if dims.len() < self.dims.len() {
251 None
252 } else {
253 Some(dims)
254 }
255 }
256
257 fn unify(&self, other: &Self) -> TractResult<Self> {
259 let (x, y) = (self, other);
260
261 use tract_itertools::EitherOrBoth::{Both, Left, Right};
262 use tract_itertools::Itertools;
263
264 let xi = x.dims();
265 let yi = y.dims();
266
267 let dimensions: TVec<_> = xi
268 .zip_longest(yi)
269 .map(|r| match r {
270 Both(a, b) => a.unify(b),
271 Left(d) if y.open => Ok(d.clone()),
272 Right(d) if x.open => Ok(d.clone()),
273
274 Left(_) | Right(_) => bail!(
275 "Impossible to unify closed shapes of different rank (found {:?} and {:?}).",
276 x,
277 y
278 ),
279 })
280 .collect::<TractResult<_>>()
281 .with_context(|| format!("Unifying shapes {x:?} and {y:?}"))?;
282
283 if x.open && y.open {
284 Ok(ShapeFactoid::open(dimensions))
285 } else {
286 Ok(ShapeFactoid::closed(dimensions))
287 }
288 }
289}
290
291impl Default for ShapeFactoid {
292 fn default() -> ShapeFactoid {
294 ShapeFactoid::open(tvec![])
295 }
296}
297
298impl FromIterator<TDim> for ShapeFactoid {
299 fn from_iter<I: IntoIterator<Item = TDim>>(iter: I) -> ShapeFactoid {
301 ShapeFactoid::closed(iter.into_iter().map(GenericFactoid::Only).collect())
302 }
303}
304
305impl FromIterator<usize> for ShapeFactoid {
306 fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> ShapeFactoid {
308 ShapeFactoid::closed(iter.into_iter().map(|d| GenericFactoid::Only(d.to_dim())).collect())
309 }
310}
311
312impl<D: ToDim, I: IntoIterator<Item = D>> From<I> for ShapeFactoid {
313 fn from(it: I) -> ShapeFactoid {
314 ShapeFactoid::closed(it.into_iter().map(|d| GenericFactoid::Only(d.to_dim())).collect())
315 }
316}
317
318impl fmt::Debug for ShapeFactoid {
319 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
320 for (ix, d) in self.dims.iter().enumerate() {
321 if ix != 0 {
322 write!(formatter, ",")?
323 }
324 write!(formatter, "{d}")?;
325 }
326 if self.open {
327 if self.dims.len() == 0 {
328 write!(formatter, "..")?;
329 } else {
330 write!(formatter, ",..")?;
331 }
332 }
333 Ok(())
334 }
335}
336
337pub type DimFact = GenericFactoid<TDim>;
338
339pub type ValueFact = GenericFactoid<Arc<Tensor>>;
341
342pub type IntFactoid = GenericFactoid<i64>;
343
344impl<T> Zero for GenericFactoid<T>
345where
346 T: Add<T, Output = T> + Zero + PartialEq + Clone + ::std::fmt::Debug + Hash,
347{
348 fn zero() -> GenericFactoid<T> {
349 GenericFactoid::Only(T::zero())
350 }
351 fn is_zero(&self) -> bool {
352 match self {
353 GenericFactoid::Only(t) => t.is_zero(),
354 _ => false,
355 }
356 }
357}
358
359impl<T> Neg for GenericFactoid<T>
360where
361 T: Neg<Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
362{
363 type Output = GenericFactoid<T>;
364 fn neg(self) -> GenericFactoid<T> {
365 match self {
366 GenericFactoid::Only(t) => GenericFactoid::Only(t.neg()),
367 any => any,
368 }
369 }
370}
371
372impl<T, I> Add<I> for GenericFactoid<T>
373where
374 T: Add<T, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
375 I: Into<GenericFactoid<T>>,
376{
377 type Output = GenericFactoid<T>;
378 fn add(self, rhs: I) -> Self::Output {
379 match (self.concretize(), rhs.into().concretize()) {
380 (Some(a), Some(b)) => GenericFactoid::Only(a + b),
381 _ => GenericFactoid::Any,
382 }
383 }
384}
385
386impl<T> Sub<GenericFactoid<T>> for GenericFactoid<T>
387where
388 T: Sub<T, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
389{
390 type Output = GenericFactoid<T>;
391 fn sub(self, rhs: GenericFactoid<T>) -> Self::Output {
392 match (self.concretize(), rhs.concretize()) {
393 (Some(a), Some(b)) => GenericFactoid::Only(a - b),
394 _ => GenericFactoid::Any,
395 }
396 }
397}
398
399impl<T, R> Mul<R> for GenericFactoid<T>
400where
401 T: Mul<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
402{
403 type Output = GenericFactoid<T>;
404 fn mul(self, rhs: R) -> Self::Output {
405 if let Some(a) = self.concretize() {
406 GenericFactoid::Only(a * rhs)
407 } else {
408 GenericFactoid::Any
409 }
410 }
411}
412
413impl<T, R> Div<R> for GenericFactoid<T>
414where
415 T: Div<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
416{
417 type Output = GenericFactoid<T>;
418 fn div(self, rhs: R) -> Self::Output {
419 if let Some(a) = self.concretize() {
420 GenericFactoid::Only(a / rhs)
421 } else {
422 GenericFactoid::Any
423 }
424 }
425}
426
427impl<T, R> Rem<R> for GenericFactoid<T>
428where
429 T: Rem<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
430{
431 type Output = GenericFactoid<T>;
432 fn rem(self, rhs: R) -> Self::Output {
433 if let Some(a) = self.concretize() {
434 GenericFactoid::Only(a % rhs)
435 } else {
436 GenericFactoid::Any
437 }
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::GenericFactoid::*;
444 use super::*;
445
446 #[test]
447 fn unify_same_datum_type() {
448 let dt = TypeFactoid::Only(DatumType::F32);
449 assert_eq!(dt.unify(&dt).unwrap(), dt);
450 }
451
452 #[test]
453 fn unify_different_datum_types_only() {
454 let dt1 = TypeFactoid::Only(DatumType::F32);
455 let dt2 = TypeFactoid::Only(DatumType::F64);
456 assert!(dt1.unify(&dt2).is_err());
457 }
458
459 #[test]
460 fn unify_different_datum_types_any_left() {
461 let dt = TypeFactoid::Only(DatumType::F32);
462 assert_eq!(TypeFactoid::Any.unify(&dt).unwrap(), dt);
463 }
464
465 #[test]
466 fn unify_different_datum_types_any_right() {
467 let dt = TypeFactoid::Only(DatumType::F32);
468 assert_eq!(dt.unify(&TypeFactoid::Any).unwrap(), dt);
469 }
470
471 #[test]
472 fn unify_same_shape_1() {
473 let s = ShapeFactoid::closed(tvec![]);
474 assert_eq!(s.unify(&s).unwrap(), s);
475 }
476
477 #[test]
478 fn unify_same_shape_2() {
479 let s = ShapeFactoid::closed(tvec![Any]);
480 assert_eq!(s.unify(&s).unwrap(), s);
481 }
482
483 #[test]
484 fn unify_same_shape_3() {
485 let s = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
486 assert_eq!(s.unify(&s).unwrap(), s);
487 }
488
489 #[test]
490 fn unify_different_shapes_1() {
491 let s1 = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
492 let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
493 assert!(s1.unify(&s2).is_err());
494 }
495
496 #[test]
497 fn unify_different_shapes_2() {
498 let s1 = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
499 let s2 = ShapeFactoid::closed(tvec![Any]);
500 assert!(s1.unify(&s2).is_err());
501 }
502
503 #[test]
504 fn unify_different_shapes_3() {
505 let s1 = ShapeFactoid::open(tvec![Only(1.into()), Only(2.into())]);
506 let s2 = ShapeFactoid::closed(tvec![Any]);
507 assert!(s1.unify(&s2).is_err());
508 }
509
510 #[test]
511 fn unify_different_shapes_4() {
512 let s1 = ShapeFactoid::closed(tvec![Any]);
513 let s2 = ShapeFactoid::closed(tvec![Any]);
514 let sr = ShapeFactoid::closed(tvec![Any]);
515 assert_eq!(s1.unify(&s2).unwrap(), sr);
516 }
517
518 #[test]
519 fn unify_different_shapes_5() {
520 let s1 = ShapeFactoid::closed(tvec![Any]);
521 let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
522 let sr = ShapeFactoid::closed(tvec![Only(1.into())]);
523 assert_eq!(s1.unify(&s2).unwrap(), sr);
524 }
525
526 #[test]
527 fn unify_different_shapes_6() {
528 let s1 = ShapeFactoid::open(tvec![]);
529 let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
530 let sr = ShapeFactoid::closed(tvec![Only(1.into())]);
531 assert_eq!(s1.unify(&s2).unwrap(), sr);
532 }
533
534 #[test]
535 fn unify_different_shapes_7() {
536 let s1 = ShapeFactoid::open(tvec![Any, Only(2.into())]);
537 let s2 = ShapeFactoid::closed(tvec![Only(1.into()), Any, Any]);
538 let sr = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into()), Any]);
539 assert_eq!(s1.unify(&s2).unwrap(), sr);
540 }
541
542 #[test]
543 fn unify_same_value() {
544 let t = ValueFact::Only(rctensor0(12f32));
545 assert_eq!(t.unify(&t).unwrap(), t);
546 }
547
548 #[test]
549 fn unify_different_values_only() {
550 let t1 = ValueFact::Only(rctensor1(&[12f32]));
551 let t2 = ValueFact::Only(rctensor1(&[12f32, 42.0]));
552 assert!(t1.unify(&t2).is_err());
553 }
554
555 #[test]
556 fn unify_different_values_any_left() {
557 let t1 = ValueFact::Only(rctensor1(&[12f32]));
558 assert_eq!(ValueFact::Any.unify(&t1).unwrap(), t1);
559 }
560
561 #[test]
562 fn unify_different_values_any_right() {
563 let t1 = ValueFact::Only(rctensor1(&[12f32]));
564 assert_eq!(t1.unify(&ValueFact::Any).unwrap(), t1);
565 }
566}