1use crate::internal::*;
3use downcast_rs::Downcast;
4use std::fmt;
5use tract_linalg::block_quant::{BlockQuantFact, BlockQuantValue};
6
7#[derive(Clone, PartialEq, Eq, Hash)]
8pub struct ShapeFact {
9    dims: TVec<TDim>,
10    concrete: Option<TVec<usize>>,
11}
12
13impl ShapeFact {
14    #[inline]
15    pub fn rank(&self) -> usize {
16        self.dims.len()
17    }
18
19    fn compute_concrete(&mut self) {
20        assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
21        self.concrete =
22            self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
23    }
24
25    #[inline]
27    pub fn as_concrete(&self) -> Option<&[usize]> {
28        self.concrete.as_deref()
29    }
30
31    #[inline]
33    pub fn is_concrete(&self) -> bool {
34        self.concrete.is_some()
35    }
36
37    #[inline]
39    pub fn to_tvec(&self) -> TVec<TDim> {
40        self.dims.clone()
41    }
42
43    #[inline]
45    pub fn volume(&self) -> TDim {
46        self.dims.iter().product()
47    }
48
49    #[inline]
50    pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<ShapeFact>> {
51        if self.is_concrete() {
52            Ok(Cow::Borrowed(self))
53        } else {
54            Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
55        }
56    }
57
58    #[inline]
59    pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<usize>>> {
60        if let Some(c) = &self.concrete {
61            Ok(Cow::Borrowed(c))
62        } else {
63            Ok(Cow::Owned(
64                self.iter()
65                    .map(|d| d.eval_to_i64(values).map(|d| d as usize))
66                    .collect::<TractResult<TVec<_>>>()?,
67            ))
68        }
69    }
70
71    #[inline]
72    pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<isize>>> {
73        if let Some(c) = &self.concrete {
74            #[allow(unknown_lints, clippy::missing_transmute_annotations)]
75            Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
77        } else {
78            Ok(Cow::Owned(
79                self.iter()
80                    .map(|d| d.eval_to_i64(values).map(|d| d as isize))
81                    .collect::<TractResult<TVec<_>>>()?,
82            ))
83        }
84    }
85
86    pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
87        let mut dims =
88            ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
89        dims.compute_concrete();
90        dims
91    }
92
93    pub fn dims(&self) -> &[TDim] {
94        self.dims.as_slice()
95    }
96
97    pub fn set(&mut self, ix: usize, dim: TDim) {
98        self.dims[ix] = dim;
99        self.compute_concrete();
100    }
101
102    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
103        self.dims.insert(axis, 1.into());
104        if let Some(concrete) = &mut self.concrete {
105            concrete.insert(axis, 1);
106        }
107        Ok(())
108    }
109
110    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
111        self.dims.remove(axis);
112        if let Some(concrete) = &mut self.concrete {
113            concrete.remove(axis);
114        } else {
115            self.compute_concrete();
116        };
117        Ok(())
118    }
119
120    pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
121        if self.rank() == _other.rank() {
122            self.dims
123                .iter()
124                .zip(_other.dims.iter())
125                .all(|(dim, other_dim)| dim.compatible_with(other_dim))
126        } else {
127            false
128        }
129    }
130
131    pub fn scalar() -> ShapeFact {
132        let void: &[usize] = &[];
133        Self::from(void)
134    }
135
136    pub fn consistent(&self) -> TractResult<()> {
137        ensure!(
138            self.concrete
139                == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
140        );
141        Ok(())
142    }
143}
144
145impl std::ops::Deref for ShapeFact {
146    type Target = [TDim];
147    fn deref(&self) -> &[TDim] {
148        &self.dims
149    }
150}
151
152impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
153    fn from(it: T) -> ShapeFact {
154        ShapeFact::from_dims(it)
155    }
156}
157
158pub trait Fact: std::fmt::Debug + Downcast + dyn_clone::DynClone + Send + Sync + 'static {
161    fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>>;
162
163    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
164        self.to_typed_fact()?.matches(t, symbols)
165    }
166
167    fn same_as(&self, _other: &dyn Fact) -> bool;
168
169    fn compatible_with(&self, _other: &dyn Fact) -> bool;
171
172    fn datum_type(&self) -> Option<DatumType>;
173}
174
175impl_downcast!(Fact);
176dyn_clone::clone_trait_object!(Fact);
177
178impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
179    fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
180        ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
181    }
182}
183
184impl fmt::Debug for ShapeFact {
185    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
186        use tract_itertools::Itertools;
187        write!(fmt, "{}", self.iter().join(","))
188    }
189}
190
191impl AsRef<[TDim]> for ShapeFact {
192    fn as_ref(&self) -> &[TDim] {
193        &self.dims
194    }
195}
196
197#[derive(Clone, PartialEq, Eq, Hash)]
199pub struct TypedFact {
200    pub datum_type: DatumType,
202    pub shape: ShapeFact,
204    pub konst: Option<Arc<Tensor>>,
206    pub uniform: Option<Arc<Tensor>>,
208    pub opaque_fact: Option<Box<dyn OpaqueFact>>,
210}
211
212impl TypedFact {
213    pub fn scalar<T>() -> TypedFact
214    where
215        T: Datum,
216    {
217        Self::dt_scalar(T::datum_type())
218    }
219
220    pub fn shape<T, S>(shape: S) -> TypedFact
221    where
222        T: Datum,
223        S: Into<ShapeFact>,
224    {
225        Self::dt_shape(T::datum_type(), shape)
226    }
227
228    pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
229        TypedFact {
230            datum_type: t.datum_type(),
231            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
232            uniform: None,
233            konst: None,
234            opaque_fact: None,
235        }
236    }
237
238    pub fn mem_size(&self) -> TDim {
239        self.shape.volume() * self.datum_type.size_of()
240            + self.opaque_fact().map(|it| it.mem_size()).unwrap_or(0.into())
241    }
242
243    pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
244        TypedFact {
245            datum_type,
246            shape: ShapeFact::scalar(),
247            konst: None,
248            uniform: None,
249            opaque_fact: None,
250        }
251    }
252
253    pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
254    where
255        S: Into<ShapeFact>,
256    {
257        TypedFact { datum_type, shape: shape.into(), konst: None, uniform: None, opaque_fact: None }
258    }
259
260    pub fn rank(&self) -> usize {
261        if cfg!(debug_assertions) {
262            self.consistent().unwrap();
263        }
264        self.shape.rank()
265    }
266
267    fn format_dt_shape_nocheck(&self) -> String {
268        if self.shape.rank() > 0 {
269            format!("{:?},{:?}", self.shape, self.datum_type)
270        } else {
271            format!("{:?}", self.datum_type)
272        }
273    }
274
275    pub fn format_dt_shape(&self) -> String {
276        if cfg!(debug_assertions) {
277            self.consistent().unwrap()
278        }
279        self.format_dt_shape_nocheck()
280    }
281
282    pub fn consistent(&self) -> TractResult<()> {
283        self.shape.consistent()?;
284        ensure!(self.datum_type.is_opaque() == self.opaque_fact.is_some());
285        if let Some(k) = &self.konst {
286            if !self.matches(k.as_ref(), None)? {
287                bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
288            }
289            if let Some(bqf) = self.opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
290            {
291                for o in k.as_slice::<Opaque>().unwrap() {
292                    ensure!(o.is::<BlockQuantValue>());
293                    ensure!(o.downcast_ref::<BlockQuantValue>().unwrap().fact == *bqf);
294                }
295            }
296        }
297        if let Some(u) = &self.uniform {
298            if self.datum_type != u.datum_type() {
299                bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
300            }
301        }
302        if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
303            if let Some(k) = k.as_uniform() {
304                if &k != u {
305                    bail!("Uniform value and uniform constant mismatch: {:?}, {:?}", u, k);
306                }
307            } else {
308                bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
309            }
310        }
311        Ok(())
312    }
313
314    pub fn without_value(&self) -> Self {
315        let mut new = self.clone();
316        new.konst = None;
317        new.uniform = None;
318        new
319    }
320
321    pub fn with_opaque_fact<O: Into<Box<dyn OpaqueFact>>>(mut self, opaque_fact: O) -> Self {
322        self.opaque_fact = Some(opaque_fact.into());
323        self
324    }
325
326    pub fn opaque_fact(&self) -> Option<&dyn OpaqueFact> {
327        self.opaque_fact.as_deref()
328    }
329}
330
331impl Fact for TypedFact {
332    fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
333        if cfg!(debug_assertions) {
334            self.consistent()?
335        }
336        Ok(Cow::Borrowed(self))
337    }
338
339    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
340        if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
341            return Ok(false);
342        }
343        for i in 0..t.rank() {
344            if let Ok(dim) =
345                self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
346            {
347                if dim != t.shape()[i] {
348                    return Ok(false);
349                }
350            }
351        }
352        Ok(true)
353    }
354
355    fn same_as(&self, other: &dyn Fact) -> bool {
356        if cfg!(debug_assertions) {
357            self.consistent().unwrap()
358        }
359        if let Some(other) = other.downcast_ref::<Self>() {
360            if cfg!(debug_assertions) {
361                other.consistent().unwrap()
362            }
363            self == other
364        } else {
365            false
366        }
367    }
368
369    fn compatible_with(&self, other: &dyn Fact) -> bool {
370        if cfg!(debug_assertions) {
371            self.consistent().unwrap()
372        }
373        if let Some(other) = other.downcast_ref::<Self>() {
374            if cfg!(debug_assertions) {
375                other.consistent().unwrap()
376            }
377            self.datum_type == other.datum_type
378                && self.shape.compatible_with(&other.shape)
379                && self
380                    .opaque_fact()
381                    .zip(other.opaque_fact())
382                    .map(|(a, b)| a.compatible_with(b))
383                    .unwrap_or(true)
384        } else {
385            false
386        }
387    }
388
389    fn datum_type(&self) -> Option<DatumType> {
390        Some(self.datum_type)
391    }
392}
393
394impl From<Tensor> for TypedFact {
395    fn from(t: Tensor) -> TypedFact {
396        TypedFact::from(t.into_arc_tensor())
397    }
398}
399
400impl From<Arc<Tensor>> for TypedFact {
401    fn from(t: Arc<Tensor>) -> TypedFact {
402        TypedFact {
403            datum_type: t.datum_type(),
404            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
405            uniform: t.as_uniform().map(Arc::new),
406            opaque_fact: None,
407            konst: Some(t),
408        }
409    }
410}
411
412impl From<&TypedFact> for TypedFact {
413    fn from(fact: &TypedFact) -> TypedFact {
414        fact.clone()
415    }
416}
417
418impl<'a> From<&'a Arc<Tensor>> for TypedFact {
419    fn from(t: &'a Arc<Tensor>) -> TypedFact {
420        Arc::clone(t).into()
421    }
422}
423
424impl fmt::Debug for TypedFact {
425    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
426        write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
427        if self.datum_type.is_opaque() {
428            if let Some(of) = &self.opaque_fact {
429                write!(fmt, " 🔍 {:?} ", of)?
430            } else {
431                write!(fmt, " 🔍 <no opaque fact> ")?
432            }
433        }
434        if let Some(k) = &self.konst {
435            write!(fmt, "🟰 {:?}", k)?
436        }
437        Ok(())
438    }
439}
440
441pub trait DatumExt {
442    fn scalar_fact() -> TypedFact;
443    fn fact<S>(shape: S) -> TypedFact
444    where
445        S: Into<ShapeFact>;
446}
447
448impl<T: Datum> DatumExt for T {
449    #[allow(clippy::needless_borrow)]
450    fn scalar_fact() -> TypedFact {
451        TypedFact::shape::<Self, &[usize]>(&[])
452    }
453
454    fn fact<S>(shape: S) -> TypedFact
455    where
456        S: Into<ShapeFact>,
457    {
458        TypedFact::shape::<Self, _>(shape)
459    }
460}
461
462pub trait DatumTypeExt {
463    fn scalar_fact(&self) -> TypedFact;
464    fn fact<S>(&self, shape: S) -> TypedFact
465    where
466        S: Into<ShapeFact>;
467}
468
469impl DatumTypeExt for DatumType {
470    #[allow(clippy::needless_borrow)]
471    fn scalar_fact(&self) -> TypedFact {
472        TypedFact::dt_shape::<&[usize]>(*self, &[])
473    }
474
475    fn fact<S>(&self, shape: S) -> TypedFact
476    where
477        S: Into<ShapeFact>,
478    {
479        TypedFact::dt_shape(*self, shape)
480    }
481}