1use std::fmt;
3
4use downcast_rs::Downcast;
5
6use dyn_clone;
7
8#[macro_use]
9pub mod macros;
10#[macro_use]
11pub mod element_wise;
12#[macro_use]
13pub mod binary;
14
15pub mod array;
16pub mod cast;
17pub mod change_axes;
18pub mod cnn;
19pub mod downsample;
20pub mod dummy;
21pub mod einsum;
22pub mod fft;
23pub mod identity;
24pub mod konst;
25pub mod logic;
26pub mod math;
27pub mod matmul;
28pub mod memory;
29pub mod nn;
30pub mod quant;
31pub mod scan;
32pub mod source;
33pub mod submodel;
34pub mod unimpl;
35
36pub use downsample::Downsample;
37pub use memory::*;
38
39use crate::internal::*;
40use crate::optim::OptimizerSession;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum Validation {
45 Random,
47 Rounding,
49 Accurate,
51}
52
53#[derive(Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
54pub enum Cost {
55 Div(DatumType),
56 FMA(DatumType),
57 Buffer(DatumType),
58 Params(DatumType),
59 Custom(bool, String),
60}
61
62impl Cost {
63 pub fn is_compute(&self) -> bool {
64 use Cost::*;
65 match self {
66 FMA(_) | Div(_) => true,
67 Buffer(_) | Params(_) => false,
68 Custom(compute, _) => *compute,
69 }
70 }
71}
72
73impl std::fmt::Debug for Cost {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 use Cost::*;
76 match self {
77 Div(dt) => write!(f, "Div({dt:?})"),
78 FMA(dt) => write!(f, "FMA({dt:?})"),
79 Buffer(dt) => write!(f, "Buffer({dt:?})"),
80 Params(dt) => write!(f, "Params({dt:?})"),
81 Custom(_, name) => write!(f, "{name}"),
82 }
83 }
84}
85
86pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
87 fn unfreeze(&self) -> Box<dyn OpState>;
88}
89
90pub trait OpStateFreeze {
91 fn freeze(&self) -> Box<dyn FrozenOpState>;
92}
93
94dyn_clone::clone_trait_object!(FrozenOpState);
95
96pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
97 fn load_from(&mut self, _: &mut SessionState, _: &mut Vec<TValue>) -> TractResult<()> {
98 Ok(())
99 }
100
101 fn save_to(&self, _: &mut Vec<TValue>) -> TractResult<()> {
102 Ok(())
103 }
104
105 fn init_tensor_fact(&self) -> Option<TypedFact> {
106 None
107 }
108
109 fn resolve_symbols(&mut self, _: &mut SessionState) -> TractResult<()> {
110 Ok(())
111 }
112
113 fn eval(
114 &mut self,
115 session: &mut SessionState,
116 op: &dyn Op,
117 inputs: TVec<TValue>,
118 ) -> TractResult<TVec<TValue>>;
119}
120dyn_clone::clone_trait_object!(OpState);
121impl_downcast!(OpState);
122
123pub trait EvalOp {
124 #[allow(unused_variables)]
125 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
126 bail!("stateless evaluation not implemented")
127 }
128
129 #[allow(unused_variables)]
130 fn eval_with_session(
131 &self,
132 node_id: usize,
133 session: &SessionState,
134 inputs: TVec<TValue>,
135 ) -> TractResult<TVec<TValue>> {
136 self.eval(inputs).context("Running legacy eval")
137 }
138
139 #[allow(unused_variables)]
140 fn state(
141 &self,
142 session: &mut SessionState,
143 node_id: usize,
144 ) -> TractResult<Option<Box<dyn OpState>>> {
145 Ok(None)
146 }
147
148 fn is_stateless(&self) -> bool;
149}
150
151pub trait Op: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp {
153 fn name(&self) -> StaticName;
154
155 fn validation(&self) -> Validation {
158 Validation::Accurate
159 }
160
161 fn same_as(&self, _other: &dyn Op) -> bool {
164 false
165 }
166
167 fn info(&self) -> TractResult<Vec<String>> {
170 Ok(vec![])
171 }
172
173 fn as_typed(&self) -> Option<&dyn TypedOp>;
174}
175
176pub trait TypedOp:
177 Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
178{
179 fn as_op(&self) -> &dyn Op;
181
182 fn as_op_mut(&mut self) -> &mut dyn Op;
184
185 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
187
188 #[allow(unused_variables)]
189 fn axes_mapping(
190 &self,
191 inputs: &[&TypedFact],
192 outputs: &[&TypedFact],
193 ) -> TractResult<AxesMapping> {
194 AxesMapping::disconnected(inputs, outputs)
195 }
196
197 fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
199 Ok(None)
200 }
201
202 #[allow(unused_variables)]
204 fn declutter_with_session(
205 &self,
206 session: &mut OptimizerSession,
207 model: &TypedModel,
208 node: &TypedNode,
209 ) -> TractResult<Option<TypedModelPatch>> {
210 self.declutter(model, node)
211 }
212
213 #[allow(unused_variables)]
215 fn declutter(
216 &self,
217 model: &TypedModel,
218 node: &TypedNode,
219 ) -> TractResult<Option<TypedModelPatch>> {
220 Ok(None)
221 }
222
223 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
227 Ok(tvec!())
228 }
229
230 #[allow(unused_variables)]
231 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
232 Ok(tvec!())
233 }
234
235 #[allow(unused_variables)]
236 fn change_axes(
237 &self,
238 model: &TypedModel,
239 node: &TypedNode,
240 io: InOut,
241 change: &AxisOp,
242 ) -> TractResult<Option<AxisChangeConsequence>> {
243 Ok(None)
244 }
245
246 #[allow(unused_variables)]
247 #[allow(clippy::too_many_arguments)]
248 fn slice(
249 &self,
250 patch: &mut TypedModelPatch,
251 model: &TypedModel,
252 node: &TypedNode,
253 prefix: &str,
254 inputs: &[OutletId],
255 output_axis: usize,
256 start: &TDim,
257 end: &TDim,
258 ) -> TractResult<Option<TVec<OutletId>>> {
259 Ok(None)
260 }
261
262 #[allow(unused_variables)]
266 fn quantize(
267 &self,
268 model: &TypedModel,
269 node: &TypedNode,
270 dt: DatumType,
271 scale: f32,
272 zero_point: i32,
273 ) -> TractResult<Option<Box<dyn TypedOp>>> {
274 Ok(None)
275 }
276
277 #[allow(unused_variables)]
279 fn concretize_dims(
280 &self,
281 source: &TypedModel,
282 node: &TypedNode,
283 target: &mut TypedModel,
284 mapping: &HashMap<OutletId, OutletId>,
285 values: &SymbolValues,
286 ) -> TractResult<TVec<OutletId>> {
287 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
288 target.wire_node(&node.name, node.op.clone(), &inputs)
289 }
290
291 #[allow(unused_variables)]
296 fn codegen(
297 &self,
298 model: &TypedModel,
299 node: &TypedNode,
300 ) -> TractResult<Option<TypedModelPatch>> {
301 Ok(None)
302 }
303
304 #[allow(unused_variables)]
306 fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
307 vec![]
308 }
309}
310
311impl_downcast!(Op);
312impl_downcast!(TypedOp);
313
314dyn_clone::clone_trait_object!(Op);
315dyn_clone::clone_trait_object!(TypedOp);
316
317impl<O: Op> From<O> for Box<dyn Op> {
318 fn from(it: O) -> Box<dyn Op> {
319 Box::new(it)
320 }
321}
322
323impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
324 fn from(it: O) -> Box<dyn TypedOp> {
325 Box::new(it)
326 }
327}
328
329impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
330 fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
331 it.clone()
332 }
333}
334
335impl AsRef<dyn Op> for dyn TypedOp {
336 fn as_ref(&self) -> &dyn Op {
337 self.as_op()
338 }
339}
340
341impl AsRef<dyn Op> for Box<dyn TypedOp> {
342 fn as_ref(&self) -> &dyn Op {
343 self.as_op()
344 }
345}
346
347impl AsMut<dyn Op> for dyn TypedOp {
348 fn as_mut(&mut self) -> &mut dyn Op {
349 self.as_op_mut()
350 }
351}
352
353impl AsMut<dyn Op> for Box<dyn TypedOp> {
354 fn as_mut(&mut self) -> &mut dyn Op {
355 self.as_op_mut()
356 }
357}
358
359impl std::fmt::Display for Box<dyn Op> {
360 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
361 write!(fmt, "{}", self.name())
362 }
363}
364
365impl std::fmt::Display for Box<dyn TypedOp> {
366 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
367 write!(fmt, "{}", self.name())
368 }
369}