Skip to main content

tract_core/model/
fact.rs

1//! Partial and complete tensor types representations.
2use crate::internal::*;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt;
6use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
7
8#[derive(Clone, PartialEq, Eq, Hash)]
9pub struct ShapeFact {
10    dims: TVec<TDim>,
11    concrete: Option<TVec<usize>>,
12}
13
14impl ShapeFact {
15    #[inline]
16    pub fn rank(&self) -> usize {
17        self.dims.len()
18    }
19
20    fn compute_concrete(&mut self) {
21        assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
22        self.concrete =
23            self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
24    }
25
26    /// Shape of the tensor, unless it has symbolic dimensions.
27    #[inline]
28    pub fn as_concrete(&self) -> Option<&[usize]> {
29        self.concrete.as_deref()
30    }
31
32    /// Do we have a symbol-less value ?
33    #[inline]
34    pub fn is_concrete(&self) -> bool {
35        self.concrete.is_some()
36    }
37
38    /// Convert the shape to an array of extended dimensions.
39    #[inline]
40    pub fn to_tvec(&self) -> TVec<TDim> {
41        self.dims.clone()
42    }
43
44    /// Compute the volume of the tensor.
45    #[inline]
46    pub fn volume(&self) -> TDim {
47        self.dims.iter().product()
48    }
49
50    #[inline]
51    pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<'_, ShapeFact>> {
52        if self.is_concrete() {
53            Ok(Cow::Borrowed(self))
54        } else {
55            Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
56        }
57    }
58
59    /// Substitute symbols by TDim expressions in every dim of the shape.
60    /// Concrete shapes pass through unchanged.
61    #[inline]
62    pub fn substitute(
63        &self,
64        subs: &std::collections::HashMap<Symbol, TDim>,
65    ) -> TractResult<Cow<'_, ShapeFact>> {
66        if self.is_concrete() {
67            Ok(Cow::Borrowed(self))
68        } else {
69            Ok(Cow::Owned(
70                self.iter().map(|d| d.substitute_all(subs)).collect::<TractResult<ShapeFact>>()?,
71            ))
72        }
73    }
74
75    #[inline]
76    pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<usize>>> {
77        if let Some(c) = &self.concrete {
78            Ok(Cow::Borrowed(c))
79        } else {
80            Ok(Cow::Owned(
81                self.iter()
82                    .map(|d| d.eval_to_i64(values).map(|d| d as usize))
83                    .collect::<TractResult<TVec<_>>>()?,
84            ))
85        }
86    }
87
88    #[inline]
89    pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<isize>>> {
90        if let Some(c) = &self.concrete {
91            #[allow(unknown_lints, clippy::missing_transmute_annotations)]
92            // TVec<usize> -> TVec<isize>
93            Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
94        } else {
95            Ok(Cow::Owned(
96                self.iter()
97                    .map(|d| d.eval_to_i64(values).map(|d| d as isize))
98                    .collect::<TractResult<TVec<_>>>()?,
99            ))
100        }
101    }
102
103    pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
104        let mut dims =
105            ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
106        dims.compute_concrete();
107        dims
108    }
109
110    pub fn dims(&self) -> &[TDim] {
111        self.dims.as_slice()
112    }
113
114    pub fn set(&mut self, ix: usize, dim: TDim) {
115        self.dims[ix] = dim;
116        self.compute_concrete();
117    }
118
119    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
120        self.dims.insert(axis, 1.into());
121        if let Some(concrete) = &mut self.concrete {
122            concrete.insert(axis, 1);
123        }
124        Ok(())
125    }
126
127    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
128        self.dims.remove(axis);
129        if let Some(concrete) = &mut self.concrete {
130            concrete.remove(axis);
131        } else {
132            self.compute_concrete();
133        };
134        Ok(())
135    }
136
137    pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
138        if self.rank() == _other.rank() {
139            self.dims
140                .iter()
141                .zip(_other.dims.iter())
142                .all(|(dim, other_dim)| dim.compatible_with(other_dim))
143        } else {
144            false
145        }
146    }
147
148    pub fn scalar() -> ShapeFact {
149        let void: &[usize] = &[];
150        Self::from(void)
151    }
152
153    pub fn consistent(&self) -> TractResult<()> {
154        ensure!(
155            self.concrete
156                == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
157        );
158        Ok(())
159    }
160}
161
162impl std::ops::Deref for ShapeFact {
163    type Target = [TDim];
164    fn deref(&self) -> &[TDim] {
165        &self.dims
166    }
167}
168
169impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
170    fn from(it: T) -> ShapeFact {
171        ShapeFact::from_dims(it)
172    }
173}
174
175/// Type information about a tensor: shape, and element type, in various state
176/// of determination.
177pub trait Fact:
178    std::fmt::Debug + Downcast + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static
179{
180    fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>>;
181
182    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
183        self.to_typed_fact()?.matches(t, symbols)
184    }
185
186    /// Ensure that self is same type as another fact or a subtype
187    fn compatible_with(&self, _other: &dyn Fact) -> bool;
188
189    fn datum_type(&self) -> Option<DatumType>;
190}
191
192impl_downcast!(Fact);
193dyn_clone::clone_trait_object!(Fact);
194dyn_eq::eq_trait_object!(Fact);
195
196impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
197    fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
198        ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
199    }
200}
201
202impl fmt::Debug for ShapeFact {
203    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
204        use tract_itertools::Itertools;
205        write!(fmt, "{}", self.iter().join(","))
206    }
207}
208
209impl AsRef<[TDim]> for ShapeFact {
210    fn as_ref(&self) -> &[TDim] {
211        &self.dims
212    }
213}
214
215/// Fully determined tensor information for TypedModel.
216#[derive(Clone, PartialEq, Eq, Hash)]
217pub struct TypedFact {
218    /// tensor element type
219    pub datum_type: DatumType,
220    /// tensor shape
221    pub shape: ShapeFact,
222    /// optional constant value
223    pub konst: Option<Arc<Tensor>>,
224    /// optional uniform value
225    pub uniform: Option<Arc<Tensor>>,
226    /// optional exotic fact
227    pub exotic_fact: Option<Box<dyn ExoticFact>>,
228    /// Symbolic per-element value as a TDim expression, possibly involving
229    /// coordinate symbols 🎯0,🎯1,… and/or model symbols.
230    /// `None` means "unknown / not tracked".
231    pub uniform_tdim: Option<TDim>,
232    /// Boolean TDim expression in coordinate symbols defining which positions
233    /// in the tensor are relevant to downstream consumers.
234    /// `None` means "all positions matter" (no demand annotation).
235    pub region_of_interest: Option<TDim>,
236}
237
238impl TypedFact {
239    pub fn scalar<T>() -> TypedFact
240    where
241        T: Datum,
242    {
243        Self::dt_scalar(T::datum_type())
244    }
245
246    pub fn shape<T, S>(shape: S) -> TypedFact
247    where
248        T: Datum,
249        S: Into<ShapeFact>,
250    {
251        Self::dt_shape(T::datum_type(), shape)
252    }
253
254    pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
255        debug_assert!(
256            t.is_plain(),
257            "shape_and_dt_of called on exotic tensor, exotic_fact will be lost"
258        );
259        TypedFact {
260            datum_type: t.datum_type(),
261            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
262            uniform: None,
263            konst: None,
264            exotic_fact: None,
265            uniform_tdim: None,
266            region_of_interest: None,
267        }
268    }
269
270    pub fn mem_size(&self) -> TDim {
271        self.shape.volume() * self.datum_type.size_of()
272            + self.exotic_fact().iter().flat_map(|it| it.buffer_sizes()).sum::<TDim>()
273    }
274
275    pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
276        TypedFact {
277            datum_type,
278            shape: ShapeFact::scalar(),
279            konst: None,
280            uniform: None,
281            exotic_fact: None,
282            uniform_tdim: None,
283            region_of_interest: None,
284        }
285    }
286
287    pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
288    where
289        S: Into<ShapeFact>,
290    {
291        TypedFact {
292            datum_type,
293            shape: shape.into(),
294            konst: None,
295            uniform: None,
296            exotic_fact: None,
297            uniform_tdim: None,
298            region_of_interest: None,
299        }
300    }
301
302    pub fn rank(&self) -> usize {
303        if cfg!(debug_assertions) {
304            self.consistent().unwrap();
305        }
306        self.shape.rank()
307    }
308
309    fn format_dt_shape_nocheck(&self) -> String {
310        if self.shape.rank() > 0 {
311            format!("{:?},{:?}", self.shape, self.datum_type)
312        } else {
313            format!("{:?}", self.datum_type)
314        }
315    }
316
317    pub fn format_dt_shape(&self) -> String {
318        if cfg!(debug_assertions) {
319            self.consistent().unwrap()
320        }
321        self.format_dt_shape_nocheck()
322    }
323
324    pub fn consistent(&self) -> TractResult<()> {
325        self.shape.consistent()?;
326        if let Some(k) = &self.konst {
327            if !self.matches(k.as_ref(), None)? {
328                bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
329            }
330            if let Some(bqf) = self.exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
331            {
332                if let Some(bqs) = k.storage_as::<BlockQuantStorage>() {
333                    let inner_bqf =
334                        BlockQuantFact::new(dyn_clone::clone_box(bqs.format()), k.shape().into());
335                    ensure!(&inner_bqf == bqf, "BlockQuantStorage fact mismatch");
336                }
337            }
338        }
339        if let Some(u) = &self.uniform
340            && self.datum_type != u.datum_type()
341        {
342            bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
343        }
344        if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
345            if let Some(k) = k.as_uniform() {
346                if &k != u {
347                    bail!(
348                        "Uniform value and uniform constant mismatch: value:{u:?}, uniform:{k:?}",
349                    );
350                }
351            } else {
352                bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
353            }
354        }
355        Ok(())
356    }
357
358    pub fn without_value(&self) -> Self {
359        let mut new = self.clone();
360        new.konst = None;
361        new.uniform = None;
362        new.uniform_tdim = None;
363        new.region_of_interest = None;
364        new
365    }
366
367    pub fn with_exotic_fact<O: Into<Box<dyn ExoticFact>>>(mut self, exotic_fact: O) -> Self {
368        self.exotic_fact = Some(exotic_fact.into());
369        self
370    }
371
372    pub fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
373        self.exotic_fact.as_deref()
374    }
375
376    #[inline]
377    pub fn is_exotic(&self) -> bool {
378        self.exotic_fact.is_some()
379    }
380
381    #[inline]
382    pub fn is_plain(&self) -> bool {
383        self.exotic_fact.is_none()
384    }
385}
386
387impl Fact for TypedFact {
388    fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>> {
389        if cfg!(debug_assertions) {
390            self.consistent()?
391        }
392        Ok(Cow::Borrowed(self))
393    }
394
395    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
396        if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
397            return Ok(false);
398        }
399        for i in 0..t.rank() {
400            if let Ok(dim) =
401                self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
402                && dim != t.shape()[i]
403            {
404                return Ok(false);
405            }
406        }
407        Ok(true)
408    }
409
410    fn compatible_with(&self, other: &dyn Fact) -> bool {
411        if cfg!(debug_assertions) {
412            self.consistent().unwrap()
413        }
414        if let Some(other) = other.downcast_ref::<Self>() {
415            if cfg!(debug_assertions) {
416                other.consistent().unwrap()
417            }
418            self.datum_type == other.datum_type
419                && self.shape.compatible_with(&other.shape)
420                && self
421                    .exotic_fact()
422                    .zip(other.exotic_fact())
423                    .map(|(a, b)| a.compatible_with(b))
424                    .unwrap_or(true)
425        } else {
426            false
427        }
428    }
429
430    fn datum_type(&self) -> Option<DatumType> {
431        Some(self.datum_type)
432    }
433}
434
435impl TryFrom<Tensor> for TypedFact {
436    type Error = TractError;
437    fn try_from(t: Tensor) -> TractResult<TypedFact> {
438        TypedFact::try_from(t.into_arc_tensor())
439    }
440}
441
442impl TryFrom<Arc<Tensor>> for TypedFact {
443    type Error = TractError;
444    fn try_from(t: Arc<Tensor>) -> TractResult<TypedFact> {
445        let exotic_fact = t.exotic_fact()?;
446        let uniform_tdim = if t.datum_type() == TDim::datum_type() && t.len() == 1 {
447            t.try_as_plain().ok().and_then(|d| d.as_slice::<TDim>().ok()).map(|s| s[0].clone())
448        } else if t.len() == 1
449            && t.try_as_plain().is_ok()
450            && (t.datum_type().is_integer() || t.datum_type().is::<bool>())
451        {
452            t.cast_to_scalar::<i64>().ok().map(TDim::Val)
453        } else {
454            None
455        };
456        Ok(TypedFact {
457            datum_type: t.datum_type(),
458            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
459            uniform: t.as_uniform().map(Arc::new),
460            exotic_fact,
461            konst: Some(t),
462            uniform_tdim,
463            region_of_interest: None,
464        })
465    }
466}
467
468impl From<&TypedFact> for TypedFact {
469    fn from(fact: &TypedFact) -> TypedFact {
470        fact.clone()
471    }
472}
473
474impl<'a> TryFrom<&'a Arc<Tensor>> for TypedFact {
475    type Error = TractError;
476    fn try_from(t: &'a Arc<Tensor>) -> TractResult<TypedFact> {
477        TypedFact::try_from(Arc::clone(t))
478    }
479}
480
481impl fmt::Debug for TypedFact {
482    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
483        write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
484        if self.is_exotic() {
485            if let Some(of) = &self.exotic_fact {
486                write!(fmt, " 🔍 {of:?} ")?
487            } else {
488                write!(fmt, " 🔍 <no exotic fact> ")?
489            }
490        }
491        if let Some(k) = &self.konst {
492            write!(fmt, "🟰 {k:?}")?
493        }
494        if let Some(u) = &self.uniform {
495            write!(fmt, " ◻️{u:?}")?
496        }
497        if let Some(u) = &self.uniform_tdim {
498            write!(fmt, " 📐{u}")?
499        }
500        if let Some(r) = &self.region_of_interest {
501            write!(fmt, " 🬳 {r}")?
502        }
503        Ok(())
504    }
505}
506
507pub trait DatumExt {
508    fn scalar_fact() -> TypedFact;
509    fn fact<S>(shape: S) -> TypedFact
510    where
511        S: Into<ShapeFact>;
512}
513
514impl<T: Datum> DatumExt for T {
515    #[allow(clippy::needless_borrow)]
516    fn scalar_fact() -> TypedFact {
517        TypedFact::shape::<Self, &[usize]>(&[])
518    }
519
520    fn fact<S>(shape: S) -> TypedFact
521    where
522        S: Into<ShapeFact>,
523    {
524        TypedFact::shape::<Self, _>(shape)
525    }
526}
527
528pub trait DatumTypeExt {
529    fn scalar_fact(&self) -> TypedFact;
530    fn fact<S>(&self, shape: S) -> TypedFact
531    where
532        S: Into<ShapeFact>;
533}
534
535impl DatumTypeExt for DatumType {
536    #[allow(clippy::needless_borrow)]
537    fn scalar_fact(&self) -> TypedFact {
538        TypedFact::dt_shape::<&[usize]>(*self, &[])
539    }
540
541    fn fact<S>(&self, shape: S) -> TypedFact
542    where
543        S: Into<ShapeFact>,
544    {
545        TypedFact::dt_shape(*self, shape)
546    }
547}