1use super::Factoid;
2use crate::infer::*;
3use std::fmt;
4use tract_data::TooEarly;
5
6tract_core::dyn_clone::clone_trait_object!(InferenceOp);
7
8pub trait InferenceOp: Op {
10 fn infer(
23 &mut self,
24 inputs: TVec<&InferenceFact>,
25 outputs: TVec<&InferenceFact>,
26 observed: TVec<&InferenceFact>,
27 ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
28 let (infered_inputs, infered_outputs, observed) =
29 self.infer_facts(inputs, outputs, observed).context("Infering facts")?;
30
31 if self.is_stateless() && infered_inputs.iter().all(|i| i.value.is_concrete()) {
32 let input_values = infered_inputs
33 .iter()
34 .map(|i| i.value.concretize().unwrap().into_tvalue())
35 .collect(); match self.eval(input_values) {
37 Ok(values) => {
38 let output_values = values
39 .into_iter()
40 .map(|t| t.into_arc_tensor().try_into())
41 .collect::<TractResult<TVec<_>>>()?;
42 return Ok((infered_inputs, output_values, observed));
43 }
44 Err(e) if e.root_cause().downcast_ref::<TooEarly>().is_some() => (),
45 Err(e) => return Err(e).context("Eager eval during inference"),
46 }
47 }
48
49 Ok((infered_inputs, infered_outputs, observed))
50 }
51
52 fn observe_outlets(
55 &self,
56 _model: &InferenceModel,
57 _node: &InferenceNode,
58 ) -> TractResult<Vec<OutletId>> {
59 Ok(vec![])
60 }
61
62 fn infer_facts(
68 &mut self,
69 inputs: TVec<&InferenceFact>,
70 outputs: TVec<&InferenceFact>,
71 observed: TVec<&InferenceFact>,
72 ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)>;
73
74 #[allow(unused_variables)]
81 fn incorporate(
82 &self,
83 model: &InferenceModel,
84 node: &InferenceNode,
85 ) -> TractResult<Option<InferenceModelPatch>> {
86 Ok(None)
87 }
88
89 fn nboutputs(&self) -> TractResult<usize> {
90 Ok(1)
91 }
92
93 fn as_op(&self) -> &dyn Op;
95
96 fn as_op_mut(&mut self) -> &mut dyn Op;
98
99 #[allow(unused_variables)]
101 fn to_typed(
102 &self,
103 source: &InferenceModel,
104 node: &InferenceNode,
105 target: &mut TypedModel,
106 mapping: &HashMap<OutletId, OutletId>,
107 ) -> TractResult<TVec<OutletId>> {
108 bail!("Operator can not be made a TypedOp.")
109 }
110}
111
112impl std::fmt::Display for Box<dyn InferenceOp> {
113 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
114 write!(fmt, "{}", self.name())
115 }
116}
117
118impl<O: InferenceOp> From<O> for Box<dyn InferenceOp> {
119 fn from(it: O) -> Box<dyn InferenceOp> {
120 Box::new(it)
121 }
122}
123
124impl AsRef<dyn Op> for dyn InferenceOp {
125 fn as_ref(&self) -> &dyn Op {
126 self.as_op()
127 }
128}
129
130impl AsRef<dyn Op> for Box<dyn InferenceOp> {
131 fn as_ref(&self) -> &dyn Op {
132 self.as_op()
133 }
134}
135
136impl AsMut<dyn Op> for dyn InferenceOp {
137 fn as_mut(&mut self) -> &mut dyn Op {
138 self.as_op_mut()
139 }
140}
141
142impl AsMut<dyn Op> for Box<dyn InferenceOp> {
143 fn as_mut(&mut self) -> &mut dyn Op {
144 self.as_op_mut()
145 }
146}