Skip to main content

tract_core/model/
fact.rs

1//! Partial and complete tensor types representations.
2use crate::internal::*;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt;
6use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
7
8#[derive(Clone, PartialEq, Eq, Hash)]
9pub struct ShapeFact {
10    dims: TVec<TDim>,
11    concrete: Option<TVec<usize>>,
12}
13
14impl ShapeFact {
15    #[inline]
16    pub fn rank(&self) -> usize {
17        self.dims.len()
18    }
19
20    fn compute_concrete(&mut self) {
21        assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
22        self.concrete =
23            self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
24    }
25
26    /// Shape of the tensor, unless it has symbolic dimensions.
27    #[inline]
28    pub fn as_concrete(&self) -> Option<&[usize]> {
29        self.concrete.as_deref()
30    }
31
32    /// Do we have a symbol-less value ?
33    #[inline]
34    pub fn is_concrete(&self) -> bool {
35        self.concrete.is_some()
36    }
37
38    /// Convert the shape to an array of extended dimensions.
39    #[inline]
40    pub fn to_tvec(&self) -> TVec<TDim> {
41        self.dims.clone()
42    }
43
44    /// Compute the volume of the tensor.
45    #[inline]
46    pub fn volume(&self) -> TDim {
47        self.dims.iter().product()
48    }
49
50    #[inline]
51    pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<'_, ShapeFact>> {
52        if self.is_concrete() {
53            Ok(Cow::Borrowed(self))
54        } else {
55            Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
56        }
57    }
58
59    #[inline]
60    pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<usize>>> {
61        if let Some(c) = &self.concrete {
62            Ok(Cow::Borrowed(c))
63        } else {
64            Ok(Cow::Owned(
65                self.iter()
66                    .map(|d| d.eval_to_i64(values).map(|d| d as usize))
67                    .collect::<TractResult<TVec<_>>>()?,
68            ))
69        }
70    }
71
72    #[inline]
73    pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<isize>>> {
74        if let Some(c) = &self.concrete {
75            #[allow(unknown_lints, clippy::missing_transmute_annotations)]
76            // TVec<usize> -> TVec<isize>
77            Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
78        } else {
79            Ok(Cow::Owned(
80                self.iter()
81                    .map(|d| d.eval_to_i64(values).map(|d| d as isize))
82                    .collect::<TractResult<TVec<_>>>()?,
83            ))
84        }
85    }
86
87    pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
88        let mut dims =
89            ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
90        dims.compute_concrete();
91        dims
92    }
93
94    pub fn dims(&self) -> &[TDim] {
95        self.dims.as_slice()
96    }
97
98    pub fn set(&mut self, ix: usize, dim: TDim) {
99        self.dims[ix] = dim;
100        self.compute_concrete();
101    }
102
103    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
104        self.dims.insert(axis, 1.into());
105        if let Some(concrete) = &mut self.concrete {
106            concrete.insert(axis, 1);
107        }
108        Ok(())
109    }
110
111    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
112        self.dims.remove(axis);
113        if let Some(concrete) = &mut self.concrete {
114            concrete.remove(axis);
115        } else {
116            self.compute_concrete();
117        };
118        Ok(())
119    }
120
121    pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
122        if self.rank() == _other.rank() {
123            self.dims
124                .iter()
125                .zip(_other.dims.iter())
126                .all(|(dim, other_dim)| dim.compatible_with(other_dim))
127        } else {
128            false
129        }
130    }
131
132    pub fn scalar() -> ShapeFact {
133        let void: &[usize] = &[];
134        Self::from(void)
135    }
136
137    pub fn consistent(&self) -> TractResult<()> {
138        ensure!(
139            self.concrete
140                == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
141        );
142        Ok(())
143    }
144}
145
146impl std::ops::Deref for ShapeFact {
147    type Target = [TDim];
148    fn deref(&self) -> &[TDim] {
149        &self.dims
150    }
151}
152
153impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
154    fn from(it: T) -> ShapeFact {
155        ShapeFact::from_dims(it)
156    }
157}
158
159/// Type information about a tensor: shape, and element type, in various state
160/// of determination.
161pub trait Fact:
162    std::fmt::Debug + Downcast + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static
163{
164    fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>>;
165
166    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
167        self.to_typed_fact()?.matches(t, symbols)
168    }
169
170    /// Ensure that self is same type as another fact or a subtype
171    fn compatible_with(&self, _other: &dyn Fact) -> bool;
172
173    fn datum_type(&self) -> Option<DatumType>;
174}
175
176impl_downcast!(Fact);
177dyn_clone::clone_trait_object!(Fact);
178dyn_eq::eq_trait_object!(Fact);
179
180impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
181    fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
182        ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
183    }
184}
185
186impl fmt::Debug for ShapeFact {
187    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
188        use tract_itertools::Itertools;
189        write!(fmt, "{}", self.iter().join(","))
190    }
191}
192
193impl AsRef<[TDim]> for ShapeFact {
194    fn as_ref(&self) -> &[TDim] {
195        &self.dims
196    }
197}
198
199/// Fully determined tensor information for TypedModel.
200#[derive(Clone, PartialEq, Eq, Hash)]
201pub struct TypedFact {
202    /// tensor element type
203    pub datum_type: DatumType,
204    /// tensor shape
205    pub shape: ShapeFact,
206    /// optional constant value
207    pub konst: Option<Arc<Tensor>>,
208    /// optional uniform value
209    pub uniform: Option<Arc<Tensor>>,
210    /// optional exotic fact
211    pub exotic_fact: Option<Box<dyn ExoticFact>>,
212    /// Symbolic per-element value as a TDim expression, possibly involving
213    /// coordinate symbols 🎯0,🎯1,… and/or model symbols.
214    /// `None` means "unknown / not tracked".
215    pub uniform_tdim: Option<TDim>,
216    /// Boolean TDim expression in coordinate symbols defining which positions
217    /// in the tensor are relevant to downstream consumers.
218    /// `None` means "all positions matter" (no demand annotation).
219    pub region_of_interest: Option<TDim>,
220}
221
222impl TypedFact {
223    pub fn scalar<T>() -> TypedFact
224    where
225        T: Datum,
226    {
227        Self::dt_scalar(T::datum_type())
228    }
229
230    pub fn shape<T, S>(shape: S) -> TypedFact
231    where
232        T: Datum,
233        S: Into<ShapeFact>,
234    {
235        Self::dt_shape(T::datum_type(), shape)
236    }
237
238    pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
239        debug_assert!(
240            t.is_plain(),
241            "shape_and_dt_of called on exotic tensor, exotic_fact will be lost"
242        );
243        TypedFact {
244            datum_type: t.datum_type(),
245            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
246            uniform: None,
247            konst: None,
248            exotic_fact: None,
249            uniform_tdim: None,
250            region_of_interest: None,
251        }
252    }
253
254    pub fn mem_size(&self) -> TDim {
255        self.shape.volume() * self.datum_type.size_of()
256            + self.exotic_fact().iter().flat_map(|it| it.buffer_sizes()).sum::<TDim>()
257    }
258
259    pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
260        TypedFact {
261            datum_type,
262            shape: ShapeFact::scalar(),
263            konst: None,
264            uniform: None,
265            exotic_fact: None,
266            uniform_tdim: None,
267            region_of_interest: None,
268        }
269    }
270
271    pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
272    where
273        S: Into<ShapeFact>,
274    {
275        TypedFact {
276            datum_type,
277            shape: shape.into(),
278            konst: None,
279            uniform: None,
280            exotic_fact: None,
281            uniform_tdim: None,
282            region_of_interest: None,
283        }
284    }
285
286    pub fn rank(&self) -> usize {
287        if cfg!(debug_assertions) {
288            self.consistent().unwrap();
289        }
290        self.shape.rank()
291    }
292
293    fn format_dt_shape_nocheck(&self) -> String {
294        if self.shape.rank() > 0 {
295            format!("{:?},{:?}", self.shape, self.datum_type)
296        } else {
297            format!("{:?}", self.datum_type)
298        }
299    }
300
301    pub fn format_dt_shape(&self) -> String {
302        if cfg!(debug_assertions) {
303            self.consistent().unwrap()
304        }
305        self.format_dt_shape_nocheck()
306    }
307
308    pub fn consistent(&self) -> TractResult<()> {
309        self.shape.consistent()?;
310        if let Some(k) = &self.konst {
311            if !self.matches(k.as_ref(), None)? {
312                bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
313            }
314            if let Some(bqf) = self.exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
315            {
316                if let Some(bqs) = k.storage_as::<BlockQuantStorage>() {
317                    let inner_bqf =
318                        BlockQuantFact::new(dyn_clone::clone_box(bqs.format()), k.shape().into());
319                    ensure!(&inner_bqf == bqf, "BlockQuantStorage fact mismatch");
320                }
321            }
322        }
323        if let Some(u) = &self.uniform
324            && self.datum_type != u.datum_type()
325        {
326            bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
327        }
328        if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
329            if let Some(k) = k.as_uniform() {
330                if &k != u {
331                    bail!(
332                        "Uniform value and uniform constant mismatch: value:{u:?}, uniform:{k:?}",
333                    );
334                }
335            } else {
336                bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
337            }
338        }
339        Ok(())
340    }
341
342    pub fn without_value(&self) -> Self {
343        let mut new = self.clone();
344        new.konst = None;
345        new.uniform = None;
346        new.uniform_tdim = None;
347        new.region_of_interest = None;
348        new
349    }
350
351    pub fn with_exotic_fact<O: Into<Box<dyn ExoticFact>>>(mut self, exotic_fact: O) -> Self {
352        self.exotic_fact = Some(exotic_fact.into());
353        self
354    }
355
356    pub fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
357        self.exotic_fact.as_deref()
358    }
359
360    #[inline]
361    pub fn is_exotic(&self) -> bool {
362        self.exotic_fact.is_some()
363    }
364
365    #[inline]
366    pub fn is_plain(&self) -> bool {
367        self.exotic_fact.is_none()
368    }
369}
370
371impl Fact for TypedFact {
372    fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>> {
373        if cfg!(debug_assertions) {
374            self.consistent()?
375        }
376        Ok(Cow::Borrowed(self))
377    }
378
379    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
380        if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
381            return Ok(false);
382        }
383        for i in 0..t.rank() {
384            if let Ok(dim) =
385                self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
386                && dim != t.shape()[i]
387            {
388                return Ok(false);
389            }
390        }
391        Ok(true)
392    }
393
394    fn compatible_with(&self, other: &dyn Fact) -> bool {
395        if cfg!(debug_assertions) {
396            self.consistent().unwrap()
397        }
398        if let Some(other) = other.downcast_ref::<Self>() {
399            if cfg!(debug_assertions) {
400                other.consistent().unwrap()
401            }
402            self.datum_type == other.datum_type
403                && self.shape.compatible_with(&other.shape)
404                && self
405                    .exotic_fact()
406                    .zip(other.exotic_fact())
407                    .map(|(a, b)| a.compatible_with(b))
408                    .unwrap_or(true)
409        } else {
410            false
411        }
412    }
413
414    fn datum_type(&self) -> Option<DatumType> {
415        Some(self.datum_type)
416    }
417}
418
419impl TryFrom<Tensor> for TypedFact {
420    type Error = TractError;
421    fn try_from(t: Tensor) -> TractResult<TypedFact> {
422        TypedFact::try_from(t.into_arc_tensor())
423    }
424}
425
426impl TryFrom<Arc<Tensor>> for TypedFact {
427    type Error = TractError;
428    fn try_from(t: Arc<Tensor>) -> TractResult<TypedFact> {
429        let exotic_fact = t.exotic_fact()?;
430        let uniform_tdim = if t.datum_type() == TDim::datum_type() && t.len() == 1 {
431            t.try_as_plain().ok().and_then(|d| d.as_slice::<TDim>().ok()).map(|s| s[0].clone())
432        } else if t.len() == 1
433            && t.try_as_plain().is_ok()
434            && (t.datum_type().is_integer() || t.datum_type().is::<bool>())
435        {
436            t.cast_to_scalar::<i64>().ok().map(TDim::Val)
437        } else {
438            None
439        };
440        Ok(TypedFact {
441            datum_type: t.datum_type(),
442            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
443            uniform: t.as_uniform().map(Arc::new),
444            exotic_fact,
445            konst: Some(t),
446            uniform_tdim,
447            region_of_interest: None,
448        })
449    }
450}
451
452impl From<&TypedFact> for TypedFact {
453    fn from(fact: &TypedFact) -> TypedFact {
454        fact.clone()
455    }
456}
457
458impl<'a> TryFrom<&'a Arc<Tensor>> for TypedFact {
459    type Error = TractError;
460    fn try_from(t: &'a Arc<Tensor>) -> TractResult<TypedFact> {
461        TypedFact::try_from(Arc::clone(t))
462    }
463}
464
465impl fmt::Debug for TypedFact {
466    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
467        write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
468        if self.is_exotic() {
469            if let Some(of) = &self.exotic_fact {
470                write!(fmt, " 🔍 {of:?} ")?
471            } else {
472                write!(fmt, " 🔍 <no exotic fact> ")?
473            }
474        }
475        if let Some(k) = &self.konst {
476            write!(fmt, "🟰 {k:?}")?
477        }
478        if let Some(u) = &self.uniform {
479            write!(fmt, " ◻️{u:?}")?
480        }
481        if let Some(u) = &self.uniform_tdim {
482            write!(fmt, " 📐{u}")?
483        }
484        if let Some(r) = &self.region_of_interest {
485            write!(fmt, " 🬳 {r}")?
486        }
487        Ok(())
488    }
489}
490
491pub trait DatumExt {
492    fn scalar_fact() -> TypedFact;
493    fn fact<S>(shape: S) -> TypedFact
494    where
495        S: Into<ShapeFact>;
496}
497
498impl<T: Datum> DatumExt for T {
499    #[allow(clippy::needless_borrow)]
500    fn scalar_fact() -> TypedFact {
501        TypedFact::shape::<Self, &[usize]>(&[])
502    }
503
504    fn fact<S>(shape: S) -> TypedFact
505    where
506        S: Into<ShapeFact>,
507    {
508        TypedFact::shape::<Self, _>(shape)
509    }
510}
511
512pub trait DatumTypeExt {
513    fn scalar_fact(&self) -> TypedFact;
514    fn fact<S>(&self, shape: S) -> TypedFact
515    where
516        S: Into<ShapeFact>;
517}
518
519impl DatumTypeExt for DatumType {
520    #[allow(clippy::needless_borrow)]
521    fn scalar_fact(&self) -> TypedFact {
522        TypedFact::dt_shape::<&[usize]>(*self, &[])
523    }
524
525    fn fact<S>(&self, shape: S) -> TypedFact
526    where
527        S: Into<ShapeFact>,
528    {
529        TypedFact::dt_shape(*self, shape)
530    }
531}