1use crate::internal::*;
3use downcast_rs::Downcast;
4use std::fmt;
5use tract_linalg::block_quant::BlockQuantFact;
6
7#[derive(Clone, PartialEq, Eq, Hash)]
8pub struct ShapeFact {
9 dims: TVec<TDim>,
10 concrete: Option<TVec<usize>>,
11}
12
13impl ShapeFact {
14 #[inline]
15 pub fn rank(&self) -> usize {
16 self.dims.len()
17 }
18
19 fn compute_concrete(&mut self) {
20 assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
21 self.concrete =
22 self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
23 }
24
25 #[inline]
27 pub fn as_concrete(&self) -> Option<&[usize]> {
28 self.concrete.as_deref()
29 }
30
31 #[inline]
33 pub fn is_concrete(&self) -> bool {
34 self.concrete.is_some()
35 }
36
37 #[inline]
39 pub fn to_tvec(&self) -> TVec<TDim> {
40 self.dims.clone()
41 }
42
43 #[inline]
45 pub fn volume(&self) -> TDim {
46 self.dims.iter().product()
47 }
48
49 #[inline]
50 pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<'_, ShapeFact>> {
51 if self.is_concrete() {
52 Ok(Cow::Borrowed(self))
53 } else {
54 Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
55 }
56 }
57
58 #[inline]
59 pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<usize>>> {
60 if let Some(c) = &self.concrete {
61 Ok(Cow::Borrowed(c))
62 } else {
63 Ok(Cow::Owned(
64 self.iter()
65 .map(|d| d.eval_to_i64(values).map(|d| d as usize))
66 .collect::<TractResult<TVec<_>>>()?,
67 ))
68 }
69 }
70
71 #[inline]
72 pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<'_, TVec<isize>>> {
73 if let Some(c) = &self.concrete {
74 #[allow(unknown_lints, clippy::missing_transmute_annotations)]
75 Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
77 } else {
78 Ok(Cow::Owned(
79 self.iter()
80 .map(|d| d.eval_to_i64(values).map(|d| d as isize))
81 .collect::<TractResult<TVec<_>>>()?,
82 ))
83 }
84 }
85
86 pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
87 let mut dims =
88 ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
89 dims.compute_concrete();
90 dims
91 }
92
93 pub fn dims(&self) -> &[TDim] {
94 self.dims.as_slice()
95 }
96
97 pub fn set(&mut self, ix: usize, dim: TDim) {
98 self.dims[ix] = dim;
99 self.compute_concrete();
100 }
101
102 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
103 self.dims.insert(axis, 1.into());
104 if let Some(concrete) = &mut self.concrete {
105 concrete.insert(axis, 1);
106 }
107 Ok(())
108 }
109
110 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
111 self.dims.remove(axis);
112 if let Some(concrete) = &mut self.concrete {
113 concrete.remove(axis);
114 } else {
115 self.compute_concrete();
116 };
117 Ok(())
118 }
119
120 pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
121 if self.rank() == _other.rank() {
122 self.dims
123 .iter()
124 .zip(_other.dims.iter())
125 .all(|(dim, other_dim)| dim.compatible_with(other_dim))
126 } else {
127 false
128 }
129 }
130
131 pub fn scalar() -> ShapeFact {
132 let void: &[usize] = &[];
133 Self::from(void)
134 }
135
136 pub fn consistent(&self) -> TractResult<()> {
137 ensure!(
138 self.concrete
139 == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
140 );
141 Ok(())
142 }
143}
144
145impl std::ops::Deref for ShapeFact {
146 type Target = [TDim];
147 fn deref(&self) -> &[TDim] {
148 &self.dims
149 }
150}
151
152impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
153 fn from(it: T) -> ShapeFact {
154 ShapeFact::from_dims(it)
155 }
156}
157
158pub trait Fact: std::fmt::Debug + Downcast + dyn_clone::DynClone + Send + Sync + 'static {
161 fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>>;
162
163 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
164 self.to_typed_fact()?.matches(t, symbols)
165 }
166
167 fn same_as(&self, _other: &dyn Fact) -> bool;
168
169 fn compatible_with(&self, _other: &dyn Fact) -> bool;
171
172 fn datum_type(&self) -> Option<DatumType>;
173}
174
175impl_downcast!(Fact);
176dyn_clone::clone_trait_object!(Fact);
177
178impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
179 fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
180 ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
181 }
182}
183
184impl fmt::Debug for ShapeFact {
185 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
186 use tract_itertools::Itertools;
187 write!(fmt, "{}", self.iter().join(","))
188 }
189}
190
191impl AsRef<[TDim]> for ShapeFact {
192 fn as_ref(&self) -> &[TDim] {
193 &self.dims
194 }
195}
196
197#[derive(Clone, PartialEq, Eq, Hash)]
199pub struct TypedFact {
200 pub datum_type: DatumType,
202 pub shape: ShapeFact,
204 pub konst: Option<Arc<Tensor>>,
206 pub uniform: Option<Arc<Tensor>>,
208 pub opaque_fact: Option<Box<dyn OpaqueFact>>,
210}
211
212impl TypedFact {
213 pub fn scalar<T>() -> TypedFact
214 where
215 T: Datum,
216 {
217 Self::dt_scalar(T::datum_type())
218 }
219
220 pub fn shape<T, S>(shape: S) -> TypedFact
221 where
222 T: Datum,
223 S: Into<ShapeFact>,
224 {
225 Self::dt_shape(T::datum_type(), shape)
226 }
227
228 pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
229 TypedFact {
230 datum_type: t.datum_type(),
231 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
232 uniform: None,
233 konst: None,
234 opaque_fact: None,
235 }
236 }
237
238 pub fn mem_size(&self) -> TDim {
239 self.shape.volume() * self.datum_type.size_of()
240 + self.opaque_fact().iter().flat_map(|it| it.buffer_sizes()).sum::<TDim>()
241 }
242
243 pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
244 TypedFact {
245 datum_type,
246 shape: ShapeFact::scalar(),
247 konst: None,
248 uniform: None,
249 opaque_fact: None,
250 }
251 }
252
253 pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
254 where
255 S: Into<ShapeFact>,
256 {
257 TypedFact { datum_type, shape: shape.into(), konst: None, uniform: None, opaque_fact: None }
258 }
259
260 pub fn rank(&self) -> usize {
261 if cfg!(debug_assertions) {
262 self.consistent().unwrap();
263 }
264 self.shape.rank()
265 }
266
267 fn format_dt_shape_nocheck(&self) -> String {
268 if self.shape.rank() > 0 {
269 format!("{:?},{:?}", self.shape, self.datum_type)
270 } else {
271 format!("{:?}", self.datum_type)
272 }
273 }
274
275 pub fn format_dt_shape(&self) -> String {
276 if cfg!(debug_assertions) {
277 self.consistent().unwrap()
278 }
279 self.format_dt_shape_nocheck()
280 }
281
282 pub fn consistent(&self) -> TractResult<()> {
283 self.shape.consistent()?;
284 ensure!(self.datum_type.is_opaque() == self.opaque_fact.is_some());
285 if let Some(k) = &self.konst {
286 if !self.matches(k.as_ref(), None)? {
287 bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
288 }
289 if let Some(bqf) = self.opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
290 {
291 for o in k.as_slice::<Opaque>().unwrap() {
292 ensure!(o.is::<BlobWithFact>());
293 ensure!(
294 o.downcast_ref::<BlobWithFact>()
295 .and_then(|bwf| bwf.fact.downcast_ref::<BlockQuantFact>())
296 .is_some()
297 );
298 ensure!(
299 o.downcast_ref::<BlobWithFact>()
300 .and_then(|bwf| bwf.fact.downcast_ref::<BlockQuantFact>())
301 .unwrap()
302 == bqf
303 );
304 }
305 }
306 }
307 if let Some(u) = &self.uniform {
308 if self.datum_type != u.datum_type() {
309 bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
310 }
311 }
312 if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
313 if let Some(k) = k.as_uniform() {
314 if &k != u {
315 bail!(
316 "Uniform value and uniform constant mismatch: value:{u:?}, uniform:{k:?}",
317 );
318 }
319 } else {
320 bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
321 }
322 }
323 Ok(())
324 }
325
326 pub fn without_value(&self) -> Self {
327 let mut new = self.clone();
328 new.konst = None;
329 new.uniform = None;
330 new
331 }
332
333 pub fn with_opaque_fact<O: Into<Box<dyn OpaqueFact>>>(mut self, opaque_fact: O) -> Self {
334 self.opaque_fact = Some(opaque_fact.into());
335 self
336 }
337
338 pub fn opaque_fact(&self) -> Option<&dyn OpaqueFact> {
339 self.opaque_fact.as_deref()
340 }
341}
342
343impl Fact for TypedFact {
344 fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>> {
345 if cfg!(debug_assertions) {
346 self.consistent()?
347 }
348 Ok(Cow::Borrowed(self))
349 }
350
351 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
352 if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
353 return Ok(false);
354 }
355 for i in 0..t.rank() {
356 if let Ok(dim) =
357 self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
358 {
359 if dim != t.shape()[i] {
360 return Ok(false);
361 }
362 }
363 }
364 Ok(true)
365 }
366
367 fn same_as(&self, other: &dyn Fact) -> bool {
368 if cfg!(debug_assertions) {
369 self.consistent().unwrap()
370 }
371 if let Some(other) = other.downcast_ref::<Self>() {
372 if cfg!(debug_assertions) {
373 other.consistent().unwrap()
374 }
375 self == other
376 } else {
377 false
378 }
379 }
380
381 fn compatible_with(&self, other: &dyn Fact) -> bool {
382 if cfg!(debug_assertions) {
383 self.consistent().unwrap()
384 }
385 if let Some(other) = other.downcast_ref::<Self>() {
386 if cfg!(debug_assertions) {
387 other.consistent().unwrap()
388 }
389 self.datum_type == other.datum_type
390 && self.shape.compatible_with(&other.shape)
391 && self
392 .opaque_fact()
393 .zip(other.opaque_fact())
394 .map(|(a, b)| a.compatible_with(b))
395 .unwrap_or(true)
396 } else {
397 false
398 }
399 }
400
401 fn datum_type(&self) -> Option<DatumType> {
402 Some(self.datum_type)
403 }
404}
405
406impl From<Tensor> for TypedFact {
407 fn from(t: Tensor) -> TypedFact {
408 TypedFact::from(t.into_arc_tensor())
409 }
410}
411
412impl From<Arc<Tensor>> for TypedFact {
413 fn from(t: Arc<Tensor>) -> TypedFact {
414 TypedFact {
415 datum_type: t.datum_type(),
416 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
417 uniform: t.as_uniform().map(Arc::new),
418 opaque_fact: None,
419 konst: Some(t),
420 }
421 }
422}
423
424impl From<&TypedFact> for TypedFact {
425 fn from(fact: &TypedFact) -> TypedFact {
426 fact.clone()
427 }
428}
429
430impl<'a> From<&'a Arc<Tensor>> for TypedFact {
431 fn from(t: &'a Arc<Tensor>) -> TypedFact {
432 Arc::clone(t).into()
433 }
434}
435
436impl fmt::Debug for TypedFact {
437 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
438 write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
439 if self.datum_type.is_opaque() {
440 if let Some(of) = &self.opaque_fact {
441 write!(fmt, " 🔍 {of:?} ")?
442 } else {
443 write!(fmt, " 🔍 <no opaque fact> ")?
444 }
445 }
446 if let Some(k) = &self.konst {
447 write!(fmt, "🟰 {k:?}")?
448 }
449 Ok(())
450 }
451}
452
453pub trait DatumExt {
454 fn scalar_fact() -> TypedFact;
455 fn fact<S>(shape: S) -> TypedFact
456 where
457 S: Into<ShapeFact>;
458}
459
460impl<T: Datum> DatumExt for T {
461 #[allow(clippy::needless_borrow)]
462 fn scalar_fact() -> TypedFact {
463 TypedFact::shape::<Self, &[usize]>(&[])
464 }
465
466 fn fact<S>(shape: S) -> TypedFact
467 where
468 S: Into<ShapeFact>,
469 {
470 TypedFact::shape::<Self, _>(shape)
471 }
472}
473
474pub trait DatumTypeExt {
475 fn scalar_fact(&self) -> TypedFact;
476 fn fact<S>(&self, shape: S) -> TypedFact
477 where
478 S: Into<ShapeFact>;
479}
480
481impl DatumTypeExt for DatumType {
482 #[allow(clippy::needless_borrow)]
483 fn scalar_fact(&self) -> TypedFact {
484 TypedFact::dt_shape::<&[usize]>(*self, &[])
485 }
486
487 fn fact<S>(&self, shape: S) -> TypedFact
488 where
489 S: Into<ShapeFact>,
490 {
491 TypedFact::dt_shape(*self, shape)
492 }
493}