1use std::fmt;
3
4use downcast_rs::Downcast;
5
6use dyn_clone;
7
8#[macro_use]
9pub mod macros;
10#[macro_use]
11pub mod element_wise;
12#[macro_use]
13pub mod binary;
14
15pub mod array;
16pub mod cast;
17pub mod change_axes;
18pub mod cnn;
19pub mod downsample;
20pub mod dummy;
21pub mod einsum;
22pub mod fft;
23pub mod identity;
24pub mod konst;
25pub mod logic;
26pub mod math;
27pub mod matmul;
28pub mod memory;
29pub mod nn;
30pub mod quant;
31pub mod scan;
32pub mod source;
33pub mod submodel;
34pub mod unimpl;
35
36pub use downsample::Downsample;
37pub use memory::*;
38
39use crate::internal::*;
40use crate::optim::OptimizerSession;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum Validation {
45 Random,
47 Rounding,
49 Accurate,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
54pub enum Cost {
55 Div(DatumType),
56 FMA(DatumType),
57 Buffer(DatumType),
58 Params(DatumType),
59}
60
61impl Cost {
62 pub fn is_compute(&self) -> bool {
63 use Cost::*;
64 match self {
65 FMA(_) | Div(_) => true,
66 Buffer(_) | Params(_) => false,
67 }
68 }
69}
70
71pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
72 fn unfreeze(&self) -> Box<dyn OpState>;
73}
74
75pub trait OpStateFreeze {
76 fn freeze(&self) -> Box<dyn FrozenOpState>;
77}
78
79dyn_clone::clone_trait_object!(FrozenOpState);
80
81pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
82 fn eval(
83 &mut self,
84 session: &mut SessionState,
85 op: &dyn Op,
86 inputs: TVec<TValue>,
87 ) -> TractResult<TVec<TValue>>;
88}
89dyn_clone::clone_trait_object!(OpState);
90impl_downcast!(OpState);
91
92pub trait EvalOp {
93 #[allow(unused_variables)]
94 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
95 bail!("stateless evaluation not implemented")
96 }
97
98 #[allow(unused_variables)]
99 fn eval_with_session(
100 &self,
101 session: &SessionState,
102 inputs: TVec<TValue>,
103 ) -> TractResult<TVec<TValue>> {
104 self.eval(inputs).context("Running legacy eval")
105 }
106
107 #[allow(unused_variables)]
108 fn state(
109 &self,
110 session: &mut SessionState,
111 node_id: usize,
112 ) -> TractResult<Option<Box<dyn OpState>>> {
113 Ok(None)
114 }
115
116 fn is_stateless(&self) -> bool;
117}
118
119pub trait Op: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp {
121 fn name(&self) -> Cow<str>;
122
123 fn validation(&self) -> Validation {
126 Validation::Accurate
127 }
128
129 fn same_as(&self, _other: &dyn Op) -> bool {
132 false
133 }
134
135 fn info(&self) -> TractResult<Vec<String>> {
138 Ok(vec![])
139 }
140
141 fn as_typed(&self) -> Option<&dyn TypedOp>;
142}
143
144pub trait TypedOp:
145 Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
146{
147 fn as_op(&self) -> &dyn Op;
149
150 fn as_op_mut(&mut self) -> &mut dyn Op;
152
153 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
155
156 #[allow(unused_variables)]
157 fn axes_mapping(
158 &self,
159 inputs: &[&TypedFact],
160 outputs: &[&TypedFact],
161 ) -> TractResult<AxesMapping> {
162 AxesMapping::disconnected(inputs, outputs)
163 }
164
165 fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
167 Ok(None)
168 }
169
170 #[allow(unused_variables)]
172 fn declutter_with_session(
173 &self,
174 session: &mut OptimizerSession,
175 model: &TypedModel,
176 node: &TypedNode,
177 ) -> TractResult<Option<TypedModelPatch>> {
178 self.declutter(model, node)
179 }
180
181 #[allow(unused_variables)]
183 fn declutter(
184 &self,
185 model: &TypedModel,
186 node: &TypedNode,
187 ) -> TractResult<Option<TypedModelPatch>> {
188 Ok(None)
189 }
190
191 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
195 Ok(tvec!())
196 }
197
198 #[allow(unused_variables)]
199 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
200 Ok(tvec!())
201 }
202
203 #[allow(unused_variables)]
204 fn change_axes(
205 &self,
206 model: &TypedModel,
207 node: &TypedNode,
208 io: InOut,
209 change: &AxisOp,
210 ) -> TractResult<Option<AxisChangeConsequence>> {
211 Ok(None)
212 }
213
214 #[allow(unused_variables)]
215 #[allow(clippy::too_many_arguments)]
216 fn slice(
217 &self,
218 patch: &mut TypedModelPatch,
219 model: &TypedModel,
220 node: &TypedNode,
221 prefix: &str,
222 inputs: &[OutletId],
223 output_axis: usize,
224 start: usize,
225 end: usize,
226 ) -> TractResult<Option<TVec<OutletId>>> {
227 Ok(None)
228 }
229
230 #[allow(unused_variables)]
234 fn quantize(
235 &self,
236 model: &TypedModel,
237 node: &TypedNode,
238 dt: DatumType,
239 scale: f32,
240 zero_point: i32,
241 ) -> TractResult<Option<Box<dyn TypedOp>>> {
242 Ok(None)
243 }
244
245 #[allow(unused_variables)]
247 fn concretize_dims(
248 &self,
249 source: &TypedModel,
250 node: &TypedNode,
251 target: &mut TypedModel,
252 mapping: &HashMap<OutletId, OutletId>,
253 values: &SymbolValues,
254 ) -> TractResult<TVec<OutletId>> {
255 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
256 target.wire_node(&node.name, node.op.clone(), &inputs)
257 }
258
259 #[allow(unused_variables)]
264 fn codegen(
265 &self,
266 model: &TypedModel,
267 node: &TypedNode,
268 ) -> TractResult<Option<TypedModelPatch>> {
269 Ok(None)
270 }
271
272 #[allow(unused_variables)]
274 fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(Cow<str>, f64)> {
275 vec![]
276 }
277}
278
279impl_downcast!(Op);
280impl_downcast!(TypedOp);
281
282dyn_clone::clone_trait_object!(Op);
283dyn_clone::clone_trait_object!(TypedOp);
284
285impl<O: Op> From<O> for Box<dyn Op> {
286 fn from(it: O) -> Box<dyn Op> {
287 Box::new(it)
288 }
289}
290
291impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
292 fn from(it: O) -> Box<dyn TypedOp> {
293 Box::new(it)
294 }
295}
296
297impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
298 fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
299 it.clone()
300 }
301}
302
303impl AsRef<dyn Op> for dyn TypedOp {
304 fn as_ref(&self) -> &dyn Op {
305 self.as_op()
306 }
307}
308
309impl AsRef<dyn Op> for Box<dyn TypedOp> {
310 fn as_ref(&self) -> &dyn Op {
311 self.as_op()
312 }
313}
314
315impl AsMut<dyn Op> for dyn TypedOp {
316 fn as_mut(&mut self) -> &mut dyn Op {
317 self.as_op_mut()
318 }
319}
320
321impl AsMut<dyn Op> for Box<dyn TypedOp> {
322 fn as_mut(&mut self) -> &mut dyn Op {
323 self.as_op_mut()
324 }
325}
326
327impl std::fmt::Display for Box<dyn Op> {
328 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
329 write!(fmt, "{}", self.name())
330 }
331}
332
333impl std::fmt::Display for Box<dyn TypedOp> {
334 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
335 write!(fmt, "{}", self.name())
336 }
337}