Skip to main content

tract_core/ops/
mod.rs

1//! Ops
2use std::fmt;
3
4use downcast_rs::Downcast;
5
6use dyn_clone;
7use dyn_eq::DynEq;
8
9#[macro_use]
10pub mod macros;
11#[macro_use]
12pub mod element_wise;
13#[macro_use]
14pub mod binary;
15
16pub mod array;
17pub mod cast;
18pub mod change_axes;
19pub mod cnn;
20pub mod downsample;
21pub mod dummy;
22pub mod einsum;
23pub mod fft;
24pub mod identity;
25pub mod konst;
26pub mod logic;
27pub mod lstm_cell;
28pub mod math;
29pub mod matmul;
30// pub mod memory;
31pub mod nn;
32pub mod quant;
33pub mod scan;
34pub mod source;
35pub mod submodel;
36pub mod unimpl;
37
38pub use downsample::Downsample;
39pub use memory::*;
40
41use crate::internal::*;
42use crate::optim::OptimizerSession;
43
44/// Level of precision to be expected in implementations comparisons.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum Validation {
47    /// Output is random
48    Random,
49    /// Implementation may induce rounding errors
50    Rounding,
51    /// Implementation must be accurate
52    Accurate,
53}
54
55#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
56pub enum Cost {
57    Div(DatumType),
58    FMA(DatumType),
59    Buffer(DatumType),
60    Params(DatumType),
61    Custom(bool, String),
62}
63
64impl Cost {
65    pub fn is_compute(&self) -> bool {
66        use Cost::*;
67        match self {
68            FMA(_) | Div(_) => true,
69            Buffer(_) | Params(_) => false,
70            Custom(compute, _) => *compute,
71        }
72    }
73}
74
75impl std::fmt::Debug for Cost {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        use Cost::*;
78        match self {
79            Div(dt) => write!(f, "Div({dt:?})"),
80            FMA(dt) => write!(f, "FMA({dt:?})"),
81            Buffer(dt) => write!(f, "Buffer({dt:?})"),
82            Params(dt) => write!(f, "Params({dt:?})"),
83            Custom(_, name) => write!(f, "{name}"),
84        }
85    }
86}
87
88pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
89    fn unfreeze(&self) -> Box<dyn OpState>;
90}
91
92pub trait OpStateFreeze {
93    fn freeze(&self) -> Box<dyn FrozenOpState>;
94    /// Consuming freeze: moves data instead of cloning. Default delegates to freeze().
95    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
96        self.freeze()
97    }
98}
99
100dyn_clone::clone_trait_object!(FrozenOpState);
101
102pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
103    fn load_from(
104        &mut self,
105        _: &mut TurnState,
106        _: &mut dyn Iterator<Item = TValue>,
107    ) -> TractResult<()> {
108        Ok(())
109    }
110
111    fn save_to(&self, _: &mut Vec<TValue>) -> TractResult<()> {
112        Ok(())
113    }
114
115    fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
116        None
117    }
118
119    fn resolve_symbols(&mut self, _: &mut TurnState) -> TractResult<()> {
120        Ok(())
121    }
122
123    fn eval(
124        &mut self,
125        session: &mut TurnState,
126        op: &dyn Op,
127        inputs: TVec<TValue>,
128    ) -> TractResult<TVec<TValue>>;
129}
130dyn_clone::clone_trait_object!(OpState);
131impl_downcast!(OpState);
132
133pub trait EvalOp {
134    #[allow(unused_variables)]
135    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
136        bail!("stateless evaluation not implemented")
137    }
138
139    #[allow(unused_variables)]
140    fn eval_with_session(
141        &self,
142        node_id: usize,
143        session: &TurnState,
144        inputs: TVec<TValue>,
145    ) -> TractResult<TVec<TValue>> {
146        self.eval(inputs).context("Running legacy eval")
147    }
148
149    #[allow(unused_variables)]
150    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
151        Ok(None)
152    }
153
154    fn is_stateless(&self) -> bool;
155}
156
157/// A base operation
158pub trait Op:
159    fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast + EvalOp
160{
161    fn name(&self) -> StaticName;
162
163    /// The kind of accuracy check that should be performed on operation when
164    /// testing them.
165    fn validation(&self) -> Validation {
166        Validation::Accurate
167    }
168
169    /// Short (one-line) strings giving hints on internal implementation or
170    /// important configuration details to be displayed in dumps.
171    fn info(&self) -> TractResult<Vec<String>> {
172        Ok(vec![])
173    }
174
175    fn as_typed(&self) -> Option<&dyn TypedOp>;
176}
177
178impl_downcast!(Op);
179dyn_clone::clone_trait_object!(Op);
180dyn_eq::eq_trait_object!(Op);
181
182pub trait TypedOp:
183    Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
184{
185    /// Reinterpret the TypedOp as an Op.
186    fn as_op(&self) -> &dyn Op;
187
188    /// Reinterpret the TypedOp as an Op, mutably.
189    fn as_op_mut(&mut self) -> &mut dyn Op;
190
191    /// Deduce output facts from input facts.
192    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
193
194    #[allow(unused_variables)]
195    fn axes_mapping(
196        &self,
197        inputs: &[&TypedFact],
198        outputs: &[&TypedFact],
199    ) -> TractResult<AxesMapping> {
200        AxesMapping::disconnected(inputs, outputs)
201    }
202
203    /// Fuse op after codegen to deal with local optimisations.
204    fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
205        Ok(None)
206    }
207
208    /// Declutter the op to the tract_core operator set as much as possible.
209    #[allow(unused_variables)]
210    fn declutter_with_session(
211        &self,
212        session: &mut OptimizerSession,
213        model: &TypedModel,
214        node: &TypedNode,
215    ) -> TractResult<Option<TypedModelPatch>> {
216        self.declutter(model, node)
217    }
218
219    /// Declutter the op to the tract_core operator set as much as possible.
220    #[allow(unused_variables)]
221    fn declutter(
222        &self,
223        model: &TypedModel,
224        node: &TypedNode,
225    ) -> TractResult<Option<TypedModelPatch>> {
226        Ok(None)
227    }
228
229    /// Computes a cost hint of the operation.
230    ///
231    /// Each pair is a type of operation and a number per call on eval.
232    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
233        Ok(tvec!())
234    }
235
236    /// Derive ROI (region of interest) expressions for this node's inputs.
237    /// Called by the PropagateRoi pass. Default returns None (no propagation).
238    /// Override to introduce ROIs or bubble them through.
239    #[allow(unused_variables)]
240    fn input_roi(
241        &self,
242        model: &TypedModel,
243        node: &TypedNode,
244    ) -> TractResult<Option<TVec<Option<TDim>>>> {
245        Ok(None)
246    }
247
248    #[allow(unused_variables)]
249    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
250        Ok(tvec!())
251    }
252
253    #[allow(unused_variables)]
254    fn change_axes(
255        &self,
256        model: &TypedModel,
257        node: &TypedNode,
258        io: InOut,
259        change: &AxisOp,
260    ) -> TractResult<Option<AxisChangeConsequence>> {
261        Ok(None)
262    }
263
264    #[allow(unused_variables)]
265    #[allow(clippy::too_many_arguments)]
266    fn slice(
267        &self,
268        patch: &mut TypedModelPatch,
269        model: &TypedModel,
270        node: &TypedNode,
271        prefix: &str,
272        inputs: &[OutletId],
273        output_axis: usize,
274        start: &TDim,
275        end: &TDim,
276    ) -> TractResult<Option<TVec<OutletId>>> {
277        Ok(None)
278    }
279
280    /// Transforms the op in an equivalent one, operating on dt (i8 or u8).
281    ///
282    /// Returns None if the op can not be translated.
283    #[allow(unused_variables)]
284    fn quantize(
285        &self,
286        model: &TypedModel,
287        node: &TypedNode,
288        dt: DatumType,
289        scale: f32,
290        zero_point: i32,
291    ) -> TractResult<Option<Box<dyn TypedOp>>> {
292        Ok(None)
293    }
294
295    /// Transform the op by substituting one or more symbols with TDim
296    /// expressions (a concrete integer is `TDim::Val(v)`; an expression
297    /// can be any other TDim, including symbolic ones).
298    #[allow(unused_variables)]
299    fn set_symbols(
300        &self,
301        source: &TypedModel,
302        node: &TypedNode,
303        target: &mut TypedModel,
304        mapping: &HashMap<OutletId, OutletId>,
305        subs: &HashMap<Symbol, TDim>,
306    ) -> TractResult<TVec<OutletId>> {
307        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
308        target.wire_node(&node.name, node.op.clone(), &inputs)
309    }
310
311    /// Translate the op into the most efficient form possible for execution.
312    ///
313    /// This transformation is supposed to be final, no more pass are expected
314    /// to be run on the codegen networks.
315    #[allow(unused_variables)]
316    fn codegen(
317        &self,
318        model: &TypedModel,
319        node: &TypedNode,
320    ) -> TractResult<Option<TypedModelPatch>> {
321        Ok(None)
322    }
323
324    /// Nested model multipliers, with label (for profiling).
325    #[allow(unused_variables)]
326    fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
327        vec![]
328    }
329}
330
331impl_downcast!(TypedOp);
332dyn_clone::clone_trait_object!(TypedOp);
333dyn_eq::eq_trait_object!(TypedOp);
334
335impl<O: Op> From<O> for Box<dyn Op> {
336    fn from(it: O) -> Box<dyn Op> {
337        Box::new(it)
338    }
339}
340
341impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
342    fn from(it: O) -> Box<dyn TypedOp> {
343        Box::new(it)
344    }
345}
346
347impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
348    fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
349        it.clone()
350    }
351}
352
353impl AsRef<dyn Op> for dyn TypedOp {
354    fn as_ref(&self) -> &dyn Op {
355        self.as_op()
356    }
357}
358
359impl AsRef<dyn Op> for Box<dyn TypedOp> {
360    fn as_ref(&self) -> &dyn Op {
361        self.as_op()
362    }
363}
364
365impl AsMut<dyn Op> for dyn TypedOp {
366    fn as_mut(&mut self) -> &mut dyn Op {
367        self.as_op_mut()
368    }
369}
370
371impl AsMut<dyn Op> for Box<dyn TypedOp> {
372    fn as_mut(&mut self) -> &mut dyn Op {
373        self.as_op_mut()
374    }
375}
376
377impl std::fmt::Display for Box<dyn Op> {
378    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
379        write!(fmt, "{}", self.name())
380    }
381}
382
383impl std::fmt::Display for Box<dyn TypedOp> {
384    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
385        write!(fmt, "{}", self.name())
386    }
387}