1use std::fmt;
3
4use downcast_rs::Downcast;
5
6use dyn_clone;
7use dyn_eq::DynEq;
8
9#[macro_use]
10pub mod macros;
11#[macro_use]
12pub mod element_wise;
13#[macro_use]
14pub mod binary;
15
16pub mod array;
17pub mod cast;
18pub mod change_axes;
19pub mod cnn;
20pub mod downsample;
21pub mod dummy;
22pub mod einsum;
23pub mod fft;
24pub mod identity;
25pub mod konst;
26pub mod logic;
27pub mod math;
28pub mod matmul;
29pub mod nn;
31pub mod quant;
32pub mod scan;
33pub mod source;
34pub mod submodel;
35pub mod unimpl;
36
37pub use downsample::Downsample;
38pub use memory::*;
39
40use crate::internal::*;
41use crate::optim::OptimizerSession;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub enum Validation {
46 Random,
48 Rounding,
50 Accurate,
52}
53
54#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
55pub enum Cost {
56 Div(DatumType),
57 FMA(DatumType),
58 Buffer(DatumType),
59 Params(DatumType),
60 Custom(bool, String),
61}
62
63impl Cost {
64 pub fn is_compute(&self) -> bool {
65 use Cost::*;
66 match self {
67 FMA(_) | Div(_) => true,
68 Buffer(_) | Params(_) => false,
69 Custom(compute, _) => *compute,
70 }
71 }
72}
73
74impl std::fmt::Debug for Cost {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 use Cost::*;
77 match self {
78 Div(dt) => write!(f, "Div({dt:?})"),
79 FMA(dt) => write!(f, "FMA({dt:?})"),
80 Buffer(dt) => write!(f, "Buffer({dt:?})"),
81 Params(dt) => write!(f, "Params({dt:?})"),
82 Custom(_, name) => write!(f, "{name}"),
83 }
84 }
85}
86
87pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
88 fn unfreeze(&self) -> Box<dyn OpState>;
89}
90
91pub trait OpStateFreeze {
92 fn freeze(&self) -> Box<dyn FrozenOpState>;
93 fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
95 self.freeze()
96 }
97}
98
99dyn_clone::clone_trait_object!(FrozenOpState);
100
101pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
102 fn load_from(
103 &mut self,
104 _: &mut TurnState,
105 _: &mut dyn Iterator<Item = TValue>,
106 ) -> TractResult<()> {
107 Ok(())
108 }
109
110 fn save_to(&self, _: &mut Vec<TValue>) -> TractResult<()> {
111 Ok(())
112 }
113
114 fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
115 None
116 }
117
118 fn resolve_symbols(&mut self, _: &mut TurnState) -> TractResult<()> {
119 Ok(())
120 }
121
122 fn eval(
123 &mut self,
124 session: &mut TurnState,
125 op: &dyn Op,
126 inputs: TVec<TValue>,
127 ) -> TractResult<TVec<TValue>>;
128}
129dyn_clone::clone_trait_object!(OpState);
130impl_downcast!(OpState);
131
132pub trait EvalOp {
133 #[allow(unused_variables)]
134 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
135 bail!("stateless evaluation not implemented")
136 }
137
138 #[allow(unused_variables)]
139 fn eval_with_session(
140 &self,
141 node_id: usize,
142 session: &TurnState,
143 inputs: TVec<TValue>,
144 ) -> TractResult<TVec<TValue>> {
145 self.eval(inputs).context("Running legacy eval")
146 }
147
148 #[allow(unused_variables)]
149 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
150 Ok(None)
151 }
152
153 fn is_stateless(&self) -> bool;
154}
155
156pub trait Op:
158 fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast + EvalOp
159{
160 fn name(&self) -> StaticName;
161
162 fn validation(&self) -> Validation {
165 Validation::Accurate
166 }
167
168 fn info(&self) -> TractResult<Vec<String>> {
171 Ok(vec![])
172 }
173
174 fn as_typed(&self) -> Option<&dyn TypedOp>;
175}
176
177impl_downcast!(Op);
178dyn_clone::clone_trait_object!(Op);
179dyn_eq::eq_trait_object!(Op);
180
181pub trait TypedOp:
182 Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
183{
184 fn as_op(&self) -> &dyn Op;
186
187 fn as_op_mut(&mut self) -> &mut dyn Op;
189
190 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
192
193 #[allow(unused_variables)]
194 fn axes_mapping(
195 &self,
196 inputs: &[&TypedFact],
197 outputs: &[&TypedFact],
198 ) -> TractResult<AxesMapping> {
199 AxesMapping::disconnected(inputs, outputs)
200 }
201
202 fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
204 Ok(None)
205 }
206
207 #[allow(unused_variables)]
209 fn declutter_with_session(
210 &self,
211 session: &mut OptimizerSession,
212 model: &TypedModel,
213 node: &TypedNode,
214 ) -> TractResult<Option<TypedModelPatch>> {
215 self.declutter(model, node)
216 }
217
218 #[allow(unused_variables)]
220 fn declutter(
221 &self,
222 model: &TypedModel,
223 node: &TypedNode,
224 ) -> TractResult<Option<TypedModelPatch>> {
225 Ok(None)
226 }
227
228 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
232 Ok(tvec!())
233 }
234
235 #[allow(unused_variables)]
239 fn input_roi(
240 &self,
241 model: &TypedModel,
242 node: &TypedNode,
243 ) -> TractResult<Option<TVec<Option<TDim>>>> {
244 Ok(None)
245 }
246
247 #[allow(unused_variables)]
248 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
249 Ok(tvec!())
250 }
251
252 #[allow(unused_variables)]
253 fn change_axes(
254 &self,
255 model: &TypedModel,
256 node: &TypedNode,
257 io: InOut,
258 change: &AxisOp,
259 ) -> TractResult<Option<AxisChangeConsequence>> {
260 Ok(None)
261 }
262
263 #[allow(unused_variables)]
264 #[allow(clippy::too_many_arguments)]
265 fn slice(
266 &self,
267 patch: &mut TypedModelPatch,
268 model: &TypedModel,
269 node: &TypedNode,
270 prefix: &str,
271 inputs: &[OutletId],
272 output_axis: usize,
273 start: &TDim,
274 end: &TDim,
275 ) -> TractResult<Option<TVec<OutletId>>> {
276 Ok(None)
277 }
278
279 #[allow(unused_variables)]
283 fn quantize(
284 &self,
285 model: &TypedModel,
286 node: &TypedNode,
287 dt: DatumType,
288 scale: f32,
289 zero_point: i32,
290 ) -> TractResult<Option<Box<dyn TypedOp>>> {
291 Ok(None)
292 }
293
294 #[allow(unused_variables)]
296 fn concretize_dims(
297 &self,
298 source: &TypedModel,
299 node: &TypedNode,
300 target: &mut TypedModel,
301 mapping: &HashMap<OutletId, OutletId>,
302 values: &SymbolValues,
303 ) -> TractResult<TVec<OutletId>> {
304 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
305 target.wire_node(&node.name, node.op.clone(), &inputs)
306 }
307
308 #[allow(unused_variables)]
313 fn codegen(
314 &self,
315 model: &TypedModel,
316 node: &TypedNode,
317 ) -> TractResult<Option<TypedModelPatch>> {
318 Ok(None)
319 }
320
321 #[allow(unused_variables)]
323 fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
324 vec![]
325 }
326}
327
328impl_downcast!(TypedOp);
329dyn_clone::clone_trait_object!(TypedOp);
330dyn_eq::eq_trait_object!(TypedOp);
331
332impl<O: Op> From<O> for Box<dyn Op> {
333 fn from(it: O) -> Box<dyn Op> {
334 Box::new(it)
335 }
336}
337
338impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
339 fn from(it: O) -> Box<dyn TypedOp> {
340 Box::new(it)
341 }
342}
343
344impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
345 fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
346 it.clone()
347 }
348}
349
350impl AsRef<dyn Op> for dyn TypedOp {
351 fn as_ref(&self) -> &dyn Op {
352 self.as_op()
353 }
354}
355
356impl AsRef<dyn Op> for Box<dyn TypedOp> {
357 fn as_ref(&self) -> &dyn Op {
358 self.as_op()
359 }
360}
361
362impl AsMut<dyn Op> for dyn TypedOp {
363 fn as_mut(&mut self) -> &mut dyn Op {
364 self.as_op_mut()
365 }
366}
367
368impl AsMut<dyn Op> for Box<dyn TypedOp> {
369 fn as_mut(&mut self) -> &mut dyn Op {
370 self.as_op_mut()
371 }
372}
373
374impl std::fmt::Display for Box<dyn Op> {
375 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
376 write!(fmt, "{}", self.name())
377 }
378}
379
380impl std::fmt::Display for Box<dyn TypedOp> {
381 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
382 write!(fmt, "{}", self.name())
383 }
384}