1use crate::internal::*;
3use downcast_rs::Downcast;
4use std::fmt;
5
6#[derive(Clone, PartialEq, Eq, Hash)]
7pub struct ShapeFact {
8 dims: TVec<TDim>,
9 concrete: Option<TVec<usize>>,
10}
11
12impl ShapeFact {
13 #[inline]
14 pub fn rank(&self) -> usize {
15 self.dims.len()
16 }
17
18 fn compute_concrete(&mut self) {
19 assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
20 self.concrete =
21 self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
22 }
23
24 #[inline]
26 pub fn as_concrete(&self) -> Option<&[usize]> {
27 self.concrete.as_deref()
28 }
29
30 #[inline]
32 pub fn is_concrete(&self) -> bool {
33 self.concrete.is_some()
34 }
35
36 #[inline]
38 pub fn to_tvec(&self) -> TVec<TDim> {
39 self.dims.clone()
40 }
41
42 #[inline]
44 pub fn volume(&self) -> TDim {
45 self.dims.iter().product()
46 }
47
48 #[inline]
49 pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<ShapeFact>> {
50 if self.is_concrete() {
51 Ok(Cow::Borrowed(self))
52 } else {
53 Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
54 }
55 }
56
57 #[inline]
58 pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<usize>>> {
59 if let Some(c) = &self.concrete {
60 Ok(Cow::Borrowed(c))
61 } else {
62 Ok(Cow::Owned(
63 self.iter()
64 .map(|d| d.eval_to_i64(values).map(|d| d as usize))
65 .collect::<TractResult<TVec<_>>>()?,
66 ))
67 }
68 }
69
70 #[inline]
71 pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<isize>>> {
72 if let Some(c) = &self.concrete {
73 #[allow(unknown_lints, clippy::missing_transmute_annotations)]
74 Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
76 } else {
77 Ok(Cow::Owned(
78 self.iter()
79 .map(|d| d.eval_to_i64(values).map(|d| d as isize))
80 .collect::<TractResult<TVec<_>>>()?,
81 ))
82 }
83 }
84
85 pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
86 let mut dims =
87 ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
88 dims.compute_concrete();
89 dims
90 }
91
92 pub fn dims(&self) -> &[TDim] {
93 self.dims.as_slice()
94 }
95
96 pub fn set(&mut self, ix: usize, dim: TDim) {
97 self.dims[ix] = dim;
98 self.compute_concrete();
99 }
100
101 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
102 self.dims.insert(axis, 1.into());
103 if let Some(concrete) = &mut self.concrete {
104 concrete.insert(axis, 1);
105 }
106 Ok(())
107 }
108
109 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
110 self.dims.remove(axis);
111 if let Some(concrete) = &mut self.concrete {
112 concrete.remove(axis);
113 } else {
114 self.compute_concrete();
115 };
116 Ok(())
117 }
118
119 pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
120 if self.rank() == _other.rank() {
121 self.dims
122 .iter()
123 .zip(_other.dims.iter())
124 .all(|(dim, other_dim)| dim.compatible_with(other_dim))
125 } else {
126 false
127 }
128 }
129
130 pub fn scalar() -> ShapeFact {
131 let void: &[usize] = &[];
132 Self::from(void)
133 }
134
135 pub fn consistent(&self) -> TractResult<()> {
136 ensure!(
137 self.concrete
138 == self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
139 );
140 Ok(())
141 }
142}
143
144impl std::ops::Deref for ShapeFact {
145 type Target = [TDim];
146 fn deref(&self) -> &[TDim] {
147 &self.dims
148 }
149}
150
151impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
152 fn from(it: T) -> ShapeFact {
153 ShapeFact::from_dims(it)
154 }
155}
156
157pub trait Fact: std::fmt::Debug + Downcast + dyn_clone::DynClone + Send + Sync + 'static {
160 fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>>;
161
162 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
163 self.to_typed_fact()?.matches(t, symbols)
164 }
165
166 fn same_as(&self, _other: &dyn Fact) -> bool;
167
168 fn compatible_with(&self, _other: &dyn Fact) -> bool;
170
171 fn datum_type(&self) -> Option<DatumType>;
172}
173
174impl_downcast!(Fact);
175dyn_clone::clone_trait_object!(Fact);
176
177impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
178 fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
179 ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
180 }
181}
182
183impl fmt::Debug for ShapeFact {
184 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
185 use tract_itertools::Itertools;
186 write!(fmt, "{}", self.iter().join(","))
187 }
188}
189
190impl AsRef<[TDim]> for ShapeFact {
191 fn as_ref(&self) -> &[TDim] {
192 &self.dims
193 }
194}
195
196#[derive(Clone, PartialEq, Eq, Hash)]
198pub struct TypedFact {
199 pub datum_type: DatumType,
201 pub shape: ShapeFact,
203 pub konst: Option<Arc<Tensor>>,
205 pub uniform: Option<Arc<Tensor>>,
207 pub opaque_fact: Option<Box<dyn OpaqueFact>>,
209}
210
211impl TypedFact {
212 pub fn scalar<T>() -> TypedFact
213 where
214 T: Datum,
215 {
216 Self::dt_scalar(T::datum_type())
217 }
218
219 pub fn shape<T, S>(shape: S) -> TypedFact
220 where
221 T: Datum,
222 S: Into<ShapeFact>,
223 {
224 Self::dt_shape(T::datum_type(), shape)
225 }
226
227 pub fn shape_and_dt_of(t: &Tensor) -> TypedFact {
228 TypedFact {
229 datum_type: t.datum_type(),
230 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
231 uniform: None,
232 konst: None,
233 opaque_fact: None,
234 }
235 }
236
237 pub fn mem_size(&self) -> TDim {
238 self.shape.volume() * self.datum_type.size_of()
239 + self.opaque_fact.as_ref().map(|it| it.mem_size()).unwrap_or(0.into())
240 }
241
242 pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
243 TypedFact {
244 datum_type,
245 shape: ShapeFact::scalar(),
246 konst: None,
247 uniform: None,
248 opaque_fact: None,
249 }
250 }
251
252 pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
253 where
254 S: Into<ShapeFact>,
255 {
256 TypedFact { datum_type, shape: shape.into(), konst: None, uniform: None, opaque_fact: None }
257 }
258
259 pub fn rank(&self) -> usize {
260 if cfg!(debug_assertions) {
261 self.consistent().unwrap();
262 }
263 self.shape.rank()
264 }
265
266 fn format_dt_shape_nocheck(&self) -> String {
267 if self.shape.rank() > 0 {
268 format!("{:?},{:?}", self.shape, self.datum_type)
269 } else {
270 format!("{:?}", self.datum_type)
271 }
272 }
273
274 pub fn format_dt_shape(&self) -> String {
275 if cfg!(debug_assertions) {
276 self.consistent().unwrap()
277 }
278 self.format_dt_shape_nocheck()
279 }
280
281 pub fn consistent(&self) -> TractResult<()> {
282 self.shape.consistent()?;
283 if let Some(k) = &self.konst {
284 if !self.matches(k.as_ref(), None)? {
285 bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
286 }
287 }
288 if let Some(u) = &self.uniform {
289 if self.datum_type != u.datum_type() {
290 bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
291 }
292 }
293 if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
294 if let Some(k) = k.as_uniform() {
295 if &k != u {
296 bail!("Uniform value and uniform constant mismatch: {:?}, {:?}", u, k);
297 }
298 } else {
299 bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
300 }
301 }
302 Ok(())
303 }
304
305 pub fn without_value(&self) -> Self {
306 let mut new = self.clone();
307 new.konst = None;
308 new.uniform = None;
309 new
310 }
311
312 pub fn with_opaque_fact<O: Into<Box<dyn OpaqueFact>>>(mut self, opaque_fact: O) -> Self {
313 self.opaque_fact = Some(opaque_fact.into());
314 self
315 }
316}
317
318impl Fact for TypedFact {
319 fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
320 if cfg!(debug_assertions) {
321 self.consistent()?
322 }
323 Ok(Cow::Borrowed(self))
324 }
325
326 fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
327 if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
328 return Ok(false);
329 }
330 for i in 0..t.rank() {
331 if let Ok(dim) =
332 self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
333 {
334 if dim != t.shape()[i] {
335 return Ok(false);
336 }
337 }
338 }
339 Ok(true)
340 }
341
342 fn same_as(&self, other: &dyn Fact) -> bool {
343 if cfg!(debug_assertions) {
344 self.consistent().unwrap()
345 }
346 if let Some(other) = other.downcast_ref::<Self>() {
347 if cfg!(debug_assertions) {
348 other.consistent().unwrap()
349 }
350 self == other
351 } else {
352 false
353 }
354 }
355
356 fn compatible_with(&self, other: &dyn Fact) -> bool {
357 if cfg!(debug_assertions) {
358 self.consistent().unwrap()
359 }
360 if let Some(other) = other.downcast_ref::<Self>() {
361 if cfg!(debug_assertions) {
362 other.consistent().unwrap()
363 }
364 self.datum_type == other.datum_type
365 && self.shape.compatible_with(&other.shape)
366 && self
367 .opaque_fact
368 .as_ref()
369 .zip(other.opaque_fact.as_ref())
370 .map(|(a, b)| a.compatible_with(&**b))
371 .unwrap_or(true)
372 } else {
373 false
374 }
375 }
376
377 fn datum_type(&self) -> Option<DatumType> {
378 Some(self.datum_type)
379 }
380}
381
382impl From<Tensor> for TypedFact {
383 fn from(t: Tensor) -> TypedFact {
384 TypedFact::from(t.into_arc_tensor())
385 }
386}
387
388impl From<Arc<Tensor>> for TypedFact {
389 fn from(t: Arc<Tensor>) -> TypedFact {
390 TypedFact {
391 datum_type: t.datum_type(),
392 shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
393 uniform: t.as_uniform().map(Arc::new),
394 opaque_fact: None,
395 konst: Some(t),
396 }
397 }
398}
399
400impl From<&TypedFact> for TypedFact {
401 fn from(fact: &TypedFact) -> TypedFact {
402 fact.clone()
403 }
404}
405
406impl<'a> From<&'a Arc<Tensor>> for TypedFact {
407 fn from(t: &'a Arc<Tensor>) -> TypedFact {
408 Arc::clone(t).into()
409 }
410}
411
412impl fmt::Debug for TypedFact {
413 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
414 write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
415 if self.datum_type.is_opaque() {
416 if let Some(of) = &self.opaque_fact {
417 write!(fmt, " 🔍 {:?} ", of)?
418 } else {
419 write!(fmt, " 🔍 <no opaque fact> ")?
420 }
421 }
422 if let Some(k) = &self.konst {
423 write!(fmt, "🟰 {:?}", k)?
424 }
425 Ok(())
426 }
427}
428
429pub trait DatumExt {
430 fn scalar_fact() -> TypedFact;
431 fn fact<S>(shape: S) -> TypedFact
432 where
433 S: Into<ShapeFact>;
434}
435
436impl<T: Datum> DatumExt for T {
437 #[allow(clippy::needless_borrow)]
438 fn scalar_fact() -> TypedFact {
439 TypedFact::shape::<Self, &[usize]>(&[])
440 }
441
442 fn fact<S>(shape: S) -> TypedFact
443 where
444 S: Into<ShapeFact>,
445 {
446 TypedFact::shape::<Self, _>(shape)
447 }
448}
449
450pub trait DatumTypeExt {
451 fn scalar_fact(&self) -> TypedFact;
452 fn fact<S>(&self, shape: S) -> TypedFact
453 where
454 S: Into<ShapeFact>;
455}
456
457impl DatumTypeExt for DatumType {
458 #[allow(clippy::needless_borrow)]
459 fn scalar_fact(&self) -> TypedFact {
460 TypedFact::dt_shape::<&[usize]>(*self, &[])
461 }
462
463 fn fact<S>(&self, shape: S) -> TypedFact
464 where
465 S: Into<ShapeFact>,
466 {
467 TypedFact::dt_shape(*self, shape)
468 }
469}