1use crate::internal::*;
3use downcast_rs::Downcast;
4use std::fmt;
5use tract_linalg::block_quant::{BlockQuantFact, BlockQuantValue};
6
7#[derive(Clone, PartialEq, Eq, Hash)]
8pub struct ShapeFact {
9 dims: TVec<TDim>,
10 concrete: Option<TVec<usize>>,
11}
12
13impl ShapeFact {
14 #[inline]
15 pub fn rank(&self) -> usize {
16 self.dims.len()
17 }
18
19 fn compute_concrete(&mut self) {
20 assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
21 self.concrete =
22 self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
23 }
24
25 #[inline]
27 pub fn as_concrete(&self) -> Option<&[usize]> {
28 self.concrete.as_deref()
29 }
30
31 #[inline]
33 pub fn is_concrete(&self) -> bool {
34 self.concrete.is_some()
35 }
36
37 #[inline]
39 pub fn to_tvec(&self) -> TVec<TDim> {
40 self.dims.clone()
41 }
42
43 #[inline]
45 pub fn volume(&self) -> TDim {
46 self.dims.iter().product()
47 }
48
49 #[inline]
50 pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<ShapeFact>> {
51 if self.is_concrete() {
52 Ok(Cow::Borrowed(self))
53 } else {
54 Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
55 }
56 }
57
58 #[inline]
59 pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<usize>>> {
60 if let Some(c) = &self.concrete {
61 Ok(Cow::Borrowed(c))
62 } else {
63 Ok(Cow::Owned(
64 self.iter()
65 .map(|d| d.eval_to_i64(values).map(|d| d as usize))
66 .collect::<TractResult<TVec<_>>>()?,
67 ))
68 }
69 }
70
71 #[inline]
72 pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<isize>>> {
73 if let Some(c) = &self.concrete {
74 #[allow(unknown_lints, clippy::missing_transmute_annotations)]
75 Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
77 } else {
78 Ok(Cow::Owned(
79 self.iter()
80 .map(|d| d.eval_to_i64(values).map(|d| d as isize))
81 .collect::<TractResult<TVec<_>>>()?,
82 ))
83 }
84 }
85
86 pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
87 let mut dims =
88 ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
89 dims.compute_concrete();
90 dims
91 }
92
93 pub fn dims(&self) -> &[TDim] {
94 self.dims.as_slice()
95 }
96
97 pub fn set(&mut self, ix: usize, dim: TDim) {
98 self.dims[ix] = dim;
99 self.compute_concrete();
100 }
101
102 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
103 self.dims.insert(axis, 1.into());
104 if let Some(concrete) = &mut self.concrete {
105 concrete.insert(axis, 1);
106 }
107 Ok(())
108 }
109
110 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
111 self.dims.remove(axis);
112 if let Some(concrete) = &mut self.concrete {
113 concrete.remove(axis);
114 } else {
115 self.compute_concrete();
116 };
117 Ok(())
118 }
119
120 pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
121 if self.rank() == _other.rank() {
122 self.dims
123 .iter()
124 .zip(_other.dims.iter())
125 .all(|(dim, other_dim)| dim.compatible_with(other_dim))
126 } else {
127 false
128 }
129 }
130
131 pub fn scalar() -> ShapeFact {
132 let void: &[usize] = &[];
133 Self::from(void)
134 }
135
136 pub fn consistent(&self) -> TractResult<()> {
137 ensure!(
138 self.concrete
139 == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
140 );
141 Ok(())
142 }
143}
144
145impl std::ops::Deref for ShapeFact {
146 type Target = [TDim];
147 fn deref(&self) -> &[TDim] {
148 &self.dims
149 }
150}
151
152impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
153 fn from(it: T) -> ShapeFact {
154 ShapeFact::from_dims(it)
155 }
156}
157
158pub trait Fact: std::fmt::Debug + Downcast + dyn_clone::DynClone + Send + Sync + 'static {
161 fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>>;
162
163 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
164 self.to_typed_fact()?.matches(t, symbols)
165 }
166
167 fn same_as(&self, _other: &dyn Fact) -> bool;
168
169 fn compatible_with(&self, _other: &dyn Fact) -> bool;
171
172 fn datum_type(&self) -> Option<DatumType>;
173}
174
175impl_downcast!(Fact);
176dyn_clone::clone_trait_object!(Fact);
177
178impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
179 fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
180 ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
181 }
182}
183
184impl fmt::Debug for ShapeFact {
185 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
186 use tract_itertools::Itertools;
187 write!(fmt, "{}", self.iter().join(","))
188 }
189}
190
191impl AsRef<[TDim]> for ShapeFact {
192 fn as_ref(&self) -> &[TDim] {
193 &self.dims
194 }
195}
196
197#[derive(Clone, PartialEq, Eq, Hash)]
199pub struct TypedFact {
200 pub datum_type: DatumType,
202 pub shape: ShapeFact,
204 pub konst: Option<Arc<Tensor>>,
206 pub uniform: Option<Arc<Tensor>>,
208 pub opaque_fact: Option<Box<dyn OpaqueFact>>,
210}
211
212impl TypedFact {
213 pub fn scalar<T>() -> TypedFact
214 where
215 T: Datum,
216 {
217 Self::dt_scalar(T::datum_type())
218 }
219
220 pub fn shape<T, S>(shape: S) -> TypedFact
221 where
222 T: Datum,
223 S: Into<ShapeFact>,
224 {
225 Self::dt_shape(T::datum_type(), shape)
226 }
227
228 pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
229 TypedFact {
230 datum_type: t.datum_type(),
231 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
232 uniform: None,
233 konst: None,
234 opaque_fact: None,
235 }
236 }
237
238 pub fn mem_size(&self) -> TDim {
239 self.shape.volume() * self.datum_type.size_of()
240 + self.opaque_fact().map(|it| it.mem_size()).unwrap_or(0.into())
241 }
242
243 pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
244 TypedFact {
245 datum_type,
246 shape: ShapeFact::scalar(),
247 konst: None,
248 uniform: None,
249 opaque_fact: None,
250 }
251 }
252
253 pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
254 where
255 S: Into<ShapeFact>,
256 {
257 TypedFact { datum_type, shape: shape.into(), konst: None, uniform: None, opaque_fact: None }
258 }
259
260 pub fn rank(&self) -> usize {
261 if cfg!(debug_assertions) {
262 self.consistent().unwrap();
263 }
264 self.shape.rank()
265 }
266
267 fn format_dt_shape_nocheck(&self) -> String {
268 if self.shape.rank() > 0 {
269 format!("{:?},{:?}", self.shape, self.datum_type)
270 } else {
271 format!("{:?}", self.datum_type)
272 }
273 }
274
275 pub fn format_dt_shape(&self) -> String {
276 if cfg!(debug_assertions) {
277 self.consistent().unwrap()
278 }
279 self.format_dt_shape_nocheck()
280 }
281
282 pub fn consistent(&self) -> TractResult<()> {
283 self.shape.consistent()?;
284 ensure!(self.datum_type.is_opaque() == self.opaque_fact.is_some());
285 if let Some(k) = &self.konst {
286 if !self.matches(k.as_ref(), None)? {
287 bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
288 }
289 if let Some(bqf) = self.opaque_fact().and_then(|of| of.downcast_ref::<BlockQuantFact>())
290 {
291 for o in k.as_slice::<Opaque>().unwrap() {
292 ensure!(o.is::<BlockQuantValue>());
293 ensure!(o.downcast_ref::<BlockQuantValue>().unwrap().fact == *bqf);
294 }
295 }
296 }
297 if let Some(u) = &self.uniform {
298 if self.datum_type != u.datum_type() {
299 bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
300 }
301 }
302 if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
303 if let Some(k) = k.as_uniform() {
304 if &k != u {
305 bail!("Uniform value and uniform constant mismatch: {:?}, {:?}", u, k);
306 }
307 } else {
308 bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
309 }
310 }
311 Ok(())
312 }
313
314 pub fn without_value(&self) -> Self {
315 let mut new = self.clone();
316 new.konst = None;
317 new.uniform = None;
318 new
319 }
320
321 pub fn with_opaque_fact<O: Into<Box<dyn OpaqueFact>>>(mut self, opaque_fact: O) -> Self {
322 self.opaque_fact = Some(opaque_fact.into());
323 self
324 }
325
326 pub fn opaque_fact(&self) -> Option<&dyn OpaqueFact> {
327 self.opaque_fact.as_deref()
328 }
329}
330
331impl Fact for TypedFact {
332 fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
333 if cfg!(debug_assertions) {
334 self.consistent()?
335 }
336 Ok(Cow::Borrowed(self))
337 }
338
339 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
340 if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
341 return Ok(false);
342 }
343 for i in 0..t.rank() {
344 if let Ok(dim) =
345 self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
346 {
347 if dim != t.shape()[i] {
348 return Ok(false);
349 }
350 }
351 }
352 Ok(true)
353 }
354
355 fn same_as(&self, other: &dyn Fact) -> bool {
356 if cfg!(debug_assertions) {
357 self.consistent().unwrap()
358 }
359 if let Some(other) = other.downcast_ref::<Self>() {
360 if cfg!(debug_assertions) {
361 other.consistent().unwrap()
362 }
363 self == other
364 } else {
365 false
366 }
367 }
368
369 fn compatible_with(&self, other: &dyn Fact) -> bool {
370 if cfg!(debug_assertions) {
371 self.consistent().unwrap()
372 }
373 if let Some(other) = other.downcast_ref::<Self>() {
374 if cfg!(debug_assertions) {
375 other.consistent().unwrap()
376 }
377 self.datum_type == other.datum_type
378 && self.shape.compatible_with(&other.shape)
379 && self
380 .opaque_fact()
381 .zip(other.opaque_fact())
382 .map(|(a, b)| a.compatible_with(b))
383 .unwrap_or(true)
384 } else {
385 false
386 }
387 }
388
389 fn datum_type(&self) -> Option<DatumType> {
390 Some(self.datum_type)
391 }
392}
393
394impl From<Tensor> for TypedFact {
395 fn from(t: Tensor) -> TypedFact {
396 TypedFact::from(t.into_arc_tensor())
397 }
398}
399
400impl From<Arc<Tensor>> for TypedFact {
401 fn from(t: Arc<Tensor>) -> TypedFact {
402 TypedFact {
403 datum_type: t.datum_type(),
404 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
405 uniform: t.as_uniform().map(Arc::new),
406 opaque_fact: None,
407 konst: Some(t),
408 }
409 }
410}
411
412impl From<&TypedFact> for TypedFact {
413 fn from(fact: &TypedFact) -> TypedFact {
414 fact.clone()
415 }
416}
417
418impl<'a> From<&'a Arc<Tensor>> for TypedFact {
419 fn from(t: &'a Arc<Tensor>) -> TypedFact {
420 Arc::clone(t).into()
421 }
422}
423
424impl fmt::Debug for TypedFact {
425 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
426 write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
427 if self.datum_type.is_opaque() {
428 if let Some(of) = &self.opaque_fact {
429 write!(fmt, " 🔍 {:?} ", of)?
430 } else {
431 write!(fmt, " 🔍 <no opaque fact> ")?
432 }
433 }
434 if let Some(k) = &self.konst {
435 write!(fmt, "🟰 {:?}", k)?
436 }
437 Ok(())
438 }
439}
440
441pub trait DatumExt {
442 fn scalar_fact() -> TypedFact;
443 fn fact<S>(shape: S) -> TypedFact
444 where
445 S: Into<ShapeFact>;
446}
447
448impl<T: Datum> DatumExt for T {
449 #[allow(clippy::needless_borrow)]
450 fn scalar_fact() -> TypedFact {
451 TypedFact::shape::<Self, &[usize]>(&[])
452 }
453
454 fn fact<S>(shape: S) -> TypedFact
455 where
456 S: Into<ShapeFact>,
457 {
458 TypedFact::shape::<Self, _>(shape)
459 }
460}
461
462pub trait DatumTypeExt {
463 fn scalar_fact(&self) -> TypedFact;
464 fn fact<S>(&self, shape: S) -> TypedFact
465 where
466 S: Into<ShapeFact>;
467}
468
469impl DatumTypeExt for DatumType {
470 #[allow(clippy::needless_borrow)]
471 fn scalar_fact(&self) -> TypedFact {
472 TypedFact::dt_shape::<&[usize]>(*self, &[])
473 }
474
475 fn fact<S>(&self, shape: S) -> TypedFact
476 where
477 S: Into<ShapeFact>,
478 {
479 TypedFact::dt_shape(*self, shape)
480 }
481}