1use super::Factoid;
2use crate::infer::*;
3use std::fmt;
4use tract_data::TooEarly;
5
6tract_core::dyn_clone::clone_trait_object!(InferenceOp);
7
8pub trait InferenceOp: Op {
10 fn infer(
23 &mut self,
24 inputs: TVec<&InferenceFact>,
25 outputs: TVec<&InferenceFact>,
26 observed: TVec<&InferenceFact>,
27 ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
28 let (infered_inputs, infered_outputs, observed) =
29 self.infer_facts(inputs, outputs, observed).context("Infering facts")?;
30
31 if self.is_stateless() && infered_inputs.iter().all(|i| i.value.is_concrete()) {
32 let input_values = infered_inputs
33 .iter()
34 .map(|i| i.value.concretize().unwrap().into_tvalue())
35 .collect(); match self.eval(input_values) {
37 Ok(values) => {
38 let output_values =
39 values.into_iter().map(|t| t.into_arc_tensor().into()).collect::<TVec<_>>();
40 return Ok((infered_inputs, output_values, observed));
41 }
42 Err(e) if e.root_cause().downcast_ref::<TooEarly>().is_some() => (),
43 Err(e) => return Err(e).context("Eager eval during inference"),
44 }
45 }
46
47 Ok((infered_inputs, infered_outputs, observed))
48 }
49
50 fn observe_outlets(
53 &self,
54 _model: &InferenceModel,
55 _node: &InferenceNode,
56 ) -> TractResult<Vec<OutletId>> {
57 Ok(vec![])
58 }
59
60 fn infer_facts(
66 &mut self,
67 inputs: TVec<&InferenceFact>,
68 outputs: TVec<&InferenceFact>,
69 observed: TVec<&InferenceFact>,
70 ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)>;
71
72 #[allow(unused_variables)]
79 fn incorporate(
80 &self,
81 model: &InferenceModel,
82 node: &InferenceNode,
83 ) -> TractResult<Option<InferenceModelPatch>> {
84 Ok(None)
85 }
86
87 fn nboutputs(&self) -> TractResult<usize> {
88 Ok(1)
89 }
90
91 fn as_op(&self) -> &dyn Op;
93
94 fn as_op_mut(&mut self) -> &mut dyn Op;
96
97 #[allow(unused_variables)]
99 fn to_typed(
100 &self,
101 source: &InferenceModel,
102 node: &InferenceNode,
103 target: &mut TypedModel,
104 mapping: &HashMap<OutletId, OutletId>,
105 ) -> TractResult<TVec<OutletId>> {
106 bail!("Operator can not be made a TypedOp.")
107 }
108}
109
110impl std::fmt::Display for Box<dyn InferenceOp> {
111 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
112 write!(fmt, "{}", self.name())
113 }
114}
115
116impl<O: InferenceOp> From<O> for Box<dyn InferenceOp> {
117 fn from(it: O) -> Box<dyn InferenceOp> {
118 Box::new(it)
119 }
120}
121
122impl AsRef<dyn Op> for dyn InferenceOp {
123 fn as_ref(&self) -> &dyn Op {
124 self.as_op()
125 }
126}
127
128impl AsRef<dyn Op> for Box<dyn InferenceOp> {
129 fn as_ref(&self) -> &dyn Op {
130 self.as_op()
131 }
132}
133
134impl AsMut<dyn Op> for dyn InferenceOp {
135 fn as_mut(&mut self) -> &mut dyn Op {
136 self.as_op_mut()
137 }
138}
139
140impl AsMut<dyn Op> for Box<dyn InferenceOp> {
141 fn as_mut(&mut self) -> &mut dyn Op {
142 self.as_op_mut()
143 }
144}