1use std::fmt;
2use std::marker::PhantomData;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use tract_num_traits::ToPrimitive;
6use tract_num_traits::Zero;
7
8use crate::internal::*;
9
10use self::super::super::factoid::*;
11use self::super::path::Path;
12use self::super::proxies::*;
13use self::super::solver::Context;
14
15pub trait Output: fmt::Debug + Clone + PartialEq {
17 fn wrap(self) -> Wrapped {
19 Self::into_wrapped(self)
20 }
21
22 fn into_wrapped(source: Self) -> Wrapped;
24
25 fn from_wrapped(wrapped: Wrapped) -> TractResult<Self>;
28}
29
30macro_rules! impl_output {
31 ($type:ty, $constr:ident, $name:expr) => {
32 impl Output for $type {
33 fn into_wrapped(source: Self) -> Wrapped {
34 Wrapped::$constr(source)
35 }
36
37 fn from_wrapped(wrapped: Wrapped) -> TractResult<$type> {
38 if let Wrapped::$constr(v) = wrapped {
39 Ok(v)
40 } else {
41 bail!("Tried to get a {} from {:?}.", $name, wrapped);
42 }
43 }
44 }
45 };
46}
47
48impl_output!(IntFactoid, Int, "Int");
49impl_output!(TypeFactoid, Type, "DatumType");
50impl_output!(ShapeFactoid, Shape, "Shape");
51impl_output!(ValueFact, Tensor, "Tensor");
52impl_output!(DimFact, Dim, "TDim");
53
54impl Output for usize {
56 fn into_wrapped(source: usize) -> Wrapped {
57 IntFactoid::into_wrapped((source as i64).into())
58 }
59
60 fn from_wrapped(wrapped: Wrapped) -> TractResult<usize> {
61 IntFactoid::from_wrapped(wrapped.clone())?
62 .concretize()
63 .and_then(|u| u.to_usize())
64 .with_context(|| format!("Tried to convert {wrapped:?} to a usize."))
65 }
66}
67
68impl Output for i64 {
70 fn into_wrapped(source: i64) -> Wrapped {
71 IntFactoid::into_wrapped(source.into())
72 }
73
74 fn from_wrapped(wrapped: Wrapped) -> TractResult<i64> {
75 IntFactoid::from_wrapped(wrapped.clone())?
76 .concretize()
77 .with_context(|| format!("Tried to convert {wrapped:?} to a i64."))
78 }
79}
80
81impl Output for Arc<Tensor> {
83 fn into_wrapped(source: Arc<Tensor>) -> Wrapped {
84 ValueFact::into_wrapped(source.into())
85 }
86
87 fn from_wrapped(wrapped: Wrapped) -> TractResult<Arc<Tensor>> {
88 ValueFact::from_wrapped(wrapped.clone())?
89 .concretize()
90 .with_context(|| format_err!("Tried to convert {:?} to a tensor.", wrapped))
91 }
92}
93
94impl Output for TDim {
96 fn into_wrapped(source: TDim) -> Wrapped {
97 DimFact::into_wrapped(source.into())
98 }
99
100 fn from_wrapped(wrapped: Wrapped) -> TractResult<TDim> {
101 DimFact::from_wrapped(wrapped.clone())?
102 .concretize()
103 .with_context(|| format_err!("Tried to convert {:?} to a usize.", wrapped))
104 }
105}
106
107#[derive(Debug, Clone)]
109pub enum Wrapped {
110 Int(IntFactoid),
111 Type(TypeFactoid),
112 Shape(ShapeFactoid),
113 Tensor(ValueFact),
114 Dim(DimFact),
115}
116
117pub trait TExp<T>: fmt::Debug {
119 fn get(&self, context: &Context) -> TractResult<T>;
121
122 fn set(&self, context: &mut Context, value: T) -> TractResult<bool>;
124
125 fn get_paths(&self) -> Vec<&Path>;
127}
128
129pub struct Exp<T>(Box<dyn TExp<T>>);
130impl<T: Factoid + Output + Clone + fmt::Debug> TExp<T> for Exp<T> {
131 fn get(&self, context: &Context) -> TractResult<T> {
133 self.0.get(context)
134 }
135
136 fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
138 self.0.set(context, value)
139 }
140
141 fn get_paths(&self) -> Vec<&Path> {
143 self.0.get_paths()
144 }
145}
146
147impl<T> fmt::Debug for Exp<T>
148where
149 T: Factoid + Output + Clone + ::std::fmt::Debug,
150{
151 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
152 write!(formatter, "{:?}", self.0)
153 }
154}
155
156pub trait IntoExp<T> {
157 fn bex(self) -> Exp<T>;
159}
160
161#[derive(new)]
162pub struct SumExp<T>(Vec<Exp<T>>)
163where
164 T: Factoid + Output + Clone + ::std::fmt::Debug + 'static;
165
166impl<T> TExp<T> for SumExp<T>
167where
168 T: Factoid + Output + Zero + Add<T> + Neg<Output = T> + Clone + ::std::fmt::Debug + 'static,
169{
170 fn get(&self, context: &Context) -> TractResult<T> {
172 self.0.iter().try_fold(T::zero(), |acc, it| Ok(acc + it.0.get(context)?))
173 }
174
175 fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
177 let mut sum = T::zero();
178 let mut misses = vec![];
179
180 for item in &self.0 {
181 let fact = item.get(context)?;
182 if fact.is_concrete() {
183 sum = sum + fact;
184 } else {
185 misses.push(item);
186 }
187 }
188
189 if misses.len() > 1 {
190 Ok(false)
191 } else if misses.len() == 1 {
192 misses[0].set(context, value + -sum)?;
193 Ok(true)
194 } else if sum == value {
195 Ok(false)
196 } else {
197 bail!("{:?} set to {:?}, already is {:?}", self, value, sum)
198 }
199 }
200
201 fn get_paths(&self) -> Vec<&Path> {
203 self.0.iter().flat_map(|e| e.get_paths()).collect()
204 }
205}
206
207impl<T> fmt::Debug for SumExp<T>
208where
209 T: Factoid + Output + Clone + ::std::fmt::Debug,
210{
211 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
212 for (ix, t) in self.0.iter().enumerate() {
213 if ix > 0 {
214 write!(formatter, " + ")?;
215 }
216 t.fmt(formatter)?;
217 }
218 Ok(())
219 }
220}
221
222pub struct ConstantExp<T>(T)
224where
225 T: Factoid + Output + Clone + ::std::fmt::Debug;
226
227impl<T> TExp<T> for ConstantExp<T>
228where
229 T: Factoid + Output + Clone + ::std::fmt::Debug,
230{
231 fn get(&self, _: &Context) -> TractResult<T> {
233 Ok(self.0.clone())
234 }
235
236 fn set(&self, _: &mut Context, value: T) -> TractResult<bool> {
238 self.0.unify(&value)?;
239 Ok(false)
240 }
241
242 fn get_paths(&self) -> Vec<&Path> {
244 vec![]
245 }
246}
247
248impl<T> fmt::Debug for ConstantExp<T>
249where
250 T: Factoid + Output + Clone + ::std::fmt::Debug,
251{
252 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
253 write!(formatter, "{:?}", self.0)
254 }
255}
256
257pub struct VariableExp<T>(Path, PhantomData<T>)
263where
264 T: Factoid + Output + Clone + ::std::fmt::Debug;
265
266impl<T> TExp<T> for VariableExp<T>
267where
268 T: Factoid + Output + Clone + ::std::fmt::Debug,
269{
270 fn get(&self, context: &Context) -> TractResult<T> {
272 context.get(&self.0).with_context(|| format!("while getting {:?}", self.0))
273 }
274
275 fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
277 let old = self.get(context)?;
278 let new = old.unify(&value)?;
279 let diff = old != new;
280 context.set(&self.0, new).with_context(|| format!("while setting {:?}", self.0))?;
281 Ok(diff)
282 }
283
284 fn get_paths(&self) -> Vec<&Path> {
286 vec![&self.0]
287 }
288}
289
290impl<T> fmt::Debug for VariableExp<T>
291where
292 T: Factoid + Output + Clone + ::std::fmt::Debug,
293{
294 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
295 write!(formatter, "{:?}", self.0)
296 }
297}
298
299pub struct ScaledExp<T>(i64, Exp<T>)
301where
302 T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone;
303
304impl<T> TExp<T> for ScaledExp<T>
305where
306 T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone,
307{
308 fn get(&self, context: &Context) -> TractResult<T> {
310 let v: T = self.1.get(context)?;
311 Ok(v * self.0)
312 }
313
314 fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
316 let k = &self.0;
317 let m = value;
318
319 if m.is_zero() && k.is_zero() {
320 Ok(false)
322 } else if m.is_zero() {
323 self.1.set(context, T::zero())
325 } else {
326 let div = m.div(*k);
338 self.1.set(context, div)
339 }
340 }
341
342 fn get_paths(&self) -> Vec<&Path> {
344 self.1.get_paths()
345 }
346}
347
348impl<T> fmt::Debug for ScaledExp<T>
349where
350 T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone,
351{
352 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
353 write!(formatter, "{}*{{{:?}}}", self.0, self.1)
354 }
355}
356
357pub struct IntoDimExp(Exp<IntFactoid>);
359
360impl TExp<DimFact> for IntoDimExp {
361 fn get(&self, context: &Context) -> TractResult<DimFact> {
363 let v: IntFactoid = self.0.get(context)?;
364 match v {
365 GenericFactoid::Only(i) => Ok(GenericFactoid::Only(i.to_dim())),
366 GenericFactoid::Any => Ok(GenericFactoid::Any),
367 }
368 }
369
370 fn set(&self, context: &mut Context, value: DimFact) -> TractResult<bool> {
372 if let Some(concrete) = value.concretize() {
373 if let Ok(int) = concrete.to_i64() {
374 return self.0.set(context, GenericFactoid::Only(int));
375 }
376 }
377 Ok(false)
378 }
379
380 fn get_paths(&self) -> Vec<&Path> {
382 self.0.get_paths()
383 }
384}
385
386impl fmt::Debug for IntoDimExp {
387 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
388 write!(formatter, "{{({:?}) as dim}}", self.0)
389 }
390}
391
392impl<T, E: TExp<T> + 'static> IntoExp<T> for E {
395 fn bex(self) -> Exp<T> {
396 Exp(Box::new(self))
397 }
398}
399
400impl IntoExp<TypeFactoid> for TypeProxy {
403 fn bex(self) -> Exp<TypeFactoid> {
404 VariableExp(self.get_path().clone(), PhantomData).bex()
405 }
406}
407
408impl IntoExp<TypeFactoid> for &TypeProxy {
409 fn bex(self) -> Exp<TypeFactoid> {
410 VariableExp(self.get_path().clone(), PhantomData).bex()
411 }
412}
413
414impl IntoExp<TypeFactoid> for DatumType {
415 fn bex(self) -> Exp<TypeFactoid> {
416 ConstantExp(self.into()).bex()
417 }
418}
419
420impl IntoExp<TypeFactoid> for &DatumType {
421 fn bex(self) -> Exp<TypeFactoid> {
422 ConstantExp((*self).into()).bex()
423 }
424}
425
426impl IntoExp<IntFactoid> for &IntProxy {
429 fn bex(self) -> Exp<IntFactoid> {
430 VariableExp(self.get_path().clone(), PhantomData).bex()
431 }
432}
433
434impl IntoExp<IntFactoid> for &ElementProxy {
435 fn bex(self) -> Exp<IntFactoid> {
436 VariableExp(self.get_path().clone(), PhantomData).bex()
437 }
438}
439
440impl IntoExp<IntFactoid> for i64 {
441 fn bex(self) -> Exp<IntFactoid> {
442 ConstantExp(self.into()).bex()
443 }
444}
445
446impl IntoExp<IntFactoid> for IntFactoid {
447 fn bex(self) -> Exp<IntFactoid> {
448 ConstantExp(self).bex()
449 }
450}
451
452impl<IE: IntoExp<IntFactoid>> Add<IE> for Exp<IntFactoid> {
453 type Output = Exp<IntFactoid>;
454 fn add(self, other: IE) -> Exp<IntFactoid> {
455 SumExp(vec![self.bex(), other.bex()]).bex()
456 }
457}
458
459impl<IE: IntoExp<IntFactoid>> Sub<IE> for Exp<IntFactoid> {
460 type Output = Exp<IntFactoid>;
461 fn sub(self, other: IE) -> Exp<IntFactoid> {
462 SumExp(vec![self.bex(), -1 * other.bex()]).bex()
463 }
464}
465
466impl Mul<Exp<IntFactoid>> for i64 {
467 type Output = Exp<IntFactoid>;
468 fn mul(self, other: Exp<IntFactoid>) -> Exp<IntFactoid> {
469 ScaledExp(self, other).bex()
470 }
471}
472
473impl IntoExp<DimFact> for &DimProxy {
476 fn bex(self) -> Exp<DimFact> {
477 VariableExp(self.get_path().clone(), PhantomData).bex()
478 }
479}
480
481impl IntoExp<DimFact> for TDim {
482 fn bex(self) -> Exp<DimFact> {
483 ConstantExp(self.into()).bex()
484 }
485}
486
487impl IntoExp<DimFact> for &TDim {
488 fn bex(self) -> Exp<DimFact> {
489 ConstantExp(self.clone().into()).bex()
490 }
491}
492
493impl<IE: IntoExp<DimFact>> Add<IE> for Exp<DimFact> {
494 type Output = Exp<DimFact>;
495 fn add(self, other: IE) -> Exp<DimFact> {
496 SumExp(vec![self.bex(), other.bex()]).bex()
497 }
498}
499
500impl<IE: IntoExp<DimFact>> Sub<IE> for Exp<DimFact> {
501 type Output = Exp<DimFact>;
502 fn sub(self, other: IE) -> Exp<DimFact> {
503 SumExp(vec![self.bex(), -1 * other.bex()]).bex()
504 }
505}
506
507impl Mul<Exp<DimFact>> for i64 {
508 type Output = Exp<DimFact>;
509 fn mul(self, other: Exp<DimFact>) -> Exp<DimFact> {
510 ScaledExp(self, other).bex()
511 }
512}
513
514impl IntoExp<DimFact> for GenericFactoid<TDim> {
515 fn bex(self) -> Exp<GenericFactoid<TDim>> {
516 ConstantExp(self).bex()
517 }
518}
519
520pub trait ToDimExp {
523 fn to_dim(self) -> Exp<DimFact>;
524}
525
526impl ToDimExp for Exp<IntFactoid> {
527 fn to_dim(self) -> Exp<DimFact> {
528 IntoDimExp(self).bex()
529 }
530}
531
532impl IntoExp<ShapeFactoid> for ShapeFactoid {
535 fn bex(self) -> Exp<ShapeFactoid> {
536 ConstantExp(self).bex()
537 }
538}
539
540impl IntoExp<ShapeFactoid> for ShapeProxy {
541 fn bex(self) -> Exp<ShapeFactoid> {
542 VariableExp(self.get_path().clone(), PhantomData).bex()
543 }
544}
545
546impl IntoExp<ShapeFactoid> for &ShapeProxy {
547 fn bex(self) -> Exp<ShapeFactoid> {
548 VariableExp(self.get_path().clone(), PhantomData).bex()
549 }
550}
551
552impl IntoExp<ShapeFactoid> for TVec<TDim> {
553 fn bex(self) -> Exp<ShapeFactoid> {
554 ConstantExp(self.into_iter().collect()).bex()
555 }
556}
557
558impl IntoExp<ValueFact> for ValueProxy {
561 fn bex(self) -> Exp<ValueFact> {
562 VariableExp(self.get_path().clone(), PhantomData).bex()
563 }
564}
565
566impl IntoExp<ValueFact> for &ValueProxy {
567 fn bex(self) -> Exp<ValueFact> {
568 VariableExp(self.get_path().clone(), PhantomData).bex()
569 }
570}
571
572impl IntoExp<ValueFact> for Arc<Tensor> {
573 fn bex(self) -> Exp<ValueFact> {
574 ConstantExp(self.into()).bex()
575 }
576}