tract_hir/infer/
fact.rs

1use std::convert::TryFrom;
2use std::fmt;
3use std::sync::Arc;
4
5use super::factoid::*;
6use crate::internal::*;
7
8/// Partial information about a tensor.
9///
10/// The task of the analyser is to tag every edge in the graph with information
11/// about the tensors that flow through it - specifically their datum_type, their
12/// shape and possibly their value. During the analysis, however, we might only
13/// know some of that information (say, for instance, that an edge only carries
14/// tensors of rank 4, but without knowing their precise dimension).
15///
16/// This is where tensor facts come in: they hold partial information about the
17/// datum_type, shape and value of tensors that might flow through an edge of the
18/// graph. The analyser will first tag each edge with a fact, starting with the
19/// most general one and specializing it at each iteration. Eventually, it will
20/// reach a fixed point that - hopefully - holds enough information.
21#[derive(Clone, PartialEq, Eq, Default, Hash)]
22pub struct InferenceFact {
23    pub datum_type: TypeFactoid,
24    pub shape: ShapeFactoid,
25    pub value: ValueFact,
26}
27
28impl InferenceFact {
29    /// Constructs the most general tensor fact possible.
30    pub fn new() -> InferenceFact {
31        InferenceFact::default()
32    }
33
34    pub fn any() -> InferenceFact {
35        InferenceFact::default()
36    }
37
38    pub fn dt(dt: DatumType) -> InferenceFact {
39        InferenceFact::default().with_datum_type(dt)
40    }
41
42    pub fn dt_shape<S: Into<ShapeFactoid>>(dt: DatumType, shape: S) -> InferenceFact {
43        InferenceFact::dt(dt).with_shape(shape)
44    }
45
46    pub fn shape<S: Into<ShapeFactoid>>(shape: S) -> InferenceFact {
47        InferenceFact::default().with_shape(shape)
48    }
49
50    pub fn with_datum_type(self, dt: DatumType) -> InferenceFact {
51        InferenceFact { datum_type: dt.into(), ..self }
52    }
53
54    pub fn without_datum_type(self) -> InferenceFact {
55        InferenceFact { datum_type: TypeFactoid::Any, ..self }
56    }
57
58    pub fn with_shape<S: Into<ShapeFactoid>>(self, shape: S) -> InferenceFact {
59        InferenceFact { shape: shape.into(), ..self }
60    }
61
62    pub fn format_dt_shape(&self) -> String {
63        if !self.shape.open && self.shape.dims.len() == 0 {
64            self.datum_type
65                .concretize()
66                .map(|dt| format!("{dt:?}"))
67                .unwrap_or_else(|| "?".to_string())
68        } else {
69            format!(
70                "{:?},{}",
71                self.shape,
72                self.datum_type
73                    .concretize()
74                    .map(|dt| format!("{dt:?}"))
75                    .unwrap_or_else(|| "?".to_string())
76            )
77        }
78    }
79
80    pub fn dt_shape_from_tensor(t: &Tensor) -> InferenceFact {
81        InferenceFact::dt_shape(t.datum_type(), t.shape())
82    }
83
84    pub fn without_value(self) -> InferenceFact {
85        InferenceFact { value: GenericFactoid::Any, ..self }
86    }
87}
88
89impl Factoid for InferenceFact {
90    type Concrete = Arc<Tensor>;
91
92    /// Tries to transform the fact into a concrete value.
93    fn concretize(&self) -> Option<Self::Concrete> {
94        self.value.concretize()
95    }
96
97    /// Tries to unify the fact with another fact of the same type.
98    fn unify(&self, other: &Self) -> TractResult<Self> {
99        let tensor = InferenceFact {
100            datum_type: self.datum_type.unify(&other.datum_type)?,
101            shape: self.shape.unify(&other.shape)?,
102            value: self.value.unify(&other.value)?,
103        };
104
105        trace!("Unifying {self:?} with {other:?} into {tensor:?}.");
106
107        Ok(tensor)
108    }
109}
110
111impl fmt::Debug for InferenceFact {
112    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
113        if let Some(t) = self.value.concretize() {
114            write!(formatter, "{t:?}")
115        } else {
116            write!(formatter, "{}", self.format_dt_shape())
117        }
118    }
119}
120
121use crate::infer::factoid::Factoid;
122
123impl Fact for InferenceFact {
124    fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>> {
125        Ok(Cow::Owned(TypedFact::try_from(self)?))
126    }
127
128    fn matches(&self, t: &Tensor, _symbols: Option<&SymbolValues>) -> TractResult<bool> {
129        if let Some(dt) = self.datum_type() {
130            if t.datum_type() != dt {
131                return Ok(false);
132            }
133        }
134        if let Some(shape) = self.shape.concretize() {
135            if *ShapeFact::from(t.shape()) != *shape {
136                return Ok(false);
137            }
138        }
139        if let Some(value) = self.value.concretize() {
140            if &*value != t {
141                return Ok(false);
142            }
143        }
144        Ok(true)
145    }
146
147    fn same_as(&self, other: &dyn Fact) -> bool {
148        if let Some(other) = other.downcast_ref::<Self>() {
149            self.unify(other).is_ok()
150        } else {
151            false
152        }
153    }
154
155    fn compatible_with(&self, other: &dyn Fact) -> bool {
156        if let Some(other) = other.downcast_ref::<Self>() {
157            self.unify(other).is_ok()
158        } else {
159            false
160        }
161    }
162
163    fn datum_type(&self) -> Option<DatumType> {
164        self.datum_type.concretize()
165    }
166}
167
168impl TryFrom<&InferenceFact> for TypedFact {
169    type Error = TractError;
170    fn try_from(fact: &InferenceFact) -> TractResult<TypedFact> {
171        if let (Some(datum_type), Some(shape)) =
172            (fact.datum_type.concretize(), fact.shape.concretize())
173        {
174            let shape = ShapeFact::from_dims(shape);
175            let konst = fact.value.concretize();
176            let uniform = konst.as_ref().and_then(|k| k.as_uniform()).map(Arc::new);
177            Ok(TypedFact { datum_type, shape, konst, uniform, opaque_fact: None })
178        } else {
179            bail!("Can not make a TypedFact out of {:?}", fact)
180        }
181    }
182}
183
184impl<'a> From<&'a InferenceFact> for InferenceFact {
185    fn from(t: &'a InferenceFact) -> InferenceFact {
186        t.clone()
187    }
188}
189
190impl<'a> From<&'a TypedFact> for InferenceFact {
191    fn from(t: &'a TypedFact) -> InferenceFact {
192        let mut fact = InferenceFact::dt_shape(t.datum_type, t.shape.iter());
193        if let Some(k) = &t.konst {
194            fact.value = Arc::clone(k).into();
195        }
196        fact
197    }
198}
199
200impl From<TypedFact> for InferenceFact {
201    fn from(t: TypedFact) -> InferenceFact {
202        InferenceFact::from(&t)
203    }
204}
205
206impl<'a> From<&'a Arc<Tensor>> for InferenceFact {
207    fn from(t: &'a Arc<Tensor>) -> InferenceFact {
208        InferenceFact::from(&TypedFact::from(Arc::clone(t)))
209    }
210}
211
212impl From<Arc<Tensor>> for InferenceFact {
213    fn from(t: Arc<Tensor>) -> InferenceFact {
214        InferenceFact::from(&TypedFact::from(t))
215    }
216}
217
218impl From<Tensor> for InferenceFact {
219    fn from(t: Tensor) -> InferenceFact {
220        let mut fact = InferenceFact::dt_shape(t.datum_type(), t.shape());
221        fact.value = t.into_arc_tensor().into();
222        fact
223    }
224}