tract_hir/infer/
factoid.rs

1use std::fmt;
2use std::iter::FromIterator;
3use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
4
5use tract_num_traits::Zero;
6
7use crate::internal::*;
8
9/// Partial information about any value.
10pub trait Factoid: fmt::Debug + Clone + PartialEq + Default + Hash {
11    type Concrete: fmt::Debug;
12
13    /// Tries to transform the fact into a concrete value.
14    fn concretize(&self) -> Option<Self::Concrete>;
15
16    /// Returns whether the value is fully determined.
17    fn is_concrete(&self) -> bool {
18        self.concretize().is_some()
19    }
20
21    /// Tries to unify the fact with another fact of the same type.
22    fn unify(&self, other: &Self) -> TractResult<Self>;
23
24    /// Tries to unify the fact with another fact of the same type and update
25    /// self.
26    ///
27    /// Returns true if it actually changed something.
28    fn unify_with(&mut self, other: &Self) -> TractResult<bool> {
29        let new = self.unify(other)?;
30        let mut changed = false;
31        if &new != self {
32            changed = true;
33            *self = new;
34        }
35        Ok(changed)
36    }
37
38    /// Tries to unify the fact with another fact of the same type and update
39    /// both of them.
40    ///
41    /// Returns true if it actually changed something.
42    fn unify_with_mut(&mut self, other: &mut Self) -> TractResult<bool> {
43        let new = self.unify(other)?;
44        let mut changed = false;
45        if &new != self {
46            changed = true;
47            *self = new.clone();
48        }
49        if &new != other {
50            changed = true;
51            *other = new;
52        }
53        Ok(changed)
54    }
55
56    /// Tries to unify all facts in the list.
57    ///
58    ///
59    /// Returns true if it actually changed something.
60    fn unify_all(facts: &mut [&mut Self]) -> TractResult<bool> {
61        let mut overall_changed = false;
62        loop {
63            let mut changed = false;
64            for i in 0..facts.len() - 1 {
65                for j in i + 1..facts.len() {
66                    let (left, right) = facts.split_at_mut(j);
67                    let c = left[i].unify_with(right[0])?;
68                    changed = changed || c;
69                    overall_changed = changed || c;
70                }
71            }
72            if !changed {
73                return Ok(overall_changed);
74            }
75        }
76    }
77}
78
79/// Partial information about a value of type T.
80#[derive(Clone, PartialEq, Eq, Hash)]
81pub enum GenericFactoid<T: fmt::Debug + Clone + PartialEq + Hash> {
82    Only(T),
83    Any,
84}
85
86// if T is not Default, autoderive wont work
87#[allow(clippy::derivable_impls)]
88impl<T: fmt::Debug + Clone + PartialEq + Hash> Default for GenericFactoid<T> {
89    fn default() -> Self {
90        GenericFactoid::Any
91    }
92}
93
94impl<T: Copy + Clone + fmt::Debug + PartialEq + Hash> Copy for GenericFactoid<T> {}
95
96impl<T: fmt::Debug + Clone + PartialEq + Hash> Factoid for GenericFactoid<T> {
97    type Concrete = T;
98
99    /// Tries to transform the fact into a concrete value.
100    fn concretize(&self) -> Option<T> {
101        match self {
102            GenericFactoid::Any => None,
103            GenericFactoid::Only(m) => Some(m.clone()),
104        }
105    }
106
107    /// Tries to unify the fact with another fact of the same type.
108    fn unify(&self, other: &Self) -> TractResult<Self> {
109        let fact = match (self, other) {
110            (_, GenericFactoid::Any) => self.clone(),
111            (GenericFactoid::Any, _) => other.clone(),
112            _ if self == other => self.clone(),
113            _ => bail!("Impossible to unify {:?} with {:?}.", self, other),
114        };
115
116        Ok(fact)
117    }
118}
119
120impl<T: fmt::Debug + Clone + PartialEq + Hash> From<T> for GenericFactoid<T> {
121    fn from(t: T) -> Self {
122        GenericFactoid::Only(t)
123    }
124}
125
126impl<T: fmt::Display + fmt::Debug + Clone + PartialEq + Hash> fmt::Display for GenericFactoid<T> {
127    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
128        match self {
129            GenericFactoid::Any => write!(formatter, "?"),
130            GenericFactoid::Only(u) => write!(formatter, "{u}"),
131        }
132    }
133}
134
135impl<T: fmt::Debug + Clone + PartialEq + Hash> fmt::Debug for GenericFactoid<T> {
136    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
137        match self {
138            GenericFactoid::Any => write!(formatter, "?"),
139            GenericFactoid::Only(u) => write!(formatter, "{u:?}"),
140        }
141    }
142}
143
144/// Partial information about a type.
145pub type TypeFactoid = GenericFactoid<DatumType>;
146
147/// Partial information about a shape.
148///
149/// A basic example of a shape fact is `shapefactoid![1, 2]`, which corresponds to
150/// the shape `[1, 2]` in Arc<Tensor>. We can use `_` in facts to denote unknown
151/// dimensions (e.g. `shapefactoid![1, 2, _]` corresponds to any shape `[1, 2, k]`
152/// with `k` a non-negative integer). We can also use `..` at the end of a fact
153/// to only specify its first dimensions, so `shapefactoid![1, 2; ..]` matches any
154/// shape that starts with `[1, 2]` (e.g. `[1, 2, i]` or `[1, 2, i, j]`), while
155/// `shapefactoid![..]` matches any shape.
156#[derive(Clone, PartialEq, Eq, Hash)]
157pub struct ShapeFactoid {
158    pub(super) open: bool,
159    pub(super) dims: TVec<GenericFactoid<TDim>>,
160}
161
162impl ShapeFactoid {
163    /// Constructs an open shape fact.
164    pub fn open(dims: TVec<DimFact>) -> ShapeFactoid {
165        ShapeFactoid { open: true, dims }
166    }
167
168    pub fn is_open(&self) -> bool {
169        self.open
170    }
171
172    /// Constructs a closed shape fact.
173    pub fn closed(dims: TVec<DimFact>) -> ShapeFactoid {
174        ShapeFactoid { open: false, dims }
175    }
176
177    pub fn rank(&self) -> IntFactoid {
178        if self.open {
179            GenericFactoid::Any
180        } else {
181            GenericFactoid::Only(self.dims.len() as i64)
182        }
183    }
184
185    pub fn ensure_rank_at_least(&mut self, n: usize) -> bool {
186        let mut changed = false;
187        while self.dims.len() <= n {
188            self.dims.push(GenericFactoid::Any);
189            changed = true;
190        }
191        changed
192    }
193
194    pub fn dim(&self, i: usize) -> Option<DimFact> {
195        self.dims().nth(i).cloned()
196    }
197
198    pub fn set_dim(&mut self, i: usize, d: TDim) -> bool {
199        let fact = GenericFactoid::Only(d.clone());
200        if self.dim(i).as_ref() == Some(&fact) {
201            return false;
202        }
203        self.dims[i] = GenericFactoid::Only(d);
204        true
205    }
206
207    pub fn dims(&self) -> impl Iterator<Item = &DimFact> {
208        self.dims.iter()
209    }
210
211    pub fn as_concrete_finite(&self) -> TractResult<Option<TVec<usize>>> {
212        if self.open {
213            return Ok(None);
214        }
215        Ok(self.dims.iter().map(|d| d.concretize().and_then(|d| d.to_usize().ok())).collect())
216    }
217
218    pub fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
219        let rank_compatible =
220            if self.is_open() { self.dims.len() <= t.rank() } else { self.dims.len() == t.rank() };
221        if !rank_compatible {
222            return Ok(false);
223        }
224
225        for i in 0..t.rank() {
226            let dim = self.dims.get(i).and_then(|el| el.concretize());
227            if let Some(dim) = dim.and_then(|dim| {
228                dim.eval(symbols.unwrap_or(&SymbolValues::default())).to_usize().ok()
229            }) {
230                if dim != t.shape()[i] {
231                    return Ok(false);
232                }
233            }
234        }
235        Ok(true)
236    }
237}
238
239impl Factoid for ShapeFactoid {
240    type Concrete = TVec<TDim>;
241
242    /// Tries to transform the fact into a `Vec<usize>`, or returns `None`.
243    fn concretize(self: &ShapeFactoid) -> Option<TVec<TDim>> {
244        if self.open {
245            return None;
246        }
247
248        let dims: TVec<_> = self.dims().filter_map(|d| d.concretize()).collect();
249
250        if dims.len() < self.dims.len() {
251            None
252        } else {
253            Some(dims)
254        }
255    }
256
257    /// Tries to unify the fact with another fact of the same type.
258    fn unify(&self, other: &Self) -> TractResult<Self> {
259        let (x, y) = (self, other);
260
261        use tract_itertools::EitherOrBoth::{Both, Left, Right};
262        use tract_itertools::Itertools;
263
264        let xi = x.dims();
265        let yi = y.dims();
266
267        let dimensions: TVec<_> = xi
268            .zip_longest(yi)
269            .map(|r| match r {
270                Both(a, b) => a.unify(b),
271                Left(d) if y.open => Ok(d.clone()),
272                Right(d) if x.open => Ok(d.clone()),
273
274                Left(_) | Right(_) => bail!(
275                    "Impossible to unify closed shapes of different rank (found {:?} and {:?}).",
276                    x,
277                    y
278                ),
279            })
280            .collect::<TractResult<_>>()
281            .with_context(|| format!("Unifying shapes {x:?} and {y:?}"))?;
282
283        if x.open && y.open {
284            Ok(ShapeFactoid::open(dimensions))
285        } else {
286            Ok(ShapeFactoid::closed(dimensions))
287        }
288    }
289}
290
291impl Default for ShapeFactoid {
292    /// Returns the most general shape fact possible.
293    fn default() -> ShapeFactoid {
294        ShapeFactoid::open(tvec![])
295    }
296}
297
298impl FromIterator<TDim> for ShapeFactoid {
299    /// Converts an iterator over usize into a closed shape.
300    fn from_iter<I: IntoIterator<Item = TDim>>(iter: I) -> ShapeFactoid {
301        ShapeFactoid::closed(iter.into_iter().map(GenericFactoid::Only).collect())
302    }
303}
304
305impl FromIterator<usize> for ShapeFactoid {
306    /// Converts an iterator over usize into a closed shape.
307    fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> ShapeFactoid {
308        ShapeFactoid::closed(iter.into_iter().map(|d| GenericFactoid::Only(d.to_dim())).collect())
309    }
310}
311
312impl<D: ToDim, I: IntoIterator<Item = D>> From<I> for ShapeFactoid {
313    fn from(it: I) -> ShapeFactoid {
314        ShapeFactoid::closed(it.into_iter().map(|d| GenericFactoid::Only(d.to_dim())).collect())
315    }
316}
317
318impl fmt::Debug for ShapeFactoid {
319    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
320        for (ix, d) in self.dims.iter().enumerate() {
321            if ix != 0 {
322                write!(formatter, ",")?
323            }
324            write!(formatter, "{d}")?;
325        }
326        if self.open {
327            if self.dims.len() == 0 {
328                write!(formatter, "..")?;
329            } else {
330                write!(formatter, ",..")?;
331            }
332        }
333        Ok(())
334    }
335}
336
337pub type DimFact = GenericFactoid<TDim>;
338
339/// Partial information about a value.
340pub type ValueFact = GenericFactoid<Arc<Tensor>>;
341
342pub type IntFactoid = GenericFactoid<i64>;
343
344impl<T> Zero for GenericFactoid<T>
345where
346    T: Add<T, Output = T> + Zero + PartialEq + Clone + ::std::fmt::Debug + Hash,
347{
348    fn zero() -> GenericFactoid<T> {
349        GenericFactoid::Only(T::zero())
350    }
351    fn is_zero(&self) -> bool {
352        match self {
353            GenericFactoid::Only(t) => t.is_zero(),
354            _ => false,
355        }
356    }
357}
358
359impl<T> Neg for GenericFactoid<T>
360where
361    T: Neg<Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
362{
363    type Output = GenericFactoid<T>;
364    fn neg(self) -> GenericFactoid<T> {
365        match self {
366            GenericFactoid::Only(t) => GenericFactoid::Only(t.neg()),
367            any => any,
368        }
369    }
370}
371
372impl<T, I> Add<I> for GenericFactoid<T>
373where
374    T: Add<T, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
375    I: Into<GenericFactoid<T>>,
376{
377    type Output = GenericFactoid<T>;
378    fn add(self, rhs: I) -> Self::Output {
379        match (self.concretize(), rhs.into().concretize()) {
380            (Some(a), Some(b)) => GenericFactoid::Only(a + b),
381            _ => GenericFactoid::Any,
382        }
383    }
384}
385
386impl<T> Sub<GenericFactoid<T>> for GenericFactoid<T>
387where
388    T: Sub<T, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
389{
390    type Output = GenericFactoid<T>;
391    fn sub(self, rhs: GenericFactoid<T>) -> Self::Output {
392        match (self.concretize(), rhs.concretize()) {
393            (Some(a), Some(b)) => GenericFactoid::Only(a - b),
394            _ => GenericFactoid::Any,
395        }
396    }
397}
398
399impl<T, R> Mul<R> for GenericFactoid<T>
400where
401    T: Mul<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
402{
403    type Output = GenericFactoid<T>;
404    fn mul(self, rhs: R) -> Self::Output {
405        if let Some(a) = self.concretize() {
406            GenericFactoid::Only(a * rhs)
407        } else {
408            GenericFactoid::Any
409        }
410    }
411}
412
413impl<T, R> Div<R> for GenericFactoid<T>
414where
415    T: Div<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
416{
417    type Output = GenericFactoid<T>;
418    fn div(self, rhs: R) -> Self::Output {
419        if let Some(a) = self.concretize() {
420            GenericFactoid::Only(a / rhs)
421        } else {
422            GenericFactoid::Any
423        }
424    }
425}
426
427impl<T, R> Rem<R> for GenericFactoid<T>
428where
429    T: Rem<R, Output = T> + PartialEq + Clone + ::std::fmt::Debug + Hash,
430{
431    type Output = GenericFactoid<T>;
432    fn rem(self, rhs: R) -> Self::Output {
433        if let Some(a) = self.concretize() {
434            GenericFactoid::Only(a % rhs)
435        } else {
436            GenericFactoid::Any
437        }
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::GenericFactoid::*;
444    use super::*;
445
446    #[test]
447    fn unify_same_datum_type() {
448        let dt = TypeFactoid::Only(DatumType::F32);
449        assert_eq!(dt.unify(&dt).unwrap(), dt);
450    }
451
452    #[test]
453    fn unify_different_datum_types_only() {
454        let dt1 = TypeFactoid::Only(DatumType::F32);
455        let dt2 = TypeFactoid::Only(DatumType::F64);
456        assert!(dt1.unify(&dt2).is_err());
457    }
458
459    #[test]
460    fn unify_different_datum_types_any_left() {
461        let dt = TypeFactoid::Only(DatumType::F32);
462        assert_eq!(TypeFactoid::Any.unify(&dt).unwrap(), dt);
463    }
464
465    #[test]
466    fn unify_different_datum_types_any_right() {
467        let dt = TypeFactoid::Only(DatumType::F32);
468        assert_eq!(dt.unify(&TypeFactoid::Any).unwrap(), dt);
469    }
470
471    #[test]
472    fn unify_same_shape_1() {
473        let s = ShapeFactoid::closed(tvec![]);
474        assert_eq!(s.unify(&s).unwrap(), s);
475    }
476
477    #[test]
478    fn unify_same_shape_2() {
479        let s = ShapeFactoid::closed(tvec![Any]);
480        assert_eq!(s.unify(&s).unwrap(), s);
481    }
482
483    #[test]
484    fn unify_same_shape_3() {
485        let s = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
486        assert_eq!(s.unify(&s).unwrap(), s);
487    }
488
489    #[test]
490    fn unify_different_shapes_1() {
491        let s1 = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
492        let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
493        assert!(s1.unify(&s2).is_err());
494    }
495
496    #[test]
497    fn unify_different_shapes_2() {
498        let s1 = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into())]);
499        let s2 = ShapeFactoid::closed(tvec![Any]);
500        assert!(s1.unify(&s2).is_err());
501    }
502
503    #[test]
504    fn unify_different_shapes_3() {
505        let s1 = ShapeFactoid::open(tvec![Only(1.into()), Only(2.into())]);
506        let s2 = ShapeFactoid::closed(tvec![Any]);
507        assert!(s1.unify(&s2).is_err());
508    }
509
510    #[test]
511    fn unify_different_shapes_4() {
512        let s1 = ShapeFactoid::closed(tvec![Any]);
513        let s2 = ShapeFactoid::closed(tvec![Any]);
514        let sr = ShapeFactoid::closed(tvec![Any]);
515        assert_eq!(s1.unify(&s2).unwrap(), sr);
516    }
517
518    #[test]
519    fn unify_different_shapes_5() {
520        let s1 = ShapeFactoid::closed(tvec![Any]);
521        let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
522        let sr = ShapeFactoid::closed(tvec![Only(1.into())]);
523        assert_eq!(s1.unify(&s2).unwrap(), sr);
524    }
525
526    #[test]
527    fn unify_different_shapes_6() {
528        let s1 = ShapeFactoid::open(tvec![]);
529        let s2 = ShapeFactoid::closed(tvec![Only(1.into())]);
530        let sr = ShapeFactoid::closed(tvec![Only(1.into())]);
531        assert_eq!(s1.unify(&s2).unwrap(), sr);
532    }
533
534    #[test]
535    fn unify_different_shapes_7() {
536        let s1 = ShapeFactoid::open(tvec![Any, Only(2.into())]);
537        let s2 = ShapeFactoid::closed(tvec![Only(1.into()), Any, Any]);
538        let sr = ShapeFactoid::closed(tvec![Only(1.into()), Only(2.into()), Any]);
539        assert_eq!(s1.unify(&s2).unwrap(), sr);
540    }
541
542    #[test]
543    fn unify_same_value() {
544        let t = ValueFact::Only(rctensor0(12f32));
545        assert_eq!(t.unify(&t).unwrap(), t);
546    }
547
548    #[test]
549    fn unify_different_values_only() {
550        let t1 = ValueFact::Only(rctensor1(&[12f32]));
551        let t2 = ValueFact::Only(rctensor1(&[12f32, 42.0]));
552        assert!(t1.unify(&t2).is_err());
553    }
554
555    #[test]
556    fn unify_different_values_any_left() {
557        let t1 = ValueFact::Only(rctensor1(&[12f32]));
558        assert_eq!(ValueFact::Any.unify(&t1).unwrap(), t1);
559    }
560
561    #[test]
562    fn unify_different_values_any_right() {
563        let t1 = ValueFact::Only(rctensor1(&[12f32]));
564        assert_eq!(t1.unify(&ValueFact::Any).unwrap(), t1);
565    }
566}