tract_core/ops/
mod.rs

1//! Ops
2use std::fmt;
3
4use downcast_rs::Downcast;
5
6use dyn_clone;
7
8#[macro_use]
9pub mod macros;
10#[macro_use]
11pub mod element_wise;
12#[macro_use]
13pub mod binary;
14
15pub mod array;
16pub mod cast;
17pub mod change_axes;
18pub mod cnn;
19pub mod downsample;
20pub mod dummy;
21pub mod einsum;
22pub mod fft;
23pub mod identity;
24pub mod konst;
25pub mod logic;
26pub mod math;
27pub mod matmul;
28pub mod memory;
29pub mod nn;
30pub mod quant;
31pub mod scan;
32pub mod source;
33pub mod submodel;
34pub mod unimpl;
35
36pub use downsample::Downsample;
37pub use memory::*;
38
39use crate::internal::*;
40use crate::optim::OptimizerSession;
41
42/// Level of precision to be expected in implementations comparisons.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum Validation {
45    /// Output is random
46    Random,
47    /// Implementation may induce rounding errors
48    Rounding,
49    /// Implementation must be accurate
50    Accurate,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
54pub enum Cost {
55    Div(DatumType),
56    FMA(DatumType),
57    Buffer(DatumType),
58    Params(DatumType),
59}
60
61impl Cost {
62    pub fn is_compute(&self) -> bool {
63        use Cost::*;
64        match self {
65            FMA(_) | Div(_) => true,
66            Buffer(_) | Params(_) => false,
67        }
68    }
69}
70
71pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
72    fn unfreeze(&self) -> Box<dyn OpState>;
73}
74
75pub trait OpStateFreeze {
76    fn freeze(&self) -> Box<dyn FrozenOpState>;
77}
78
79dyn_clone::clone_trait_object!(FrozenOpState);
80
81pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
82    fn eval(
83        &mut self,
84        session: &mut SessionState,
85        op: &dyn Op,
86        inputs: TVec<TValue>,
87    ) -> TractResult<TVec<TValue>>;
88}
89dyn_clone::clone_trait_object!(OpState);
90impl_downcast!(OpState);
91
92pub trait EvalOp {
93    #[allow(unused_variables)]
94    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
95        bail!("stateless evaluation not implemented")
96    }
97
98    #[allow(unused_variables)]
99    fn eval_with_session(
100        &self,
101        session: &SessionState,
102        inputs: TVec<TValue>,
103    ) -> TractResult<TVec<TValue>> {
104        self.eval(inputs).context("Running legacy eval")
105    }
106
107    #[allow(unused_variables)]
108    fn state(
109        &self,
110        session: &mut SessionState,
111        node_id: usize,
112    ) -> TractResult<Option<Box<dyn OpState>>> {
113        Ok(None)
114    }
115
116    fn is_stateless(&self) -> bool;
117}
118
119/// A base operation
120pub trait Op: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp {
121    fn name(&self) -> Cow<str>;
122
123    /// The kind of accuracy check that should be performed on operation when
124    /// testing them.
125    fn validation(&self) -> Validation {
126        Validation::Accurate
127    }
128
129    /// Compare two ops.
130    // Should this one be and Eq or PartialEq impl instead ?
131    fn same_as(&self, _other: &dyn Op) -> bool {
132        false
133    }
134
135    /// Short (one-line) strings giving hints on internal implementation or
136    /// important configuration details to be displayed in dumps.
137    fn info(&self) -> TractResult<Vec<String>> {
138        Ok(vec![])
139    }
140
141    fn as_typed(&self) -> Option<&dyn TypedOp>;
142}
143
144pub trait TypedOp:
145    Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
146{
147    /// Reinterpret the TypedOp as an Op.
148    fn as_op(&self) -> &dyn Op;
149
150    /// Reinterpret the TypedOp as an Op, mutably.
151    fn as_op_mut(&mut self) -> &mut dyn Op;
152
153    /// Deduce output facts from input facts.
154    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
155
156    #[allow(unused_variables)]
157    fn axes_mapping(
158        &self,
159        inputs: &[&TypedFact],
160        outputs: &[&TypedFact],
161    ) -> TractResult<AxesMapping> {
162        AxesMapping::disconnected(inputs, outputs)
163    }
164
165    /// Fuse op after codegen to deal with local optimisations.
166    fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
167        Ok(None)
168    }
169
170    /// Declutter the op to the tract_core operator set as much as possible.
171    #[allow(unused_variables)]
172    fn declutter_with_session(
173        &self,
174        session: &mut OptimizerSession,
175        model: &TypedModel,
176        node: &TypedNode,
177    ) -> TractResult<Option<TypedModelPatch>> {
178        self.declutter(model, node)
179    }
180
181    /// Declutter the op to the tract_core operator set as much as possible.
182    #[allow(unused_variables)]
183    fn declutter(
184        &self,
185        model: &TypedModel,
186        node: &TypedNode,
187    ) -> TractResult<Option<TypedModelPatch>> {
188        Ok(None)
189    }
190
191    /// Computes a cost hint of the operation.
192    ///
193    /// Each pair is a type of operation and a number per call on eval.
194    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
195        Ok(tvec!())
196    }
197
198    #[allow(unused_variables)]
199    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
200        Ok(tvec!())
201    }
202
203    #[allow(unused_variables)]
204    fn change_axes(
205        &self,
206        model: &TypedModel,
207        node: &TypedNode,
208        io: InOut,
209        change: &AxisOp,
210    ) -> TractResult<Option<AxisChangeConsequence>> {
211        Ok(None)
212    }
213
214    #[allow(unused_variables)]
215    #[allow(clippy::too_many_arguments)]
216    fn slice(
217        &self,
218        patch: &mut TypedModelPatch,
219        model: &TypedModel,
220        node: &TypedNode,
221        prefix: &str,
222        inputs: &[OutletId],
223        output_axis: usize,
224        start: usize,
225        end: usize,
226    ) -> TractResult<Option<TVec<OutletId>>> {
227        Ok(None)
228    }
229
230    /// Transforms the op in an equivalent one, operating on dt (i8 or u8).
231    ///
232    /// Returns None if the op can not be translated.
233    #[allow(unused_variables)]
234    fn quantize(
235        &self,
236        model: &TypedModel,
237        node: &TypedNode,
238        dt: DatumType,
239        scale: f32,
240        zero_point: i32,
241    ) -> TractResult<Option<Box<dyn TypedOp>>> {
242        Ok(None)
243    }
244
245    /// Transform the op into by providing a value to one or more symbols.
246    #[allow(unused_variables)]
247    fn concretize_dims(
248        &self,
249        source: &TypedModel,
250        node: &TypedNode,
251        target: &mut TypedModel,
252        mapping: &HashMap<OutletId, OutletId>,
253        values: &SymbolValues,
254    ) -> TractResult<TVec<OutletId>> {
255        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
256        target.wire_node(&node.name, node.op.clone(), &inputs)
257    }
258
259    /// Translate the op into the most efficient form possible for execution.
260    ///
261    /// This transformation is supposed to be final, no more pass are expected
262    /// to be run on the codegen networks.
263    #[allow(unused_variables)]
264    fn codegen(
265        &self,
266        model: &TypedModel,
267        node: &TypedNode,
268    ) -> TractResult<Option<TypedModelPatch>> {
269        Ok(None)
270    }
271
272    /// Nested model multipliers, with label (for profiling).
273    #[allow(unused_variables)]
274    fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(Cow<str>, f64)> {
275        vec![]
276    }
277}
278
279impl_downcast!(Op);
280impl_downcast!(TypedOp);
281
282dyn_clone::clone_trait_object!(Op);
283dyn_clone::clone_trait_object!(TypedOp);
284
285impl<O: Op> From<O> for Box<dyn Op> {
286    fn from(it: O) -> Box<dyn Op> {
287        Box::new(it)
288    }
289}
290
291impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
292    fn from(it: O) -> Box<dyn TypedOp> {
293        Box::new(it)
294    }
295}
296
297impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
298    fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
299        it.clone()
300    }
301}
302
303impl AsRef<dyn Op> for dyn TypedOp {
304    fn as_ref(&self) -> &dyn Op {
305        self.as_op()
306    }
307}
308
309impl AsRef<dyn Op> for Box<dyn TypedOp> {
310    fn as_ref(&self) -> &dyn Op {
311        self.as_op()
312    }
313}
314
315impl AsMut<dyn Op> for dyn TypedOp {
316    fn as_mut(&mut self) -> &mut dyn Op {
317        self.as_op_mut()
318    }
319}
320
321impl AsMut<dyn Op> for Box<dyn TypedOp> {
322    fn as_mut(&mut self) -> &mut dyn Op {
323        self.as_op_mut()
324    }
325}
326
327impl std::fmt::Display for Box<dyn Op> {
328    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
329        write!(fmt, "{}", self.name())
330    }
331}
332
333impl std::fmt::Display for Box<dyn TypedOp> {
334    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
335        write!(fmt, "{}", self.name())
336    }
337}