tract_hir/infer/
ops.rs

1use super::Factoid;
2use crate::infer::*;
3use std::fmt;
4use tract_data::TooEarly;
5
6tract_core::dyn_clone::clone_trait_object!(InferenceOp);
7
8/// An operation with tensor type inference
9pub trait InferenceOp: Op {
10    /// Infers properties about the input and output tensors.
11    ///
12    /// The `inputs` and `outputs` arguments correspond to properties about
13    /// the input and output tensors that are already known.
14    ///
15    /// The default implementation will call the private infer_facts method,
16    /// which is usually implemented using the InferenceRulesOp trait. It will
17    /// also try to eval() the op if its a EvalOp and if the inputs are
18    /// fully determined.
19    ///
20    /// Returns Err in case of an unrecoverable error during the inference,
21    /// and the refined properties about the inputs and outputs otherwise.
22    fn infer(
23        &mut self,
24        inputs: TVec<&InferenceFact>,
25        outputs: TVec<&InferenceFact>,
26        observed: TVec<&InferenceFact>,
27    ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
28        let (infered_inputs, infered_outputs, observed) =
29            self.infer_facts(inputs, outputs, observed).context("Infering facts")?;
30
31        if self.is_stateless() && infered_inputs.iter().all(|i| i.value.is_concrete()) {
32            let input_values = infered_inputs
33                .iter()
34                .map(|i| i.value.concretize().unwrap().into_tvalue())
35                .collect(); // checked
36            match self.eval(input_values) {
37                Ok(values) => {
38                    let output_values =
39                        values.into_iter().map(|t| t.into_arc_tensor().into()).collect::<TVec<_>>();
40                    return Ok((infered_inputs, output_values, observed));
41                }
42                Err(e) if e.root_cause().downcast_ref::<TooEarly>().is_some() => (),
43                Err(e) => return Err(e).context("Eager eval during inference"),
44            }
45        }
46
47        Ok((infered_inputs, infered_outputs, observed))
48    }
49
50    /// Allow an op to specify a supplementary list of outlets facts that
51    /// will trigger inference again.
52    fn observe_outlets(
53        &self,
54        _model: &InferenceModel,
55        _node: &InferenceNode,
56    ) -> TractResult<Vec<OutletId>> {
57        Ok(vec![])
58    }
59
60    /// Infer properties about inputs and output tensors. This method does not
61    /// need to deal with the "trivial" stateless op with fully determined
62    /// inputs cases.
63    ///
64    /// Most of the time, it is implemented using InferenceRulesOp.
65    fn infer_facts(
66        &mut self,
67        inputs: TVec<&InferenceFact>,
68        outputs: TVec<&InferenceFact>,
69        observed: TVec<&InferenceFact>,
70    ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)>;
71
72    /// Early pass on inference model, after analyse, but before translation to
73    /// typed network. Meant to deal with some framework idiosyncrasies that
74    /// manifest with temporaries nodes that can run some form of inference but
75    /// require refactoring the network before it can be evaluated.
76    ///
77    /// Called after succesful analyse, but before translating to typed model.
78    #[allow(unused_variables)]
79    fn incorporate(
80        &self,
81        model: &InferenceModel,
82        node: &InferenceNode,
83    ) -> TractResult<Option<InferenceModelPatch>> {
84        Ok(None)
85    }
86
87    fn nboutputs(&self) -> TractResult<usize> {
88        Ok(1)
89    }
90
91    /// Reinterpret the InferenceOp as an Op.
92    fn as_op(&self) -> &dyn Op;
93
94    /// Reinterpret the InferenceOp as an Op, mutably.
95    fn as_op_mut(&mut self) -> &mut dyn Op;
96
97    /// Called during translation to TypedModel.
98    #[allow(unused_variables)]
99    fn to_typed(
100        &self,
101        source: &InferenceModel,
102        node: &InferenceNode,
103        target: &mut TypedModel,
104        mapping: &HashMap<OutletId, OutletId>,
105    ) -> TractResult<TVec<OutletId>> {
106        bail!("Operator can not be made a TypedOp.")
107    }
108}
109
110impl std::fmt::Display for Box<dyn InferenceOp> {
111    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
112        write!(fmt, "{}", self.name())
113    }
114}
115
116impl<O: InferenceOp> From<O> for Box<dyn InferenceOp> {
117    fn from(it: O) -> Box<dyn InferenceOp> {
118        Box::new(it)
119    }
120}
121
122impl AsRef<dyn Op> for dyn InferenceOp {
123    fn as_ref(&self) -> &dyn Op {
124        self.as_op()
125    }
126}
127
128impl AsRef<dyn Op> for Box<dyn InferenceOp> {
129    fn as_ref(&self) -> &dyn Op {
130        self.as_op()
131    }
132}
133
134impl AsMut<dyn Op> for dyn InferenceOp {
135    fn as_mut(&mut self) -> &mut dyn Op {
136        self.as_op_mut()
137    }
138}
139
140impl AsMut<dyn Op> for Box<dyn InferenceOp> {
141    fn as_mut(&mut self) -> &mut dyn Op {
142        self.as_op_mut()
143    }
144}