Skip to main content

tract_core/ops/
mod.rs

1//! Ops
2use std::fmt;
3
4use downcast_rs::Downcast;
5
6use dyn_clone;
7use dyn_eq::DynEq;
8
9#[macro_use]
10pub mod macros;
11#[macro_use]
12pub mod element_wise;
13#[macro_use]
14pub mod binary;
15
16pub mod array;
17pub mod cast;
18pub mod change_axes;
19pub mod cnn;
20pub mod downsample;
21pub mod dummy;
22pub mod einsum;
23pub mod fft;
24pub mod identity;
25pub mod konst;
26pub mod logic;
27pub mod math;
28pub mod matmul;
29// pub mod memory;
30pub mod nn;
31pub mod quant;
32pub mod scan;
33pub mod source;
34pub mod submodel;
35pub mod unimpl;
36
37pub use downsample::Downsample;
38pub use memory::*;
39
40use crate::internal::*;
41use crate::optim::OptimizerSession;
42
43/// Level of precision to be expected in implementations comparisons.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub enum Validation {
46    /// Output is random
47    Random,
48    /// Implementation may induce rounding errors
49    Rounding,
50    /// Implementation must be accurate
51    Accurate,
52}
53
54#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
55pub enum Cost {
56    Div(DatumType),
57    FMA(DatumType),
58    Buffer(DatumType),
59    Params(DatumType),
60    Custom(bool, String),
61}
62
63impl Cost {
64    pub fn is_compute(&self) -> bool {
65        use Cost::*;
66        match self {
67            FMA(_) | Div(_) => true,
68            Buffer(_) | Params(_) => false,
69            Custom(compute, _) => *compute,
70        }
71    }
72}
73
74impl std::fmt::Debug for Cost {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        use Cost::*;
77        match self {
78            Div(dt) => write!(f, "Div({dt:?})"),
79            FMA(dt) => write!(f, "FMA({dt:?})"),
80            Buffer(dt) => write!(f, "Buffer({dt:?})"),
81            Params(dt) => write!(f, "Params({dt:?})"),
82            Custom(_, name) => write!(f, "{name}"),
83        }
84    }
85}
86
87pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
88    fn unfreeze(&self) -> Box<dyn OpState>;
89}
90
91pub trait OpStateFreeze {
92    fn freeze(&self) -> Box<dyn FrozenOpState>;
93    /// Consuming freeze: moves data instead of cloning. Default delegates to freeze().
94    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
95        self.freeze()
96    }
97}
98
99dyn_clone::clone_trait_object!(FrozenOpState);
100
101pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
102    fn load_from(
103        &mut self,
104        _: &mut TurnState,
105        _: &mut dyn Iterator<Item = TValue>,
106    ) -> TractResult<()> {
107        Ok(())
108    }
109
110    fn save_to(&self, _: &mut Vec<TValue>) -> TractResult<()> {
111        Ok(())
112    }
113
114    fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
115        None
116    }
117
118    fn resolve_symbols(&mut self, _: &mut TurnState) -> TractResult<()> {
119        Ok(())
120    }
121
122    fn eval(
123        &mut self,
124        session: &mut TurnState,
125        op: &dyn Op,
126        inputs: TVec<TValue>,
127    ) -> TractResult<TVec<TValue>>;
128}
129dyn_clone::clone_trait_object!(OpState);
130impl_downcast!(OpState);
131
132pub trait EvalOp {
133    #[allow(unused_variables)]
134    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
135        bail!("stateless evaluation not implemented")
136    }
137
138    #[allow(unused_variables)]
139    fn eval_with_session(
140        &self,
141        node_id: usize,
142        session: &TurnState,
143        inputs: TVec<TValue>,
144    ) -> TractResult<TVec<TValue>> {
145        self.eval(inputs).context("Running legacy eval")
146    }
147
148    #[allow(unused_variables)]
149    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
150        Ok(None)
151    }
152
153    fn is_stateless(&self) -> bool;
154}
155
156/// A base operation
157pub trait Op:
158    fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast + EvalOp
159{
160    fn name(&self) -> StaticName;
161
162    /// The kind of accuracy check that should be performed on operation when
163    /// testing them.
164    fn validation(&self) -> Validation {
165        Validation::Accurate
166    }
167
168    /// Short (one-line) strings giving hints on internal implementation or
169    /// important configuration details to be displayed in dumps.
170    fn info(&self) -> TractResult<Vec<String>> {
171        Ok(vec![])
172    }
173
174    fn as_typed(&self) -> Option<&dyn TypedOp>;
175}
176
177impl_downcast!(Op);
178dyn_clone::clone_trait_object!(Op);
179dyn_eq::eq_trait_object!(Op);
180
181pub trait TypedOp:
182    Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
183{
184    /// Reinterpret the TypedOp as an Op.
185    fn as_op(&self) -> &dyn Op;
186
187    /// Reinterpret the TypedOp as an Op, mutably.
188    fn as_op_mut(&mut self) -> &mut dyn Op;
189
190    /// Deduce output facts from input facts.
191    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
192
193    #[allow(unused_variables)]
194    fn axes_mapping(
195        &self,
196        inputs: &[&TypedFact],
197        outputs: &[&TypedFact],
198    ) -> TractResult<AxesMapping> {
199        AxesMapping::disconnected(inputs, outputs)
200    }
201
202    /// Fuse op after codegen to deal with local optimisations.
203    fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
204        Ok(None)
205    }
206
207    /// Declutter the op to the tract_core operator set as much as possible.
208    #[allow(unused_variables)]
209    fn declutter_with_session(
210        &self,
211        session: &mut OptimizerSession,
212        model: &TypedModel,
213        node: &TypedNode,
214    ) -> TractResult<Option<TypedModelPatch>> {
215        self.declutter(model, node)
216    }
217
218    /// Declutter the op to the tract_core operator set as much as possible.
219    #[allow(unused_variables)]
220    fn declutter(
221        &self,
222        model: &TypedModel,
223        node: &TypedNode,
224    ) -> TractResult<Option<TypedModelPatch>> {
225        Ok(None)
226    }
227
228    /// Computes a cost hint of the operation.
229    ///
230    /// Each pair is a type of operation and a number per call on eval.
231    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
232        Ok(tvec!())
233    }
234
235    /// Derive ROI (region of interest) expressions for this node's inputs.
236    /// Called by the PropagateRoi pass. Default returns None (no propagation).
237    /// Override to introduce ROIs or bubble them through.
238    #[allow(unused_variables)]
239    fn input_roi(
240        &self,
241        model: &TypedModel,
242        node: &TypedNode,
243    ) -> TractResult<Option<TVec<Option<TDim>>>> {
244        Ok(None)
245    }
246
247    #[allow(unused_variables)]
248    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
249        Ok(tvec!())
250    }
251
252    #[allow(unused_variables)]
253    fn change_axes(
254        &self,
255        model: &TypedModel,
256        node: &TypedNode,
257        io: InOut,
258        change: &AxisOp,
259    ) -> TractResult<Option<AxisChangeConsequence>> {
260        Ok(None)
261    }
262
263    #[allow(unused_variables)]
264    #[allow(clippy::too_many_arguments)]
265    fn slice(
266        &self,
267        patch: &mut TypedModelPatch,
268        model: &TypedModel,
269        node: &TypedNode,
270        prefix: &str,
271        inputs: &[OutletId],
272        output_axis: usize,
273        start: &TDim,
274        end: &TDim,
275    ) -> TractResult<Option<TVec<OutletId>>> {
276        Ok(None)
277    }
278
279    /// Transforms the op in an equivalent one, operating on dt (i8 or u8).
280    ///
281    /// Returns None if the op can not be translated.
282    #[allow(unused_variables)]
283    fn quantize(
284        &self,
285        model: &TypedModel,
286        node: &TypedNode,
287        dt: DatumType,
288        scale: f32,
289        zero_point: i32,
290    ) -> TractResult<Option<Box<dyn TypedOp>>> {
291        Ok(None)
292    }
293
294    /// Transform the op into by providing a value to one or more symbols.
295    #[allow(unused_variables)]
296    fn concretize_dims(
297        &self,
298        source: &TypedModel,
299        node: &TypedNode,
300        target: &mut TypedModel,
301        mapping: &HashMap<OutletId, OutletId>,
302        values: &SymbolValues,
303    ) -> TractResult<TVec<OutletId>> {
304        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
305        target.wire_node(&node.name, node.op.clone(), &inputs)
306    }
307
308    /// Translate the op into the most efficient form possible for execution.
309    ///
310    /// This transformation is supposed to be final, no more pass are expected
311    /// to be run on the codegen networks.
312    #[allow(unused_variables)]
313    fn codegen(
314        &self,
315        model: &TypedModel,
316        node: &TypedNode,
317    ) -> TractResult<Option<TypedModelPatch>> {
318        Ok(None)
319    }
320
321    /// Nested model multipliers, with label (for profiling).
322    #[allow(unused_variables)]
323    fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
324        vec![]
325    }
326}
327
328impl_downcast!(TypedOp);
329dyn_clone::clone_trait_object!(TypedOp);
330dyn_eq::eq_trait_object!(TypedOp);
331
332impl<O: Op> From<O> for Box<dyn Op> {
333    fn from(it: O) -> Box<dyn Op> {
334        Box::new(it)
335    }
336}
337
338impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
339    fn from(it: O) -> Box<dyn TypedOp> {
340        Box::new(it)
341    }
342}
343
344impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
345    fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
346        it.clone()
347    }
348}
349
350impl AsRef<dyn Op> for dyn TypedOp {
351    fn as_ref(&self) -> &dyn Op {
352        self.as_op()
353    }
354}
355
356impl AsRef<dyn Op> for Box<dyn TypedOp> {
357    fn as_ref(&self) -> &dyn Op {
358        self.as_op()
359    }
360}
361
362impl AsMut<dyn Op> for dyn TypedOp {
363    fn as_mut(&mut self) -> &mut dyn Op {
364        self.as_op_mut()
365    }
366}
367
368impl AsMut<dyn Op> for Box<dyn TypedOp> {
369    fn as_mut(&mut self) -> &mut dyn Op {
370        self.as_op_mut()
371    }
372}
373
374impl std::fmt::Display for Box<dyn Op> {
375    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
376        write!(fmt, "{}", self.name())
377    }
378}
379
380impl std::fmt::Display for Box<dyn TypedOp> {
381    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
382        write!(fmt, "{}", self.name())
383    }
384}