1use std::fmt;
3
4use downcast_rs::Downcast;
5
6use dyn_clone;
7use dyn_eq::DynEq;
8
9#[macro_use]
10pub mod macros;
11#[macro_use]
12pub mod element_wise;
13#[macro_use]
14pub mod binary;
15
16pub mod array;
17pub mod cast;
18pub mod change_axes;
19pub mod cnn;
20pub mod downsample;
21pub mod dummy;
22pub mod einsum;
23pub mod fft;
24pub mod identity;
25pub mod konst;
26pub mod logic;
27pub mod lstm_cell;
28pub mod math;
29pub mod matmul;
30pub mod nn;
32pub mod quant;
33pub mod scan;
34pub mod source;
35pub mod submodel;
36pub mod unimpl;
37
38pub use downsample::Downsample;
39pub use memory::*;
40
41use crate::internal::*;
42use crate::optim::OptimizerSession;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum Validation {
47 Random,
49 Rounding,
51 Accurate,
53}
54
55#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
56pub enum Cost {
57 Div(DatumType),
58 FMA(DatumType),
59 Buffer(DatumType),
60 Params(DatumType),
61 Custom(bool, String),
62}
63
64impl Cost {
65 pub fn is_compute(&self) -> bool {
66 use Cost::*;
67 match self {
68 FMA(_) | Div(_) => true,
69 Buffer(_) | Params(_) => false,
70 Custom(compute, _) => *compute,
71 }
72 }
73}
74
75impl std::fmt::Debug for Cost {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 use Cost::*;
78 match self {
79 Div(dt) => write!(f, "Div({dt:?})"),
80 FMA(dt) => write!(f, "FMA({dt:?})"),
81 Buffer(dt) => write!(f, "Buffer({dt:?})"),
82 Params(dt) => write!(f, "Params({dt:?})"),
83 Custom(_, name) => write!(f, "{name}"),
84 }
85 }
86}
87
88pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
89 fn unfreeze(&self) -> Box<dyn OpState>;
90}
91
92pub trait OpStateFreeze {
93 fn freeze(&self) -> Box<dyn FrozenOpState>;
94 fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
96 self.freeze()
97 }
98}
99
100dyn_clone::clone_trait_object!(FrozenOpState);
101
102pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
103 fn load_from(
104 &mut self,
105 _: &mut TurnState,
106 _: &mut dyn Iterator<Item = TValue>,
107 ) -> TractResult<()> {
108 Ok(())
109 }
110
111 fn save_to(&self, _: &mut Vec<TValue>) -> TractResult<()> {
112 Ok(())
113 }
114
115 fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
116 None
117 }
118
119 fn resolve_symbols(&mut self, _: &mut TurnState) -> TractResult<()> {
120 Ok(())
121 }
122
123 fn eval(
124 &mut self,
125 session: &mut TurnState,
126 op: &dyn Op,
127 inputs: TVec<TValue>,
128 ) -> TractResult<TVec<TValue>>;
129}
130dyn_clone::clone_trait_object!(OpState);
131impl_downcast!(OpState);
132
133pub trait EvalOp {
134 #[allow(unused_variables)]
135 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
136 bail!("stateless evaluation not implemented")
137 }
138
139 #[allow(unused_variables)]
140 fn eval_with_session(
141 &self,
142 node_id: usize,
143 session: &TurnState,
144 inputs: TVec<TValue>,
145 ) -> TractResult<TVec<TValue>> {
146 self.eval(inputs).context("Running legacy eval")
147 }
148
149 #[allow(unused_variables)]
150 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
151 Ok(None)
152 }
153
154 fn is_stateless(&self) -> bool;
155}
156
157pub trait Op:
159 fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast + EvalOp
160{
161 fn name(&self) -> StaticName;
162
163 fn validation(&self) -> Validation {
166 Validation::Accurate
167 }
168
169 fn info(&self) -> TractResult<Vec<String>> {
172 Ok(vec![])
173 }
174
175 fn as_typed(&self) -> Option<&dyn TypedOp>;
176}
177
178impl_downcast!(Op);
179dyn_clone::clone_trait_object!(Op);
180dyn_eq::eq_trait_object!(Op);
181
182pub trait TypedOp:
183 Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
184{
185 fn as_op(&self) -> &dyn Op;
187
188 fn as_op_mut(&mut self) -> &mut dyn Op;
190
191 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
193
194 #[allow(unused_variables)]
195 fn axes_mapping(
196 &self,
197 inputs: &[&TypedFact],
198 outputs: &[&TypedFact],
199 ) -> TractResult<AxesMapping> {
200 AxesMapping::disconnected(inputs, outputs)
201 }
202
203 fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
205 Ok(None)
206 }
207
208 #[allow(unused_variables)]
210 fn declutter_with_session(
211 &self,
212 session: &mut OptimizerSession,
213 model: &TypedModel,
214 node: &TypedNode,
215 ) -> TractResult<Option<TypedModelPatch>> {
216 self.declutter(model, node)
217 }
218
219 #[allow(unused_variables)]
221 fn declutter(
222 &self,
223 model: &TypedModel,
224 node: &TypedNode,
225 ) -> TractResult<Option<TypedModelPatch>> {
226 Ok(None)
227 }
228
229 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
233 Ok(tvec!())
234 }
235
236 #[allow(unused_variables)]
240 fn input_roi(
241 &self,
242 model: &TypedModel,
243 node: &TypedNode,
244 ) -> TractResult<Option<TVec<Option<TDim>>>> {
245 Ok(None)
246 }
247
248 #[allow(unused_variables)]
249 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
250 Ok(tvec!())
251 }
252
253 #[allow(unused_variables)]
254 fn change_axes(
255 &self,
256 model: &TypedModel,
257 node: &TypedNode,
258 io: InOut,
259 change: &AxisOp,
260 ) -> TractResult<Option<AxisChangeConsequence>> {
261 Ok(None)
262 }
263
264 #[allow(unused_variables)]
265 #[allow(clippy::too_many_arguments)]
266 fn slice(
267 &self,
268 patch: &mut TypedModelPatch,
269 model: &TypedModel,
270 node: &TypedNode,
271 prefix: &str,
272 inputs: &[OutletId],
273 output_axis: usize,
274 start: &TDim,
275 end: &TDim,
276 ) -> TractResult<Option<TVec<OutletId>>> {
277 Ok(None)
278 }
279
280 #[allow(unused_variables)]
284 fn quantize(
285 &self,
286 model: &TypedModel,
287 node: &TypedNode,
288 dt: DatumType,
289 scale: f32,
290 zero_point: i32,
291 ) -> TractResult<Option<Box<dyn TypedOp>>> {
292 Ok(None)
293 }
294
295 #[allow(unused_variables)]
299 fn set_symbols(
300 &self,
301 source: &TypedModel,
302 node: &TypedNode,
303 target: &mut TypedModel,
304 mapping: &HashMap<OutletId, OutletId>,
305 subs: &HashMap<Symbol, TDim>,
306 ) -> TractResult<TVec<OutletId>> {
307 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
308 target.wire_node(&node.name, node.op.clone(), &inputs)
309 }
310
311 #[allow(unused_variables)]
316 fn codegen(
317 &self,
318 model: &TypedModel,
319 node: &TypedNode,
320 ) -> TractResult<Option<TypedModelPatch>> {
321 Ok(None)
322 }
323
324 #[allow(unused_variables)]
326 fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
327 vec![]
328 }
329}
330
331impl_downcast!(TypedOp);
332dyn_clone::clone_trait_object!(TypedOp);
333dyn_eq::eq_trait_object!(TypedOp);
334
335impl<O: Op> From<O> for Box<dyn Op> {
336 fn from(it: O) -> Box<dyn Op> {
337 Box::new(it)
338 }
339}
340
341impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
342 fn from(it: O) -> Box<dyn TypedOp> {
343 Box::new(it)
344 }
345}
346
347impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
348 fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
349 it.clone()
350 }
351}
352
353impl AsRef<dyn Op> for dyn TypedOp {
354 fn as_ref(&self) -> &dyn Op {
355 self.as_op()
356 }
357}
358
359impl AsRef<dyn Op> for Box<dyn TypedOp> {
360 fn as_ref(&self) -> &dyn Op {
361 self.as_op()
362 }
363}
364
365impl AsMut<dyn Op> for dyn TypedOp {
366 fn as_mut(&mut self) -> &mut dyn Op {
367 self.as_op_mut()
368 }
369}
370
371impl AsMut<dyn Op> for Box<dyn TypedOp> {
372 fn as_mut(&mut self) -> &mut dyn Op {
373 self.as_op_mut()
374 }
375}
376
377impl std::fmt::Display for Box<dyn Op> {
378 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
379 write!(fmt, "{}", self.name())
380 }
381}
382
383impl std::fmt::Display for Box<dyn TypedOp> {
384 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
385 write!(fmt, "{}", self.name())
386 }
387}