tract_hir/infer/rules/
mod.rs1#[macro_export]
17macro_rules! wrap {
18 ($($x:expr),*) => ({
19 vec![$( $crate::infer::rules::expr::IntoExp::bex($x) ),*]
20 });
21
22 ($($x:expr,)*) => (wrap![$($x),*]);
23}
24
25use crate::infer::*;
26
27mod cache;
28pub mod expr;
29mod path;
30mod proxies;
31mod solver;
32
33pub use self::proxies::*;
34pub use self::solver::Solver;
35
36pub type InferenceResult = TractResult<()>;
37
38pub trait InferenceRulesOp {
39 fn rules<'r, 'p: 'r, 's: 'r>(
41 &'s self,
42 solver: &mut Solver<'r>,
43 inputs: &'p [TensorProxy],
44 outputs: &'p [TensorProxy],
45 ) -> InferenceResult;
46
47 fn as_op(&self) -> &dyn Op;
48 fn as_op_mut(&mut self) -> &mut dyn Op;
49
50 #[allow(unused_variables)]
51 fn to_typed(
52 &self,
53 source: &InferenceModel,
54 node: &InferenceNode,
55 target: &mut TypedModel,
56 mapping: &HashMap<OutletId, OutletId>,
57 ) -> TractResult<TVec<OutletId>> {
58 bail!("Node {} can not be typed", node)
59 }
60
61 fn nboutputs(&self) -> TractResult<usize> {
62 Ok(1)
63 }
64
65 #[allow(unused_variables)]
66 fn incorporate(
67 &self,
68 model: &InferenceModel,
69 node: &InferenceNode,
70 ) -> TractResult<Option<InferenceModelPatch>> {
71 Ok(None)
72 }
73}
74
75impl<O: InferenceRulesOp + Op> InferenceOp for O {
76 fn infer_facts(
77 &mut self,
78 inputs: TVec<&InferenceFact>,
79 outputs: TVec<&InferenceFact>,
80 observed: TVec<&InferenceFact>,
81 ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
82 let inputs_proxy: TVec<TensorProxy> =
83 (0..inputs.len()).map(|ix| TensorProxy::new(tvec!(0, ix as isize).into())).collect();
84 let outputs_proxy: TVec<TensorProxy> =
85 (0..outputs.len()).map(|ix| TensorProxy::new(tvec!(1, ix as isize).into())).collect();
86
87 trace!("Building rules for {self:?}");
88 let mut solver = Solver::default();
89 self.rules(&mut solver, &inputs_proxy, &outputs_proxy)?;
90 trace!("Applying rules for {self:?}");
91 let (input, output) = solver.infer_facts((inputs, outputs))?;
92 trace!("Solver done");
93 Ok((input, output, observed.into_iter().cloned().collect()))
94 }
95
96 fn nboutputs(&self) -> TractResult<usize> {
97 self.nboutputs()
98 }
99
100 fn observe_outlets(
101 &self,
102 _model: &InferenceModel,
103 _node: &InferenceNode,
104 ) -> TractResult<Vec<OutletId>> {
105 Ok(vec![])
106 }
107
108 fn as_op(&self) -> &dyn Op {
109 self.as_op()
110 }
111
112 fn as_op_mut(&mut self) -> &mut dyn Op {
113 self.as_op_mut()
114 }
115
116 fn to_typed(
117 &self,
118 source: &InferenceModel,
119 node: &InferenceNode,
120 target: &mut TypedModel,
121 mapping: &HashMap<OutletId, OutletId>,
122 ) -> TractResult<TVec<OutletId>> {
123 self.to_typed(source, node, target, mapping)
124 }
125
126 fn incorporate(
127 &self,
128 model: &InferenceModel,
129 node: &InferenceNode,
130 ) -> TractResult<Option<InferenceModelPatch>> {
131 self.incorporate(model, node)
132 }
133}