1use crate::internal::*;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt;
6use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
7
8#[derive(Clone, PartialEq, Eq, Hash)]
9pub struct ShapeFact {
10 dims: TVec<TDim>,
11 concrete: Option<TVec<usize>>,
12}
13
14impl ShapeFact {
15 #[inline]
16 pub fn rank(&self) -> usize {
17 self.dims.len()
18 }
19
20 fn compute_concrete(&mut self) {
21 assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
22 self.concrete =
23 self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
24 }
25
26 #[inline]
28 pub fn as_concrete(&self) -> Option<&[usize]> {
29 self.concrete.as_deref()
30 }
31
32 #[inline]
34 pub fn is_concrete(&self) -> bool {
35 self.concrete.is_some()
36 }
37
38 #[inline]
40 pub fn to_tvec(&self) -> TVec<TDim> {
41 self.dims.clone()
42 }
43
44 #[inline]
46 pub fn volume(&self) -> TDim {
47 self.dims.iter().product()
48 }
49
50 #[inline]
51 pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<'_, ShapeFact>> {
52 if self.is_concrete() {
53 Ok(Cow::Borrowed(self))
54 } else {
55 Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
56 }
57 }
58
59 #[inline]
62 pub fn substitute(
63 &self,
64 subs: &std::collections::HashMap<Symbol, TDim>,
65 ) -> TractResult<Cow<'_, ShapeFact>> {
66 if self.is_concrete() {
67 Ok(Cow::Borrowed(self))
68 } else {
69 Ok(Cow::Owned(
70 self.iter().map(|d| d.substitute_all(subs)).collect::<TractResult<ShapeFact>>()?,
71 ))
72 }
73 }
74
75 #[inline]
76 pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<usize>>> {
77 if let Some(c) = &self.concrete {
78 Ok(Cow::Borrowed(c))
79 } else {
80 Ok(Cow::Owned(
81 self.iter()
82 .map(|d| d.eval_to_i64(values).map(|d| d as usize))
83 .collect::<TractResult<TVec<_>>>()?,
84 ))
85 }
86 }
87
88 #[inline]
89 pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<isize>>> {
90 if let Some(c) = &self.concrete {
91 #[allow(unknown_lints, clippy::missing_transmute_annotations)]
92 Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
94 } else {
95 Ok(Cow::Owned(
96 self.iter()
97 .map(|d| d.eval_to_i64(values).map(|d| d as isize))
98 .collect::<TractResult<TVec<_>>>()?,
99 ))
100 }
101 }
102
103 pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
104 let mut dims =
105 ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
106 dims.compute_concrete();
107 dims
108 }
109
110 pub fn dims(&self) -> &[TDim] {
111 self.dims.as_slice()
112 }
113
114 pub fn set(&mut self, ix: usize, dim: TDim) {
115 self.dims[ix] = dim;
116 self.compute_concrete();
117 }
118
119 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
120 self.dims.insert(axis, 1.into());
121 if let Some(concrete) = &mut self.concrete {
122 concrete.insert(axis, 1);
123 }
124 Ok(())
125 }
126
127 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
128 self.dims.remove(axis);
129 if let Some(concrete) = &mut self.concrete {
130 concrete.remove(axis);
131 } else {
132 self.compute_concrete();
133 };
134 Ok(())
135 }
136
137 pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
138 if self.rank() == _other.rank() {
139 self.dims
140 .iter()
141 .zip(_other.dims.iter())
142 .all(|(dim, other_dim)| dim.compatible_with(other_dim))
143 } else {
144 false
145 }
146 }
147
148 pub fn scalar() -> ShapeFact {
149 let void: &[usize] = &[];
150 Self::from(void)
151 }
152
153 pub fn consistent(&self) -> TractResult<()> {
154 ensure!(
155 self.concrete
156 == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
157 );
158 Ok(())
159 }
160}
161
162impl std::ops::Deref for ShapeFact {
163 type Target = [TDim];
164 fn deref(&self) -> &[TDim] {
165 &self.dims
166 }
167}
168
169impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
170 fn from(it: T) -> ShapeFact {
171 ShapeFact::from_dims(it)
172 }
173}
174
175pub trait Fact:
178 std::fmt::Debug + Downcast + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static
179{
180 fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>>;
181
182 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
183 self.to_typed_fact()?.matches(t, symbols)
184 }
185
186 fn compatible_with(&self, _other: &dyn Fact) -> bool;
188
189 fn datum_type(&self) -> Option<DatumType>;
190}
191
192impl_downcast!(Fact);
193dyn_clone::clone_trait_object!(Fact);
194dyn_eq::eq_trait_object!(Fact);
195
196impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
197 fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
198 ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
199 }
200}
201
202impl fmt::Debug for ShapeFact {
203 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
204 use tract_itertools::Itertools;
205 write!(fmt, "{}", self.iter().join(","))
206 }
207}
208
209impl AsRef<[TDim]> for ShapeFact {
210 fn as_ref(&self) -> &[TDim] {
211 &self.dims
212 }
213}
214
215#[derive(Clone, PartialEq, Eq, Hash)]
217pub struct TypedFact {
218 pub datum_type: DatumType,
220 pub shape: ShapeFact,
222 pub konst: Option<Arc<Tensor>>,
224 pub uniform: Option<Arc<Tensor>>,
226 pub exotic_fact: Option<Box<dyn ExoticFact>>,
228 pub uniform_tdim: Option<TDim>,
232 pub region_of_interest: Option<TDim>,
236}
237
238impl TypedFact {
239 pub fn scalar<T>() -> TypedFact
240 where
241 T: Datum,
242 {
243 Self::dt_scalar(T::datum_type())
244 }
245
246 pub fn shape<T, S>(shape: S) -> TypedFact
247 where
248 T: Datum,
249 S: Into<ShapeFact>,
250 {
251 Self::dt_shape(T::datum_type(), shape)
252 }
253
254 pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
255 debug_assert!(
256 t.is_plain(),
257 "shape_and_dt_of called on exotic tensor, exotic_fact will be lost"
258 );
259 TypedFact {
260 datum_type: t.datum_type(),
261 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
262 uniform: None,
263 konst: None,
264 exotic_fact: None,
265 uniform_tdim: None,
266 region_of_interest: None,
267 }
268 }
269
270 pub fn mem_size(&self) -> TDim {
271 self.shape.volume() * self.datum_type.size_of()
272 + self.exotic_fact().iter().flat_map(|it| it.buffer_sizes()).sum::<TDim>()
273 }
274
275 pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
276 TypedFact {
277 datum_type,
278 shape: ShapeFact::scalar(),
279 konst: None,
280 uniform: None,
281 exotic_fact: None,
282 uniform_tdim: None,
283 region_of_interest: None,
284 }
285 }
286
287 pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
288 where
289 S: Into<ShapeFact>,
290 {
291 TypedFact {
292 datum_type,
293 shape: shape.into(),
294 konst: None,
295 uniform: None,
296 exotic_fact: None,
297 uniform_tdim: None,
298 region_of_interest: None,
299 }
300 }
301
302 pub fn rank(&self) -> usize {
303 if cfg!(debug_assertions) {
304 self.consistent().unwrap();
305 }
306 self.shape.rank()
307 }
308
309 fn format_dt_shape_nocheck(&self) -> String {
310 if self.shape.rank() > 0 {
311 format!("{:?},{:?}", self.shape, self.datum_type)
312 } else {
313 format!("{:?}", self.datum_type)
314 }
315 }
316
317 pub fn format_dt_shape(&self) -> String {
318 if cfg!(debug_assertions) {
319 self.consistent().unwrap()
320 }
321 self.format_dt_shape_nocheck()
322 }
323
324 pub fn consistent(&self) -> TractResult<()> {
325 self.shape.consistent()?;
326 if let Some(k) = &self.konst {
327 if !self.matches(k.as_ref(), None)? {
328 bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
329 }
330 if let Some(bqf) = self.exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
331 {
332 if let Some(bqs) = k.storage_as::<BlockQuantStorage>() {
333 let inner_bqf =
334 BlockQuantFact::new(dyn_clone::clone_box(bqs.format()), k.shape().into());
335 ensure!(&inner_bqf == bqf, "BlockQuantStorage fact mismatch");
336 }
337 }
338 }
339 if let Some(u) = &self.uniform
340 && self.datum_type != u.datum_type()
341 {
342 bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
343 }
344 if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
345 if let Some(k) = k.as_uniform() {
346 if &k != u {
347 bail!(
348 "Uniform value and uniform constant mismatch: value:{u:?}, uniform:{k:?}",
349 );
350 }
351 } else {
352 bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
353 }
354 }
355 Ok(())
356 }
357
358 pub fn without_value(&self) -> Self {
359 let mut new = self.clone();
360 new.konst = None;
361 new.uniform = None;
362 new.uniform_tdim = None;
363 new.region_of_interest = None;
364 new
365 }
366
367 pub fn with_exotic_fact<O: Into<Box<dyn ExoticFact>>>(mut self, exotic_fact: O) -> Self {
368 self.exotic_fact = Some(exotic_fact.into());
369 self
370 }
371
372 pub fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
373 self.exotic_fact.as_deref()
374 }
375
376 #[inline]
377 pub fn is_exotic(&self) -> bool {
378 self.exotic_fact.is_some()
379 }
380
381 #[inline]
382 pub fn is_plain(&self) -> bool {
383 self.exotic_fact.is_none()
384 }
385}
386
387impl Fact for TypedFact {
388 fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>> {
389 if cfg!(debug_assertions) {
390 self.consistent()?
391 }
392 Ok(Cow::Borrowed(self))
393 }
394
395 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
396 if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
397 return Ok(false);
398 }
399 for i in 0..t.rank() {
400 if let Ok(dim) =
401 self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
402 && dim != t.shape()[i]
403 {
404 return Ok(false);
405 }
406 }
407 Ok(true)
408 }
409
410 fn compatible_with(&self, other: &dyn Fact) -> bool {
411 if cfg!(debug_assertions) {
412 self.consistent().unwrap()
413 }
414 if let Some(other) = other.downcast_ref::<Self>() {
415 if cfg!(debug_assertions) {
416 other.consistent().unwrap()
417 }
418 self.datum_type == other.datum_type
419 && self.shape.compatible_with(&other.shape)
420 && self
421 .exotic_fact()
422 .zip(other.exotic_fact())
423 .map(|(a, b)| a.compatible_with(b))
424 .unwrap_or(true)
425 } else {
426 false
427 }
428 }
429
430 fn datum_type(&self) -> Option<DatumType> {
431 Some(self.datum_type)
432 }
433}
434
435impl TryFrom<Tensor> for TypedFact {
436 type Error = TractError;
437 fn try_from(t: Tensor) -> TractResult<TypedFact> {
438 TypedFact::try_from(t.into_arc_tensor())
439 }
440}
441
442impl TryFrom<Arc<Tensor>> for TypedFact {
443 type Error = TractError;
444 fn try_from(t: Arc<Tensor>) -> TractResult<TypedFact> {
445 let exotic_fact = t.exotic_fact()?;
446 let uniform_tdim = if t.datum_type() == TDim::datum_type() && t.len() == 1 {
447 t.try_as_plain().ok().and_then(|d| d.as_slice::<TDim>().ok()).map(|s| s[0].clone())
448 } else if t.len() == 1
449 && t.try_as_plain().is_ok()
450 && (t.datum_type().is_integer() || t.datum_type().is::<bool>())
451 {
452 t.cast_to_scalar::<i64>().ok().map(TDim::Val)
453 } else {
454 None
455 };
456 Ok(TypedFact {
457 datum_type: t.datum_type(),
458 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
459 uniform: t.as_uniform().map(Arc::new),
460 exotic_fact,
461 konst: Some(t),
462 uniform_tdim,
463 region_of_interest: None,
464 })
465 }
466}
467
468impl From<&TypedFact> for TypedFact {
469 fn from(fact: &TypedFact) -> TypedFact {
470 fact.clone()
471 }
472}
473
474impl<'a> TryFrom<&'a Arc<Tensor>> for TypedFact {
475 type Error = TractError;
476 fn try_from(t: &'a Arc<Tensor>) -> TractResult<TypedFact> {
477 TypedFact::try_from(Arc::clone(t))
478 }
479}
480
481impl fmt::Debug for TypedFact {
482 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
483 write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
484 if self.is_exotic() {
485 if let Some(of) = &self.exotic_fact {
486 write!(fmt, " 🔍 {of:?} ")?
487 } else {
488 write!(fmt, " 🔍 <no exotic fact> ")?
489 }
490 }
491 if let Some(k) = &self.konst {
492 write!(fmt, "🟰 {k:?}")?
493 }
494 if let Some(u) = &self.uniform {
495 write!(fmt, " ◻️{u:?}")?
496 }
497 if let Some(u) = &self.uniform_tdim {
498 write!(fmt, " 📐{u}")?
499 }
500 if let Some(r) = &self.region_of_interest {
501 write!(fmt, " 🬳 {r}")?
502 }
503 Ok(())
504 }
505}
506
507pub trait DatumExt {
508 fn scalar_fact() -> TypedFact;
509 fn fact<S>(shape: S) -> TypedFact
510 where
511 S: Into<ShapeFact>;
512}
513
514impl<T: Datum> DatumExt for T {
515 #[allow(clippy::needless_borrow)]
516 fn scalar_fact() -> TypedFact {
517 TypedFact::shape::<Self, &[usize]>(&[])
518 }
519
520 fn fact<S>(shape: S) -> TypedFact
521 where
522 S: Into<ShapeFact>,
523 {
524 TypedFact::shape::<Self, _>(shape)
525 }
526}
527
528pub trait DatumTypeExt {
529 fn scalar_fact(&self) -> TypedFact;
530 fn fact<S>(&self, shape: S) -> TypedFact
531 where
532 S: Into<ShapeFact>;
533}
534
535impl DatumTypeExt for DatumType {
536 #[allow(clippy::needless_borrow)]
537 fn scalar_fact(&self) -> TypedFact {
538 TypedFact::dt_shape::<&[usize]>(*self, &[])
539 }
540
541 fn fact<S>(&self, shape: S) -> TypedFact
542 where
543 S: Into<ShapeFact>,
544 {
545 TypedFact::dt_shape(*self, shape)
546 }
547}