tract_hir/infer/rules/
expr.rs

1use std::fmt;
2use std::marker::PhantomData;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5use tract_num_traits::ToPrimitive;
6use tract_num_traits::Zero;
7
8use crate::internal::*;
9
10use self::super::super::factoid::*;
11use self::super::path::Path;
12use self::super::proxies::*;
13use self::super::solver::Context;
14
15/// A trait for values produced by expressions.
16pub trait Output: fmt::Debug + Clone + PartialEq {
17    /// Wraps self in the Wrapped type.
18    fn wrap(self) -> Wrapped {
19        Self::into_wrapped(self)
20    }
21
22    /// Wraps the fact in the Wrapped type.
23    fn into_wrapped(source: Self) -> Wrapped;
24
25    /// Retrieves the fact from the Wrapped type.
26    /// Panics if wrapped doesn't have the right constructor.
27    fn from_wrapped(wrapped: Wrapped) -> TractResult<Self>;
28}
29
30macro_rules! impl_output {
31    ($type:ty, $constr:ident, $name:expr) => {
32        impl Output for $type {
33            fn into_wrapped(source: Self) -> Wrapped {
34                Wrapped::$constr(source)
35            }
36
37            fn from_wrapped(wrapped: Wrapped) -> TractResult<$type> {
38                if let Wrapped::$constr(v) = wrapped {
39                    Ok(v)
40                } else {
41                    bail!("Tried to get a {} from {:?}.", $name, wrapped);
42                }
43            }
44        }
45    };
46}
47
48impl_output!(IntFactoid, Int, "Int");
49impl_output!(TypeFactoid, Type, "DatumType");
50impl_output!(ShapeFactoid, Shape, "Shape");
51impl_output!(ValueFact, Tensor, "Tensor");
52impl_output!(DimFact, Dim, "TDim");
53
54// Converts back and forth between Wrapped and usize.
55impl Output for usize {
56    fn into_wrapped(source: usize) -> Wrapped {
57        IntFactoid::into_wrapped((source as i64).into())
58    }
59
60    fn from_wrapped(wrapped: Wrapped) -> TractResult<usize> {
61        IntFactoid::from_wrapped(wrapped.clone())?
62            .concretize()
63            .and_then(|u| u.to_usize())
64            .with_context(|| format!("Tried to convert {wrapped:?} to a usize."))
65    }
66}
67
68// Converts back and forth between Wrapped and i64.
69impl Output for i64 {
70    fn into_wrapped(source: i64) -> Wrapped {
71        IntFactoid::into_wrapped(source.into())
72    }
73
74    fn from_wrapped(wrapped: Wrapped) -> TractResult<i64> {
75        IntFactoid::from_wrapped(wrapped.clone())?
76            .concretize()
77            .with_context(|| format!("Tried to convert {wrapped:?} to a i64."))
78    }
79}
80
81// Converts back and forth between Wrapped and Tensor.
82impl Output for Arc<Tensor> {
83    fn into_wrapped(source: Arc<Tensor>) -> Wrapped {
84        ValueFact::into_wrapped(source.into())
85    }
86
87    fn from_wrapped(wrapped: Wrapped) -> TractResult<Arc<Tensor>> {
88        ValueFact::from_wrapped(wrapped.clone())?
89            .concretize()
90            .with_context(|| format_err!("Tried to convert {:?} to a tensor.", wrapped))
91    }
92}
93
94// Converts back and forth between Wrapped and usize.
95impl Output for TDim {
96    fn into_wrapped(source: TDim) -> Wrapped {
97        DimFact::into_wrapped(source.into())
98    }
99
100    fn from_wrapped(wrapped: Wrapped) -> TractResult<TDim> {
101        DimFact::from_wrapped(wrapped.clone())?
102            .concretize()
103            .with_context(|| format_err!("Tried to convert {:?} to a usize.", wrapped))
104    }
105}
106
107/// A wrapper for all the types of values that expressions can produce.
108#[derive(Debug, Clone)]
109pub enum Wrapped {
110    Int(IntFactoid),
111    Type(TypeFactoid),
112    Shape(ShapeFactoid),
113    Tensor(ValueFact),
114    Dim(DimFact),
115}
116
117/// An expression that can be compared by the solver.
118pub trait TExp<T>: fmt::Debug {
119    /// Returns the current value of the expression in the given context.
120    fn get(&self, context: &Context) -> TractResult<T>;
121
122    /// Tries to set the value of the expression in the given context.
123    fn set(&self, context: &mut Context, value: T) -> TractResult<bool>;
124
125    /// Returns the paths that the expression depends on.
126    fn get_paths(&self) -> Vec<&Path>;
127}
128
129pub struct Exp<T>(Box<dyn TExp<T>>);
130impl<T: Factoid + Output + Clone + fmt::Debug> TExp<T> for Exp<T> {
131    /// Returns the current value of the expression in the given context.
132    fn get(&self, context: &Context) -> TractResult<T> {
133        self.0.get(context)
134    }
135
136    /// Tries to set the value of the expression in the given context.
137    fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
138        self.0.set(context, value)
139    }
140
141    /// Returns the paths that the expression depends on.
142    fn get_paths(&self) -> Vec<&Path> {
143        self.0.get_paths()
144    }
145}
146
147impl<T> fmt::Debug for Exp<T>
148where
149    T: Factoid + Output + Clone + ::std::fmt::Debug,
150{
151    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
152        write!(formatter, "{:?}", self.0)
153    }
154}
155
156pub trait IntoExp<T> {
157    /// Converts the value to an Expression.
158    fn bex(self) -> Exp<T>;
159}
160
161#[derive(new)]
162pub struct SumExp<T>(Vec<Exp<T>>)
163where
164    T: Factoid + Output + Clone + ::std::fmt::Debug + 'static;
165
166impl<T> TExp<T> for SumExp<T>
167where
168    T: Factoid + Output + Zero + Add<T> + Neg<Output = T> + Clone + ::std::fmt::Debug + 'static,
169{
170    /// Returns the current value of the expression in the given context.
171    fn get(&self, context: &Context) -> TractResult<T> {
172        self.0.iter().try_fold(T::zero(), |acc, it| Ok(acc + it.0.get(context)?))
173    }
174
175    /// Tries to set the value of the expression in the given context.
176    fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
177        let mut sum = T::zero();
178        let mut misses = vec![];
179
180        for item in &self.0 {
181            let fact = item.get(context)?;
182            if fact.is_concrete() {
183                sum = sum + fact;
184            } else {
185                misses.push(item);
186            }
187        }
188
189        if misses.len() > 1 {
190            Ok(false)
191        } else if misses.len() == 1 {
192            misses[0].set(context, value + -sum)?;
193            Ok(true)
194        } else if sum == value {
195            Ok(false)
196        } else {
197            bail!("{:?} set to {:?}, already is {:?}", self, value, sum)
198        }
199    }
200
201    /// Returns the paths that the rule depends on.
202    fn get_paths(&self) -> Vec<&Path> {
203        self.0.iter().flat_map(|e| e.get_paths()).collect()
204    }
205}
206
207impl<T> fmt::Debug for SumExp<T>
208where
209    T: Factoid + Output + Clone + ::std::fmt::Debug,
210{
211    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
212        for (ix, t) in self.0.iter().enumerate() {
213            if ix > 0 {
214                write!(formatter, " + ")?;
215            }
216            t.fmt(formatter)?;
217        }
218        Ok(())
219    }
220}
221
222/// A constant expression (e.g. `2` or `DatumType::DT_INT32`).
223pub struct ConstantExp<T>(T)
224where
225    T: Factoid + Output + Clone + ::std::fmt::Debug;
226
227impl<T> TExp<T> for ConstantExp<T>
228where
229    T: Factoid + Output + Clone + ::std::fmt::Debug,
230{
231    /// Returns the current value of the expression in the given context.
232    fn get(&self, _: &Context) -> TractResult<T> {
233        Ok(self.0.clone())
234    }
235
236    /// Tries to set the value of the expression in the given context.
237    fn set(&self, _: &mut Context, value: T) -> TractResult<bool> {
238        self.0.unify(&value)?;
239        Ok(false)
240    }
241
242    /// Returns the paths that the expression depends on.
243    fn get_paths(&self) -> Vec<&Path> {
244        vec![]
245    }
246}
247
248impl<T> fmt::Debug for ConstantExp<T>
249where
250    T: Factoid + Output + Clone + ::std::fmt::Debug,
251{
252    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
253        write!(formatter, "{:?}", self.0)
254    }
255}
256
257/// A reference to a variable.
258///
259/// For instance, `inputs[0].rank` is a reference to the rank of the first
260/// input. Internally, a reference holds a Vec<usize> called a path (see
261/// the documentation for `Proxy::get_path`).
262pub struct VariableExp<T>(Path, PhantomData<T>)
263where
264    T: Factoid + Output + Clone + ::std::fmt::Debug;
265
266impl<T> TExp<T> for VariableExp<T>
267where
268    T: Factoid + Output + Clone + ::std::fmt::Debug,
269{
270    /// Returns the current value of the expression in the given context.
271    fn get(&self, context: &Context) -> TractResult<T> {
272        context.get(&self.0).with_context(|| format!("while getting {:?}", self.0))
273    }
274
275    /// Tries to set the value of the expression in the given context.
276    fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
277        let old = self.get(context)?;
278        let new = old.unify(&value)?;
279        let diff = old != new;
280        context.set(&self.0, new).with_context(|| format!("while setting {:?}", self.0))?;
281        Ok(diff)
282    }
283
284    /// Returns the paths that the expression depends on.
285    fn get_paths(&self) -> Vec<&Path> {
286        vec![&self.0]
287    }
288}
289
290impl<T> fmt::Debug for VariableExp<T>
291where
292    T: Factoid + Output + Clone + ::std::fmt::Debug,
293{
294    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
295        write!(formatter, "{:?}", self.0)
296    }
297}
298
299/// A scalar product between a constant and another expression.
300pub struct ScaledExp<T>(i64, Exp<T>)
301where
302    T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone;
303
304impl<T> TExp<T> for ScaledExp<T>
305where
306    T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone,
307{
308    /// Returns the current value of the expression in the given context.
309    fn get(&self, context: &Context) -> TractResult<T> {
310        let v: T = self.1.get(context)?;
311        Ok(v * self.0)
312    }
313
314    /// Tries to set the value of the expression in the given context.
315    fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
316        let k = &self.0;
317        let m = value;
318
319        if m.is_zero() && k.is_zero() {
320            // We want to set 0 * x <- 0, so we don't have to do anything.
321            Ok(false)
322        } else if m.is_zero() {
323            // We want to set k * x <- 0, where k != 0, so we have to set x <- 0.
324            self.1.set(context, T::zero())
325        } else {
326            /*
327            // We want to set k * x <- m, where k and m != 0, so we will try
328            // to set x <- m / k using a checked division. This way, if m is
329            // not divisible by k, we will return Err instead of panicking.
330            let div = m.div(&V::from(*k)).ok_or(format!(
331            "Cannot set the value of ({:?}, _) to {:?} because \
332            {:?} is not divisible by {:?}.",
333            k, m, m, k
334            ))?;
335            */
336
337            let div = m.div(*k);
338            self.1.set(context, div)
339        }
340    }
341
342    /// Returns the paths that the expression depends on.
343    fn get_paths(&self) -> Vec<&Path> {
344        self.1.get_paths()
345    }
346}
347
348impl<T> fmt::Debug for ScaledExp<T>
349where
350    T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone,
351{
352    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
353        write!(formatter, "{}*{{{:?}}}", self.0, self.1)
354    }
355}
356
357/// Cast an IntFactoid into a DimFact
358pub struct IntoDimExp(Exp<IntFactoid>);
359
360impl TExp<DimFact> for IntoDimExp {
361    /// Returns the current value of the expression in the given context.
362    fn get(&self, context: &Context) -> TractResult<DimFact> {
363        let v: IntFactoid = self.0.get(context)?;
364        match v {
365            GenericFactoid::Only(i) => Ok(GenericFactoid::Only(i.to_dim())),
366            GenericFactoid::Any => Ok(GenericFactoid::Any),
367        }
368    }
369
370    /// Tries to set the value of the expression in the given context.
371    fn set(&self, context: &mut Context, value: DimFact) -> TractResult<bool> {
372        if let Some(concrete) = value.concretize() {
373            if let Ok(int) = concrete.to_i64() {
374                return self.0.set(context, GenericFactoid::Only(int));
375            }
376        }
377        Ok(false)
378    }
379
380    /// Returns the paths that the expression depends on.
381    fn get_paths(&self) -> Vec<&Path> {
382        self.0.get_paths()
383    }
384}
385
386impl fmt::Debug for IntoDimExp {
387    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
388        write!(formatter, "{{({:?}) as dim}}", self.0)
389    }
390}
391
392// ops and cast on Exp
393
394impl<T, E: TExp<T> + 'static> IntoExp<T> for E {
395    fn bex(self) -> Exp<T> {
396        Exp(Box::new(self))
397    }
398}
399
400// Type
401
402impl IntoExp<TypeFactoid> for TypeProxy {
403    fn bex(self) -> Exp<TypeFactoid> {
404        VariableExp(self.get_path().clone(), PhantomData).bex()
405    }
406}
407
408impl IntoExp<TypeFactoid> for &TypeProxy {
409    fn bex(self) -> Exp<TypeFactoid> {
410        VariableExp(self.get_path().clone(), PhantomData).bex()
411    }
412}
413
414impl IntoExp<TypeFactoid> for DatumType {
415    fn bex(self) -> Exp<TypeFactoid> {
416        ConstantExp(self.into()).bex()
417    }
418}
419
420impl IntoExp<TypeFactoid> for &DatumType {
421    fn bex(self) -> Exp<TypeFactoid> {
422        ConstantExp((*self).into()).bex()
423    }
424}
425
426// Int
427
428impl IntoExp<IntFactoid> for &IntProxy {
429    fn bex(self) -> Exp<IntFactoid> {
430        VariableExp(self.get_path().clone(), PhantomData).bex()
431    }
432}
433
434impl IntoExp<IntFactoid> for &ElementProxy {
435    fn bex(self) -> Exp<IntFactoid> {
436        VariableExp(self.get_path().clone(), PhantomData).bex()
437    }
438}
439
440impl IntoExp<IntFactoid> for i64 {
441    fn bex(self) -> Exp<IntFactoid> {
442        ConstantExp(self.into()).bex()
443    }
444}
445
446impl IntoExp<IntFactoid> for IntFactoid {
447    fn bex(self) -> Exp<IntFactoid> {
448        ConstantExp(self).bex()
449    }
450}
451
452impl<IE: IntoExp<IntFactoid>> Add<IE> for Exp<IntFactoid> {
453    type Output = Exp<IntFactoid>;
454    fn add(self, other: IE) -> Exp<IntFactoid> {
455        SumExp(vec![self.bex(), other.bex()]).bex()
456    }
457}
458
459impl<IE: IntoExp<IntFactoid>> Sub<IE> for Exp<IntFactoid> {
460    type Output = Exp<IntFactoid>;
461    fn sub(self, other: IE) -> Exp<IntFactoid> {
462        SumExp(vec![self.bex(), -1 * other.bex()]).bex()
463    }
464}
465
466impl Mul<Exp<IntFactoid>> for i64 {
467    type Output = Exp<IntFactoid>;
468    fn mul(self, other: Exp<IntFactoid>) -> Exp<IntFactoid> {
469        ScaledExp(self, other).bex()
470    }
471}
472
473// Dim
474
475impl IntoExp<DimFact> for &DimProxy {
476    fn bex(self) -> Exp<DimFact> {
477        VariableExp(self.get_path().clone(), PhantomData).bex()
478    }
479}
480
481impl IntoExp<DimFact> for TDim {
482    fn bex(self) -> Exp<DimFact> {
483        ConstantExp(self.into()).bex()
484    }
485}
486
487impl IntoExp<DimFact> for &TDim {
488    fn bex(self) -> Exp<DimFact> {
489        ConstantExp(self.clone().into()).bex()
490    }
491}
492
493impl<IE: IntoExp<DimFact>> Add<IE> for Exp<DimFact> {
494    type Output = Exp<DimFact>;
495    fn add(self, other: IE) -> Exp<DimFact> {
496        SumExp(vec![self.bex(), other.bex()]).bex()
497    }
498}
499
500impl<IE: IntoExp<DimFact>> Sub<IE> for Exp<DimFact> {
501    type Output = Exp<DimFact>;
502    fn sub(self, other: IE) -> Exp<DimFact> {
503        SumExp(vec![self.bex(), -1 * other.bex()]).bex()
504    }
505}
506
507impl Mul<Exp<DimFact>> for i64 {
508    type Output = Exp<DimFact>;
509    fn mul(self, other: Exp<DimFact>) -> Exp<DimFact> {
510        ScaledExp(self, other).bex()
511    }
512}
513
514impl IntoExp<DimFact> for GenericFactoid<TDim> {
515    fn bex(self) -> Exp<GenericFactoid<TDim>> {
516        ConstantExp(self).bex()
517    }
518}
519
520// Cast to dim
521
522pub trait ToDimExp {
523    fn to_dim(self) -> Exp<DimFact>;
524}
525
526impl ToDimExp for Exp<IntFactoid> {
527    fn to_dim(self) -> Exp<DimFact> {
528        IntoDimExp(self).bex()
529    }
530}
531
532// Shape
533
534impl IntoExp<ShapeFactoid> for ShapeFactoid {
535    fn bex(self) -> Exp<ShapeFactoid> {
536        ConstantExp(self).bex()
537    }
538}
539
540impl IntoExp<ShapeFactoid> for ShapeProxy {
541    fn bex(self) -> Exp<ShapeFactoid> {
542        VariableExp(self.get_path().clone(), PhantomData).bex()
543    }
544}
545
546impl IntoExp<ShapeFactoid> for &ShapeProxy {
547    fn bex(self) -> Exp<ShapeFactoid> {
548        VariableExp(self.get_path().clone(), PhantomData).bex()
549    }
550}
551
552impl IntoExp<ShapeFactoid> for TVec<TDim> {
553    fn bex(self) -> Exp<ShapeFactoid> {
554        ConstantExp(self.into_iter().collect()).bex()
555    }
556}
557
558// Arc<Tensor>
559
560impl IntoExp<ValueFact> for ValueProxy {
561    fn bex(self) -> Exp<ValueFact> {
562        VariableExp(self.get_path().clone(), PhantomData).bex()
563    }
564}
565
566impl IntoExp<ValueFact> for &ValueProxy {
567    fn bex(self) -> Exp<ValueFact> {
568        VariableExp(self.get_path().clone(), PhantomData).bex()
569    }
570}
571
572impl IntoExp<ValueFact> for Arc<Tensor> {
573    fn bex(self) -> Exp<ValueFact> {
574        ConstantExp(self.into()).bex()
575    }
576}