tract_core/ops/
mod.rs

1//! Ops
2use std::fmt;
3
4use downcast_rs::Downcast;
5
6use dyn_clone;
7
8#[macro_use]
9pub mod macros;
10#[macro_use]
11pub mod element_wise;
12#[macro_use]
13pub mod binary;
14
15pub mod array;
16pub mod cast;
17pub mod change_axes;
18pub mod cnn;
19pub mod downsample;
20pub mod dummy;
21pub mod einsum;
22pub mod fft;
23pub mod identity;
24pub mod konst;
25pub mod logic;
26pub mod math;
27pub mod matmul;
28pub mod memory;
29pub mod nn;
30pub mod quant;
31pub mod scan;
32pub mod source;
33pub mod submodel;
34pub mod unimpl;
35
36pub use downsample::Downsample;
37pub use memory::*;
38
39use crate::internal::*;
40use crate::optim::OptimizerSession;
41
42/// Level of precision to be expected in implementations comparisons.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum Validation {
45    /// Output is random
46    Random,
47    /// Implementation may induce rounding errors
48    Rounding,
49    /// Implementation must be accurate
50    Accurate,
51}
52
53#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
54pub enum Cost {
55    Div(DatumType),
56    FMA(DatumType),
57    Buffer(DatumType),
58    Params(DatumType),
59    Custom(bool, String),
60}
61
62impl Cost {
63    pub fn is_compute(&self) -> bool {
64        use Cost::*;
65        match self {
66            FMA(_) | Div(_) => true,
67            Buffer(_) | Params(_) => false,
68            Custom(compute, _) => *compute,
69        }
70    }
71}
72
73impl std::fmt::Debug for Cost {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        use Cost::*;
76        match self {
77            Div(dt) => write!(f, "Div({dt:?})"),
78            FMA(dt) => write!(f, "FMA({dt:?})"),
79            Buffer(dt) => write!(f, "Buffer({dt:?})"),
80            Params(dt) => write!(f, "Params({dt:?})"),
81            Custom(_, name) => write!(f, "{name}"),
82        }
83    }
84}
85
86pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
87    fn unfreeze(&self) -> Box<dyn OpState>;
88}
89
90pub trait OpStateFreeze {
91    fn freeze(&self) -> Box<dyn FrozenOpState>;
92}
93
94dyn_clone::clone_trait_object!(FrozenOpState);
95
96pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
97    fn load_from(&mut self, _: &mut SessionState, _: &mut Vec<TValue>) -> TractResult<()> {
98        Ok(())
99    }
100
101    fn save_to(&self, _: &mut Vec<TValue>) -> TractResult<()> {
102        Ok(())
103    }
104
105    fn init_tensor_fact(&self) -> Option<TypedFact> {
106        None
107    }
108
109    fn resolve_symbols(&mut self, _: &mut SessionState) -> TractResult<()> {
110        Ok(())
111    }
112
113    fn eval(
114        &mut self,
115        session: &mut SessionState,
116        op: &dyn Op,
117        inputs: TVec<TValue>,
118    ) -> TractResult<TVec<TValue>>;
119}
120dyn_clone::clone_trait_object!(OpState);
121impl_downcast!(OpState);
122
123pub trait EvalOp {
124    #[allow(unused_variables)]
125    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
126        bail!("stateless evaluation not implemented")
127    }
128
129    #[allow(unused_variables)]
130    fn eval_with_session(
131        &self,
132        node_id: usize,
133        session: &SessionState,
134        inputs: TVec<TValue>,
135    ) -> TractResult<TVec<TValue>> {
136        self.eval(inputs).context("Running legacy eval")
137    }
138
139    #[allow(unused_variables)]
140    fn state(
141        &self,
142        session: &mut SessionState,
143        node_id: usize,
144    ) -> TractResult<Option<Box<dyn OpState>>> {
145        Ok(None)
146    }
147
148    fn is_stateless(&self) -> bool;
149}
150
151/// A base operation
152pub trait Op: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp {
153    fn name(&self) -> StaticName;
154
155    /// The kind of accuracy check that should be performed on operation when
156    /// testing them.
157    fn validation(&self) -> Validation {
158        Validation::Accurate
159    }
160
161    /// Compare two ops.
162    // Should this one be and Eq or PartialEq impl instead ?
163    fn same_as(&self, _other: &dyn Op) -> bool {
164        false
165    }
166
167    /// Short (one-line) strings giving hints on internal implementation or
168    /// important configuration details to be displayed in dumps.
169    fn info(&self) -> TractResult<Vec<String>> {
170        Ok(vec![])
171    }
172
173    fn as_typed(&self) -> Option<&dyn TypedOp>;
174}
175
176pub trait TypedOp:
177    Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
178{
179    /// Reinterpret the TypedOp as an Op.
180    fn as_op(&self) -> &dyn Op;
181
182    /// Reinterpret the TypedOp as an Op, mutably.
183    fn as_op_mut(&mut self) -> &mut dyn Op;
184
185    /// Deduce output facts from input facts.
186    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
187
188    #[allow(unused_variables)]
189    fn axes_mapping(
190        &self,
191        inputs: &[&TypedFact],
192        outputs: &[&TypedFact],
193    ) -> TractResult<AxesMapping> {
194        AxesMapping::disconnected(inputs, outputs)
195    }
196
197    /// Fuse op after codegen to deal with local optimisations.
198    fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
199        Ok(None)
200    }
201
202    /// Declutter the op to the tract_core operator set as much as possible.
203    #[allow(unused_variables)]
204    fn declutter_with_session(
205        &self,
206        session: &mut OptimizerSession,
207        model: &TypedModel,
208        node: &TypedNode,
209    ) -> TractResult<Option<TypedModelPatch>> {
210        self.declutter(model, node)
211    }
212
213    /// Declutter the op to the tract_core operator set as much as possible.
214    #[allow(unused_variables)]
215    fn declutter(
216        &self,
217        model: &TypedModel,
218        node: &TypedNode,
219    ) -> TractResult<Option<TypedModelPatch>> {
220        Ok(None)
221    }
222
223    /// Computes a cost hint of the operation.
224    ///
225    /// Each pair is a type of operation and a number per call on eval.
226    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
227        Ok(tvec!())
228    }
229
230    #[allow(unused_variables)]
231    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
232        Ok(tvec!())
233    }
234
235    #[allow(unused_variables)]
236    fn change_axes(
237        &self,
238        model: &TypedModel,
239        node: &TypedNode,
240        io: InOut,
241        change: &AxisOp,
242    ) -> TractResult<Option<AxisChangeConsequence>> {
243        Ok(None)
244    }
245
246    #[allow(unused_variables)]
247    #[allow(clippy::too_many_arguments)]
248    fn slice(
249        &self,
250        patch: &mut TypedModelPatch,
251        model: &TypedModel,
252        node: &TypedNode,
253        prefix: &str,
254        inputs: &[OutletId],
255        output_axis: usize,
256        start: &TDim,
257        end: &TDim,
258    ) -> TractResult<Option<TVec<OutletId>>> {
259        Ok(None)
260    }
261
262    /// Transforms the op in an equivalent one, operating on dt (i8 or u8).
263    ///
264    /// Returns None if the op can not be translated.
265    #[allow(unused_variables)]
266    fn quantize(
267        &self,
268        model: &TypedModel,
269        node: &TypedNode,
270        dt: DatumType,
271        scale: f32,
272        zero_point: i32,
273    ) -> TractResult<Option<Box<dyn TypedOp>>> {
274        Ok(None)
275    }
276
277    /// Transform the op into by providing a value to one or more symbols.
278    #[allow(unused_variables)]
279    fn concretize_dims(
280        &self,
281        source: &TypedModel,
282        node: &TypedNode,
283        target: &mut TypedModel,
284        mapping: &HashMap<OutletId, OutletId>,
285        values: &SymbolValues,
286    ) -> TractResult<TVec<OutletId>> {
287        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
288        target.wire_node(&node.name, node.op.clone(), &inputs)
289    }
290
291    /// Translate the op into the most efficient form possible for execution.
292    ///
293    /// This transformation is supposed to be final, no more pass are expected
294    /// to be run on the codegen networks.
295    #[allow(unused_variables)]
296    fn codegen(
297        &self,
298        model: &TypedModel,
299        node: &TypedNode,
300    ) -> TractResult<Option<TypedModelPatch>> {
301        Ok(None)
302    }
303
304    /// Nested model multipliers, with label (for profiling).
305    #[allow(unused_variables)]
306    fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
307        vec![]
308    }
309}
310
311impl_downcast!(Op);
312impl_downcast!(TypedOp);
313
314dyn_clone::clone_trait_object!(Op);
315dyn_clone::clone_trait_object!(TypedOp);
316
317impl<O: Op> From<O> for Box<dyn Op> {
318    fn from(it: O) -> Box<dyn Op> {
319        Box::new(it)
320    }
321}
322
323impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
324    fn from(it: O) -> Box<dyn TypedOp> {
325        Box::new(it)
326    }
327}
328
329impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
330    fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
331        it.clone()
332    }
333}
334
335impl AsRef<dyn Op> for dyn TypedOp {
336    fn as_ref(&self) -> &dyn Op {
337        self.as_op()
338    }
339}
340
341impl AsRef<dyn Op> for Box<dyn TypedOp> {
342    fn as_ref(&self) -> &dyn Op {
343        self.as_op()
344    }
345}
346
347impl AsMut<dyn Op> for dyn TypedOp {
348    fn as_mut(&mut self) -> &mut dyn Op {
349        self.as_op_mut()
350    }
351}
352
353impl AsMut<dyn Op> for Box<dyn TypedOp> {
354    fn as_mut(&mut self) -> &mut dyn Op {
355        self.as_op_mut()
356    }
357}
358
359impl std::fmt::Display for Box<dyn Op> {
360    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
361        write!(fmt, "{}", self.name())
362    }
363}
364
365impl std::fmt::Display for Box<dyn TypedOp> {
366    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
367        write!(fmt, "{}", self.name())
368    }
369}