1use std::convert::TryFrom;
2use std::fmt;
3use std::sync::Arc;
4
5use super::factoid::*;
6use crate::internal::*;
7
8#[derive(Clone, PartialEq, Eq, Default, Hash)]
22pub struct InferenceFact {
23 pub datum_type: TypeFactoid,
24 pub shape: ShapeFactoid,
25 pub value: ValueFact,
26}
27
28impl InferenceFact {
29 pub fn new() -> InferenceFact {
31 InferenceFact::default()
32 }
33
34 pub fn any() -> InferenceFact {
35 InferenceFact::default()
36 }
37
38 pub fn dt(dt: DatumType) -> InferenceFact {
39 InferenceFact::default().with_datum_type(dt)
40 }
41
42 pub fn dt_shape<S: Into<ShapeFactoid>>(dt: DatumType, shape: S) -> InferenceFact {
43 InferenceFact::dt(dt).with_shape(shape)
44 }
45
46 pub fn shape<S: Into<ShapeFactoid>>(shape: S) -> InferenceFact {
47 InferenceFact::default().with_shape(shape)
48 }
49
50 pub fn with_datum_type(self, dt: DatumType) -> InferenceFact {
51 InferenceFact { datum_type: dt.into(), ..self }
52 }
53
54 pub fn without_datum_type(self) -> InferenceFact {
55 InferenceFact { datum_type: TypeFactoid::Any, ..self }
56 }
57
58 pub fn with_shape<S: Into<ShapeFactoid>>(self, shape: S) -> InferenceFact {
59 InferenceFact { shape: shape.into(), ..self }
60 }
61
62 pub fn format_dt_shape(&self) -> String {
63 if !self.shape.open && self.shape.dims.len() == 0 {
64 self.datum_type
65 .concretize()
66 .map(|dt| format!("{dt:?}"))
67 .unwrap_or_else(|| "?".to_string())
68 } else {
69 format!(
70 "{:?},{}",
71 self.shape,
72 self.datum_type
73 .concretize()
74 .map(|dt| format!("{dt:?}"))
75 .unwrap_or_else(|| "?".to_string())
76 )
77 }
78 }
79
80 pub fn dt_shape_from_tensor(t: &Tensor) -> InferenceFact {
81 InferenceFact::dt_shape(t.datum_type(), t.shape())
82 }
83
84 pub fn without_value(self) -> InferenceFact {
85 InferenceFact { value: GenericFactoid::Any, ..self }
86 }
87}
88
89impl Factoid for InferenceFact {
90 type Concrete = Arc<Tensor>;
91
92 fn concretize(&self) -> Option<Self::Concrete> {
94 self.value.concretize()
95 }
96
97 fn unify(&self, other: &Self) -> TractResult<Self> {
99 let tensor = InferenceFact {
100 datum_type: self.datum_type.unify(&other.datum_type)?,
101 shape: self.shape.unify(&other.shape)?,
102 value: self.value.unify(&other.value)?,
103 };
104
105 trace!("Unifying {self:?} with {other:?} into {tensor:?}.");
106
107 Ok(tensor)
108 }
109}
110
111impl fmt::Debug for InferenceFact {
112 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
113 if let Some(t) = self.value.concretize() {
114 write!(formatter, "{t:?}")
115 } else {
116 write!(formatter, "{}", self.format_dt_shape())
117 }
118 }
119}
120
121use crate::infer::factoid::Factoid;
122
123impl Fact for InferenceFact {
124 fn to_typed_fact(&self) -> TractResult<Cow<'_, TypedFact>> {
125 Ok(Cow::Owned(TypedFact::try_from(self)?))
126 }
127
128 fn matches(&self, t: &Tensor, _symbols: Option<&SymbolValues>) -> TractResult<bool> {
129 if let Some(dt) = self.datum_type() {
130 if t.datum_type() != dt {
131 return Ok(false);
132 }
133 }
134 if let Some(shape) = self.shape.concretize() {
135 if *ShapeFact::from(t.shape()) != *shape {
136 return Ok(false);
137 }
138 }
139 if let Some(value) = self.value.concretize() {
140 if &*value != t {
141 return Ok(false);
142 }
143 }
144 Ok(true)
145 }
146
147 fn same_as(&self, other: &dyn Fact) -> bool {
148 if let Some(other) = other.downcast_ref::<Self>() {
149 self.unify(other).is_ok()
150 } else {
151 false
152 }
153 }
154
155 fn compatible_with(&self, other: &dyn Fact) -> bool {
156 if let Some(other) = other.downcast_ref::<Self>() {
157 self.unify(other).is_ok()
158 } else {
159 false
160 }
161 }
162
163 fn datum_type(&self) -> Option<DatumType> {
164 self.datum_type.concretize()
165 }
166}
167
168impl TryFrom<&InferenceFact> for TypedFact {
169 type Error = TractError;
170 fn try_from(fact: &InferenceFact) -> TractResult<TypedFact> {
171 if let (Some(datum_type), Some(shape)) =
172 (fact.datum_type.concretize(), fact.shape.concretize())
173 {
174 let shape = ShapeFact::from_dims(shape);
175 let konst = fact.value.concretize();
176 let uniform = konst.as_ref().and_then(|k| k.as_uniform()).map(Arc::new);
177 Ok(TypedFact { datum_type, shape, konst, uniform, opaque_fact: None })
178 } else {
179 bail!("Can not make a TypedFact out of {:?}", fact)
180 }
181 }
182}
183
184impl<'a> From<&'a InferenceFact> for InferenceFact {
185 fn from(t: &'a InferenceFact) -> InferenceFact {
186 t.clone()
187 }
188}
189
190impl<'a> From<&'a TypedFact> for InferenceFact {
191 fn from(t: &'a TypedFact) -> InferenceFact {
192 let mut fact = InferenceFact::dt_shape(t.datum_type, t.shape.iter());
193 if let Some(k) = &t.konst {
194 fact.value = Arc::clone(k).into();
195 }
196 fact
197 }
198}
199
200impl From<TypedFact> for InferenceFact {
201 fn from(t: TypedFact) -> InferenceFact {
202 InferenceFact::from(&t)
203 }
204}
205
206impl<'a> From<&'a Arc<Tensor>> for InferenceFact {
207 fn from(t: &'a Arc<Tensor>) -> InferenceFact {
208 InferenceFact::from(&TypedFact::from(Arc::clone(t)))
209 }
210}
211
212impl From<Arc<Tensor>> for InferenceFact {
213 fn from(t: Arc<Tensor>) -> InferenceFact {
214 InferenceFact::from(&TypedFact::from(t))
215 }
216}
217
218impl From<Tensor> for InferenceFact {
219 fn from(t: Tensor) -> InferenceFact {
220 let mut fact = InferenceFact::dt_shape(t.datum_type(), t.shape());
221 fact.value = t.into_arc_tensor().into();
222 fact
223 }
224}