1use crate::internal::*;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt;
6use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};
7
8#[derive(Clone, PartialEq, Eq, Hash)]
9pub struct ShapeFact {
10 dims: TVec<TDim>,
11 concrete: Option<TVec<usize>>,
12}
13
14impl ShapeFact {
15 #[inline]
16 pub fn rank(&self) -> usize {
17 self.dims.len()
18 }
19
20 fn compute_concrete(&mut self) {
21 assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
22 self.concrete =
23 self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
24 }
25
26 #[inline]
28 pub fn as_concrete(&self) -> Option<&[usize]> {
29 self.concrete.as_deref()
30 }
31
32 #[inline]
34 pub fn is_concrete(&self) -> bool {
35 self.concrete.is_some()
36 }
37
38 #[inline]
40 pub fn to_tvec(&self) -> TVec<TDim> {
41 self.dims.clone()
42 }
43
44 #[inline]
46 pub fn volume(&self) -> TDim {
47 self.dims.iter().product()
48 }
49
50 #[inline]
51 pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<'_, ShapeFact>> {
52 if self.is_concrete() {
53 Ok(Cow::Borrowed(self))
54 } else {
55 Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
56 }
57 }
58
59 #[inline]
60 pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<usize>>> {
61 if let Some(c) = &self.concrete {
62 Ok(Cow::Borrowed(c))
63 } else {
64 Ok(Cow::Owned(
65 self.iter()
66 .map(|d| d.eval_to_i64(values).map(|d| d as usize))
67 .collect::<TractResult<TVec<_>>>()?,
68 ))
69 }
70 }
71
72 #[inline]
73 pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<isize>>> {
74 if let Some(c) = &self.concrete {
75 #[allow(unknown_lints, clippy::missing_transmute_annotations)]
76 Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
78 } else {
79 Ok(Cow::Owned(
80 self.iter()
81 .map(|d| d.eval_to_i64(values).map(|d| d as isize))
82 .collect::<TractResult<TVec<_>>>()?,
83 ))
84 }
85 }
86
87 pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
88 let mut dims =
89 ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
90 dims.compute_concrete();
91 dims
92 }
93
94 pub fn dims(&self) -> &[TDim] {
95 self.dims.as_slice()
96 }
97
98 pub fn set(&mut self, ix: usize, dim: TDim) {
99 self.dims[ix] = dim;
100 self.compute_concrete();
101 }
102
103 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
104 self.dims.insert(axis, 1.into());
105 if let Some(concrete) = &mut self.concrete {
106 concrete.insert(axis, 1);
107 }
108 Ok(())
109 }
110
111 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
112 self.dims.remove(axis);
113 if let Some(concrete) = &mut self.concrete {
114 concrete.remove(axis);
115 } else {
116 self.compute_concrete();
117 };
118 Ok(())
119 }
120
121 pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
122 if self.rank() == _other.rank() {
123 self.dims
124 .iter()
125 .zip(_other.dims.iter())
126 .all(|(dim, other_dim)| dim.compatible_with(other_dim))
127 } else {
128 false
129 }
130 }
131
132 pub fn scalar() -> ShapeFact {
133 let void: &[usize] = &[];
134 Self::from(void)
135 }
136
137 pub fn consistent(&self) -> TractResult<()> {
138 ensure!(
139 self.concrete
140 == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
141 );
142 Ok(())
143 }
144}
145
146impl std::ops::Deref for ShapeFact {
147 type Target = [TDim];
148 fn deref(&self) -> &[TDim] {
149 &self.dims
150 }
151}
152
153impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
154 fn from(it: T) -> ShapeFact {
155 ShapeFact::from_dims(it)
156 }
157}
158
159pub trait Fact:
162 std::fmt::Debug + Downcast + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static
163{
164 fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>>;
165
166 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
167 self.to_typed_fact()?.matches(t, symbols)
168 }
169
170 fn compatible_with(&self, _other: &dyn Fact) -> bool;
172
173 fn datum_type(&self) -> Option<DatumType>;
174}
175
176impl_downcast!(Fact);
177dyn_clone::clone_trait_object!(Fact);
178dyn_eq::eq_trait_object!(Fact);
179
180impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
181 fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
182 ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
183 }
184}
185
186impl fmt::Debug for ShapeFact {
187 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
188 use tract_itertools::Itertools;
189 write!(fmt, "{}", self.iter().join(","))
190 }
191}
192
193impl AsRef<[TDim]> for ShapeFact {
194 fn as_ref(&self) -> &[TDim] {
195 &self.dims
196 }
197}
198
199#[derive(Clone, PartialEq, Eq, Hash)]
201pub struct TypedFact {
202 pub datum_type: DatumType,
204 pub shape: ShapeFact,
206 pub konst: Option<Arc<Tensor>>,
208 pub uniform: Option<Arc<Tensor>>,
210 pub exotic_fact: Option<Box<dyn ExoticFact>>,
212 pub uniform_tdim: Option<TDim>,
216 pub region_of_interest: Option<TDim>,
220}
221
222impl TypedFact {
223 pub fn scalar<T>() -> TypedFact
224 where
225 T: Datum,
226 {
227 Self::dt_scalar(T::datum_type())
228 }
229
230 pub fn shape<T, S>(shape: S) -> TypedFact
231 where
232 T: Datum,
233 S: Into<ShapeFact>,
234 {
235 Self::dt_shape(T::datum_type(), shape)
236 }
237
238 pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
239 debug_assert!(
240 t.is_plain(),
241 "shape_and_dt_of called on exotic tensor, exotic_fact will be lost"
242 );
243 TypedFact {
244 datum_type: t.datum_type(),
245 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
246 uniform: None,
247 konst: None,
248 exotic_fact: None,
249 uniform_tdim: None,
250 region_of_interest: None,
251 }
252 }
253
254 pub fn mem_size(&self) -> TDim {
255 self.shape.volume() * self.datum_type.size_of()
256 + self.exotic_fact().iter().flat_map(|it| it.buffer_sizes()).sum::<TDim>()
257 }
258
259 pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
260 TypedFact {
261 datum_type,
262 shape: ShapeFact::scalar(),
263 konst: None,
264 uniform: None,
265 exotic_fact: None,
266 uniform_tdim: None,
267 region_of_interest: None,
268 }
269 }
270
271 pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
272 where
273 S: Into<ShapeFact>,
274 {
275 TypedFact {
276 datum_type,
277 shape: shape.into(),
278 konst: None,
279 uniform: None,
280 exotic_fact: None,
281 uniform_tdim: None,
282 region_of_interest: None,
283 }
284 }
285
286 pub fn rank(&self) -> usize {
287 if cfg!(debug_assertions) {
288 self.consistent().unwrap();
289 }
290 self.shape.rank()
291 }
292
293 fn format_dt_shape_nocheck(&self) -> String {
294 if self.shape.rank() > 0 {
295 format!("{:?},{:?}", self.shape, self.datum_type)
296 } else {
297 format!("{:?}", self.datum_type)
298 }
299 }
300
301 pub fn format_dt_shape(&self) -> String {
302 if cfg!(debug_assertions) {
303 self.consistent().unwrap()
304 }
305 self.format_dt_shape_nocheck()
306 }
307
308 pub fn consistent(&self) -> TractResult<()> {
309 self.shape.consistent()?;
310 if let Some(k) = &self.konst {
311 if !self.matches(k.as_ref(), None)? {
312 bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
313 }
314 if let Some(bqf) = self.exotic_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
315 {
316 if let Some(bqs) = k.storage_as::<BlockQuantStorage>() {
317 let inner_bqf =
318 BlockQuantFact::new(dyn_clone::clone_box(bqs.format()), k.shape().into());
319 ensure!(&inner_bqf == bqf, "BlockQuantStorage fact mismatch");
320 }
321 }
322 }
323 if let Some(u) = &self.uniform
324 && self.datum_type != u.datum_type()
325 {
326 bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
327 }
328 if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
329 if let Some(k) = k.as_uniform() {
330 if &k != u {
331 bail!(
332 "Uniform value and uniform constant mismatch: value:{u:?}, uniform:{k:?}",
333 );
334 }
335 } else {
336 bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
337 }
338 }
339 Ok(())
340 }
341
342 pub fn without_value(&self) -> Self {
343 let mut new = self.clone();
344 new.konst = None;
345 new.uniform = None;
346 new.uniform_tdim = None;
347 new.region_of_interest = None;
348 new
349 }
350
351 pub fn with_exotic_fact<O: Into<Box<dyn ExoticFact>>>(mut self, exotic_fact: O) -> Self {
352 self.exotic_fact = Some(exotic_fact.into());
353 self
354 }
355
356 pub fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
357 self.exotic_fact.as_deref()
358 }
359
360 #[inline]
361 pub fn is_exotic(&self) -> bool {
362 self.exotic_fact.is_some()
363 }
364
365 #[inline]
366 pub fn is_plain(&self) -> bool {
367 self.exotic_fact.is_none()
368 }
369}
370
371impl Fact for TypedFact {
372 fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>> {
373 if cfg!(debug_assertions) {
374 self.consistent()?
375 }
376 Ok(Cow::Borrowed(self))
377 }
378
379 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
380 if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
381 return Ok(false);
382 }
383 for i in 0..t.rank() {
384 if let Ok(dim) =
385 self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
386 && dim != t.shape()[i]
387 {
388 return Ok(false);
389 }
390 }
391 Ok(true)
392 }
393
394 fn compatible_with(&self, other: &dyn Fact) -> bool {
395 if cfg!(debug_assertions) {
396 self.consistent().unwrap()
397 }
398 if let Some(other) = other.downcast_ref::<Self>() {
399 if cfg!(debug_assertions) {
400 other.consistent().unwrap()
401 }
402 self.datum_type == other.datum_type
403 && self.shape.compatible_with(&other.shape)
404 && self
405 .exotic_fact()
406 .zip(other.exotic_fact())
407 .map(|(a, b)| a.compatible_with(b))
408 .unwrap_or(true)
409 } else {
410 false
411 }
412 }
413
414 fn datum_type(&self) -> Option<DatumType> {
415 Some(self.datum_type)
416 }
417}
418
419impl TryFrom<Tensor> for TypedFact {
420 type Error = TractError;
421 fn try_from(t: Tensor) -> TractResult<TypedFact> {
422 TypedFact::try_from(t.into_arc_tensor())
423 }
424}
425
426impl TryFrom<Arc<Tensor>> for TypedFact {
427 type Error = TractError;
428 fn try_from(t: Arc<Tensor>) -> TractResult<TypedFact> {
429 let exotic_fact = t.exotic_fact()?;
430 let uniform_tdim = if t.datum_type() == TDim::datum_type() && t.len() == 1 {
431 t.try_as_plain().ok().and_then(|d| d.as_slice::<TDim>().ok()).map(|s| s[0].clone())
432 } else if t.len() == 1
433 && t.try_as_plain().is_ok()
434 && (t.datum_type().is_integer() || t.datum_type().is::<bool>())
435 {
436 t.cast_to_scalar::<i64>().ok().map(TDim::Val)
437 } else {
438 None
439 };
440 Ok(TypedFact {
441 datum_type: t.datum_type(),
442 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
443 uniform: t.as_uniform().map(Arc::new),
444 exotic_fact,
445 konst: Some(t),
446 uniform_tdim,
447 region_of_interest: None,
448 })
449 }
450}
451
452impl From<&TypedFact> for TypedFact {
453 fn from(fact: &TypedFact) -> TypedFact {
454 fact.clone()
455 }
456}
457
458impl<'a> TryFrom<&'a Arc<Tensor>> for TypedFact {
459 type Error = TractError;
460 fn try_from(t: &'a Arc<Tensor>) -> TractResult<TypedFact> {
461 TypedFact::try_from(Arc::clone(t))
462 }
463}
464
465impl fmt::Debug for TypedFact {
466 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
467 write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
468 if self.is_exotic() {
469 if let Some(of) = &self.exotic_fact {
470 write!(fmt, " 🔍 {of:?} ")?
471 } else {
472 write!(fmt, " 🔍 <no exotic fact> ")?
473 }
474 }
475 if let Some(k) = &self.konst {
476 write!(fmt, "🟰 {k:?}")?
477 }
478 if let Some(u) = &self.uniform {
479 write!(fmt, " ◻️{u:?}")?
480 }
481 if let Some(u) = &self.uniform_tdim {
482 write!(fmt, " 📐{u}")?
483 }
484 if let Some(r) = &self.region_of_interest {
485 write!(fmt, " 🬳 {r}")?
486 }
487 Ok(())
488 }
489}
490
491pub trait DatumExt {
492 fn scalar_fact() -> TypedFact;
493 fn fact<S>(shape: S) -> TypedFact
494 where
495 S: Into<ShapeFact>;
496}
497
498impl<T: Datum> DatumExt for T {
499 #[allow(clippy::needless_borrow)]
500 fn scalar_fact() -> TypedFact {
501 TypedFact::shape::<Self, &[usize]>(&[])
502 }
503
504 fn fact<S>(shape: S) -> TypedFact
505 where
506 S: Into<ShapeFact>,
507 {
508 TypedFact::shape::<Self, _>(shape)
509 }
510}
511
512pub trait DatumTypeExt {
513 fn scalar_fact(&self) -> TypedFact;
514 fn fact<S>(&self, shape: S) -> TypedFact
515 where
516 S: Into<ShapeFact>;
517}
518
519impl DatumTypeExt for DatumType {
520 #[allow(clippy::needless_borrow)]
521 fn scalar_fact(&self) -> TypedFact {
522 TypedFact::dt_shape::<&[usize]>(*self, &[])
523 }
524
525 fn fact<S>(&self, shape: S) -> TypedFact
526 where
527 S: Into<ShapeFact>,
528 {
529 TypedFact::dt_shape(*self, shape)
530 }
531}