tract_core/model/
fact.rs

1//! Partial and complete tensor types representations.
2use crate::internal::*;
3use downcast_rs::Downcast;
4use std::fmt;
5
6#[derive(Clone, PartialEq, Eq, Hash)]
7pub struct ShapeFact {
8    dims: TVec<TDim>,
9    concrete: Option<TVec<usize>>,
10}
11
12impl ShapeFact {
13    #[inline]
14    pub fn rank(&self) -> usize {
15        self.dims.len()
16    }
17
18    fn compute_concrete(&mut self) {
19        assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
20        self.concrete =
21            self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
22    }
23
24    /// Shape of the tensor, unless it has symbolic dimensions.
25    #[inline]
26    pub fn as_concrete(&self) -> Option<&[usize]> {
27        self.concrete.as_deref()
28    }
29
30    /// Do we have a symbol-less value ?
31    #[inline]
32    pub fn is_concrete(&self) -> bool {
33        self.concrete.is_some()
34    }
35
36    /// Convert the shape to an array of extended dimensions.
37    #[inline]
38    pub fn to_tvec(&self) -> TVec<TDim> {
39        self.dims.clone()
40    }
41
42    /// Compute the volume of the tensor.
43    #[inline]
44    pub fn volume(&self) -> TDim {
45        self.dims.iter().product()
46    }
47
48    #[inline]
49    pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<ShapeFact>> {
50        if self.is_concrete() {
51            Ok(Cow::Borrowed(self))
52        } else {
53            Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
54        }
55    }
56
57    #[inline]
58    pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<usize>>> {
59        if let Some(c) = &self.concrete {
60            Ok(Cow::Borrowed(c))
61        } else {
62            Ok(Cow::Owned(
63                self.iter()
64                    .map(|d| d.eval_to_i64(values).map(|d| d as usize))
65                    .collect::<TractResult<TVec<_>>>()?,
66            ))
67        }
68    }
69
70    #[inline]
71    pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<isize>>> {
72        if let Some(c) = &self.concrete {
73            #[allow(unknown_lints, clippy::missing_transmute_annotations)]
74            // TVec<usize> -> TVec<isize>
75            Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
76        } else {
77            Ok(Cow::Owned(
78                self.iter()
79                    .map(|d| d.eval_to_i64(values).map(|d| d as isize))
80                    .collect::<TractResult<TVec<_>>>()?,
81            ))
82        }
83    }
84
85    pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
86        let mut dims =
87            ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
88        dims.compute_concrete();
89        dims
90    }
91
92    pub fn dims(&self) -> &[TDim] {
93        self.dims.as_slice()
94    }
95
96    pub fn set(&mut self, ix: usize, dim: TDim) {
97        self.dims[ix] = dim;
98        self.compute_concrete();
99    }
100
101    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
102        self.dims.insert(axis, 1.into());
103        if let Some(concrete) = &mut self.concrete {
104            concrete.insert(axis, 1);
105        }
106        Ok(())
107    }
108
109    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
110        self.dims.remove(axis);
111        if let Some(concrete) = &mut self.concrete {
112            concrete.remove(axis);
113        } else {
114            self.compute_concrete();
115        };
116        Ok(())
117    }
118
119    pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
120        if self.rank() == _other.rank() {
121            self.dims
122                .iter()
123                .zip(_other.dims.iter())
124                .all(|(dim, other_dim)| dim.compatible_with(other_dim))
125        } else {
126            false
127        }
128    }
129
130    pub fn scalar() -> ShapeFact {
131        let void: &[usize] = &[];
132        Self::from(void)
133    }
134
135    pub fn consistent(&self) -> TractResult<()> {
136        ensure!(
137            self.concrete
138                == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
139        );
140        Ok(())
141    }
142}
143
144impl std::ops::Deref for ShapeFact {
145    type Target = [TDim];
146    fn deref(&self) -> &[TDim] {
147        &self.dims
148    }
149}
150
151impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
152    fn from(it: T) -> ShapeFact {
153        ShapeFact::from_dims(it)
154    }
155}
156
157/// Type information about a tensor: shape, and element type, in various state
158/// of determination.
159pub trait Fact: std::fmt::Debug + Downcast + dyn_clone::DynClone + Send + Sync + 'static {
160    fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>>;
161
162    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
163        self.to_typed_fact()?.matches(t, symbols)
164    }
165
166    fn same_as(&self, _other: &dyn Fact) -> bool;
167
168    /// Ensure that self is same type as another fact or a subtype
169    fn compatible_with(&self, _other: &dyn Fact) -> bool;
170
171    fn datum_type(&self) -> Option<DatumType>;
172}
173
174impl_downcast!(Fact);
175dyn_clone::clone_trait_object!(Fact);
176
177impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
178    fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
179        ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
180    }
181}
182
183impl fmt::Debug for ShapeFact {
184    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
185        use tract_itertools::Itertools;
186        write!(fmt, "{}", self.iter().join(","))
187    }
188}
189
190impl AsRef<[TDim]> for ShapeFact {
191    fn as_ref(&self) -> &[TDim] {
192        &self.dims
193    }
194}
195
196/// Fully determined tensor information for TypedModel.
197#[derive(Clone, PartialEq, Eq, Hash)]
198pub struct TypedFact {
199    /// tensor element type
200    pub datum_type: DatumType,
201    /// tensor shape
202    pub shape: ShapeFact,
203    /// optional constant value
204    pub konst: Option<Arc<Tensor>>,
205    /// optional uniform value
206    pub uniform: Option<Arc<Tensor>>,
207    /// optional opaque fact
208    pub opaque_fact: Option<Box<dyn OpaqueFact>>,
209}
210
211impl TypedFact {
212    pub fn scalar<T>() -> TypedFact
213    where
214        T: Datum,
215    {
216        Self::dt_scalar(T::datum_type())
217    }
218
219    pub fn shape<T, S>(shape: S) -> TypedFact
220    where
221        T: Datum,
222        S: Into<ShapeFact>,
223    {
224        Self::dt_shape(T::datum_type(), shape)
225    }
226
227    pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
228        TypedFact {
229            datum_type: t.datum_type(),
230            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
231            uniform: None,
232            konst: None,
233            opaque_fact: None,
234        }
235    }
236
237    pub fn mem_size(&self) -> TDim {
238        self.shape.volume() * self.datum_type.size_of()
239            + self.opaque_fact.as_ref().map(|it| it.mem_size()).unwrap_or(0.into())
240    }
241
242    pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
243        TypedFact {
244            datum_type,
245            shape: ShapeFact::scalar(),
246            konst: None,
247            uniform: None,
248            opaque_fact: None,
249        }
250    }
251
252    pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
253    where
254        S: Into<ShapeFact>,
255    {
256        TypedFact { datum_type, shape: shape.into(), konst: None, uniform: None, opaque_fact: None }
257    }
258
259    pub fn rank(&self) -> usize {
260        if cfg!(debug_assertions) {
261            self.consistent().unwrap();
262        }
263        self.shape.rank()
264    }
265
266    fn format_dt_shape_nocheck(&self) -> String {
267        if self.shape.rank() > 0 {
268            format!("{:?},{:?}", self.shape, self.datum_type)
269        } else {
270            format!("{:?}", self.datum_type)
271        }
272    }
273
274    pub fn format_dt_shape(&self) -> String {
275        if cfg!(debug_assertions) {
276            self.consistent().unwrap()
277        }
278        self.format_dt_shape_nocheck()
279    }
280
281    pub fn consistent(&self) -> TractResult<()> {
282        self.shape.consistent()?;
283        if let Some(k) = &self.konst {
284            if !self.matches(k.as_ref(), None)? {
285                bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
286            }
287        }
288        if let Some(u) = &self.uniform {
289            if self.datum_type != u.datum_type() {
290                bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
291            }
292        }
293        if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
294            if let Some(k) = k.as_uniform() {
295                if &k != u {
296                    bail!("Uniform value and uniform constant mismatch: {:?}, {:?}", u, k);
297                }
298            } else {
299                bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
300            }
301        }
302        Ok(())
303    }
304
305    pub fn without_value(&self) -> Self {
306        let mut new = self.clone();
307        new.konst = None;
308        new.uniform = None;
309        new
310    }
311
312    pub fn with_opaque_fact<O: Into<Box<dyn OpaqueFact>>>(mut self, opaque_fact: O) -> Self {
313        self.opaque_fact = Some(opaque_fact.into());
314        self
315    }
316}
317
318impl Fact for TypedFact {
319    fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
320        if cfg!(debug_assertions) {
321            self.consistent()?
322        }
323        Ok(Cow::Borrowed(self))
324    }
325
326    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
327        if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
328            return Ok(false);
329        }
330        for i in 0..t.rank() {
331            if let Ok(dim) =
332                self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
333            {
334                if dim != t.shape()[i] {
335                    return Ok(false);
336                }
337            }
338        }
339        Ok(true)
340    }
341
342    fn same_as(&self, other: &dyn Fact) -> bool {
343        if cfg!(debug_assertions) {
344            self.consistent().unwrap()
345        }
346        if let Some(other) = other.downcast_ref::<Self>() {
347            if cfg!(debug_assertions) {
348                other.consistent().unwrap()
349            }
350            self == other
351        } else {
352            false
353        }
354    }
355
356    fn compatible_with(&self, other: &dyn Fact) -> bool {
357        if cfg!(debug_assertions) {
358            self.consistent().unwrap()
359        }
360        if let Some(other) = other.downcast_ref::<Self>() {
361            if cfg!(debug_assertions) {
362                other.consistent().unwrap()
363            }
364            self.datum_type == other.datum_type
365                && self.shape.compatible_with(&other.shape)
366                && self
367                    .opaque_fact
368                    .as_ref()
369                    .zip(other.opaque_fact.as_ref())
370                    .map(|(a, b)| a.compatible_with(&**b))
371                    .unwrap_or(true)
372        } else {
373            false
374        }
375    }
376
377    fn datum_type(&self) -> Option<DatumType> {
378        Some(self.datum_type)
379    }
380}
381
382impl From<Tensor> for TypedFact {
383    fn from(t: Tensor) -> TypedFact {
384        TypedFact::from(t.into_arc_tensor())
385    }
386}
387
388impl From<Arc<Tensor>> for TypedFact {
389    fn from(t: Arc<Tensor>) -> TypedFact {
390        TypedFact {
391            datum_type: t.datum_type(),
392            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
393            uniform: t.as_uniform().map(Arc::new),
394            opaque_fact: None,
395            konst: Some(t),
396        }
397    }
398}
399
400impl From<&TypedFact> for TypedFact {
401    fn from(fact: &TypedFact) -> TypedFact {
402        fact.clone()
403    }
404}
405
406impl<'a> From<&'a Arc<Tensor>> for TypedFact {
407    fn from(t: &'a Arc<Tensor>) -> TypedFact {
408        Arc::clone(t).into()
409    }
410}
411
412impl fmt::Debug for TypedFact {
413    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
414        write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
415        if self.datum_type.is_opaque() {
416            if let Some(of) = &self.opaque_fact {
417                write!(fmt, " 🔍 {:?} ", of)?
418            } else {
419                write!(fmt, " 🔍 <no opaque fact> ")?
420            }
421        }
422        if let Some(k) = &self.konst {
423            write!(fmt, "🟰 {:?}", k)?
424        }
425        Ok(())
426    }
427}
428
429pub trait DatumExt {
430    fn scalar_fact() -> TypedFact;
431    fn fact<S>(shape: S) -> TypedFact
432    where
433        S: Into<ShapeFact>;
434}
435
436impl<T: Datum> DatumExt for T {
437    #[allow(clippy::needless_borrow)]
438    fn scalar_fact() -> TypedFact {
439        TypedFact::shape::<Self, &[usize]>(&[])
440    }
441
442    fn fact<S>(shape: S) -> TypedFact
443    where
444        S: Into<ShapeFact>,
445    {
446        TypedFact::shape::<Self, _>(shape)
447    }
448}
449
450pub trait DatumTypeExt {
451    fn scalar_fact(&self) -> TypedFact;
452    fn fact<S>(&self, shape: S) -> TypedFact
453    where
454        S: Into<ShapeFact>;
455}
456
457impl DatumTypeExt for DatumType {
458    #[allow(clippy::needless_borrow)]
459    fn scalar_fact(&self) -> TypedFact {
460        TypedFact::dt_shape::<&[usize]>(*self, &[])
461    }
462
463    fn fact<S>(&self, shape: S) -> TypedFact
464    where
465        S: Into<ShapeFact>,
466    {
467        TypedFact::dt_shape(*self, shape)
468    }
469}