tract_hir/infer/rules/
mod.rs

1//! A fluent interface for the analyser.
2//!
3//! This interface provides proxies for the different properties of tensors.
4//! This allows inference rules to be stated in a clear, declarative fashion
5//! inside the `rules` method of each operator.
6//!
7//! Take these rules for instance:
8//! ```text
9//! solver.equals(inputs.len(), 2);
10//! solver.equals(inputs[0].datum_type, outputs[0].datum_type);
11//! ```
12//! Here, `inputs.len`, `inputs[0].datum_type` and `outputs[0].datum_type` don't
13//! actually hold the values of the length and datum_types, but instead act as
14//! declarative placeholders for these values.
15
16#[macro_export]
17macro_rules! wrap {
18    ($($x:expr),*) => ({
19        vec![$( $crate::infer::rules::expr::IntoExp::bex($x) ),*]
20    });
21
22    ($($x:expr,)*) => (wrap![$($x),*]);
23}
24
25use crate::infer::*;
26
27mod cache;
28pub mod expr;
29mod path;
30mod proxies;
31mod solver;
32
33pub use self::proxies::*;
34pub use self::solver::Solver;
35
36pub type InferenceResult = TractResult<()>;
37
38pub trait InferenceRulesOp {
39    /// Registers the inference rules of the operator.
40    fn rules<'r, 'p: 'r, 's: 'r>(
41        &'s self,
42        solver: &mut Solver<'r>,
43        inputs: &'p [TensorProxy],
44        outputs: &'p [TensorProxy],
45    ) -> InferenceResult;
46
47    fn as_op(&self) -> &dyn Op;
48    fn as_op_mut(&mut self) -> &mut dyn Op;
49
50    #[allow(unused_variables)]
51    fn to_typed(
52        &self,
53        source: &InferenceModel,
54        node: &InferenceNode,
55        target: &mut TypedModel,
56        mapping: &HashMap<OutletId, OutletId>,
57    ) -> TractResult<TVec<OutletId>> {
58        bail!("Node {} can not be typed", node)
59    }
60
61    fn nboutputs(&self) -> TractResult<usize> {
62        Ok(1)
63    }
64
65    #[allow(unused_variables)]
66    fn incorporate(
67        &self,
68        model: &InferenceModel,
69        node: &InferenceNode,
70    ) -> TractResult<Option<InferenceModelPatch>> {
71        Ok(None)
72    }
73}
74
75impl<O: InferenceRulesOp + Op> InferenceOp for O {
76    fn infer_facts(
77        &mut self,
78        inputs: TVec<&InferenceFact>,
79        outputs: TVec<&InferenceFact>,
80        observed: TVec<&InferenceFact>,
81    ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
82        let inputs_proxy: TVec<TensorProxy> =
83            (0..inputs.len()).map(|ix| TensorProxy::new(tvec!(0, ix as isize).into())).collect();
84        let outputs_proxy: TVec<TensorProxy> =
85            (0..outputs.len()).map(|ix| TensorProxy::new(tvec!(1, ix as isize).into())).collect();
86
87        trace!("Building rules for {self:?}");
88        let mut solver = Solver::default();
89        self.rules(&mut solver, &inputs_proxy, &outputs_proxy)?;
90        trace!("Applying rules for {self:?}");
91        let (input, output) = solver.infer_facts((inputs, outputs))?;
92        trace!("Solver done");
93        Ok((input, output, observed.into_iter().cloned().collect()))
94    }
95
96    fn nboutputs(&self) -> TractResult<usize> {
97        self.nboutputs()
98    }
99
100    fn observe_outlets(
101        &self,
102        _model: &InferenceModel,
103        _node: &InferenceNode,
104    ) -> TractResult<Vec<OutletId>> {
105        Ok(vec![])
106    }
107
108    fn as_op(&self) -> &dyn Op {
109        self.as_op()
110    }
111
112    fn as_op_mut(&mut self) -> &mut dyn Op {
113        self.as_op_mut()
114    }
115
116    fn to_typed(
117        &self,
118        source: &InferenceModel,
119        node: &InferenceNode,
120        target: &mut TypedModel,
121        mapping: &HashMap<OutletId, OutletId>,
122    ) -> TractResult<TVec<OutletId>> {
123        self.to_typed(source, node, target, mapping)
124    }
125
126    fn incorporate(
127        &self,
128        model: &InferenceModel,
129        node: &InferenceNode,
130    ) -> TractResult<Option<InferenceModelPatch>> {
131        self.incorporate(model, node)
132    }
133}