1use crate::model::{OnnxOpRegister, ParsingContext};
2use crate::pb::*;
3use tract_hir::internal::*;
4
5mod array;
6mod cast;
7pub mod cumsum;
8mod d2s;
9mod einsum;
10mod fft;
11pub mod logic;
12mod math;
13mod ml;
14pub mod multinomial;
15mod nn;
16mod non_max_suppression;
17mod quant;
18mod random;
19pub mod rec;
20mod resize;
21mod s2d;
22
23pub fn register_all_ops(reg: &mut OnnxOpRegister) {
24 reg.insert("Constant", konst);
25 reg.insert("Einsum", einsum::einsum);
26 reg.insert("Identity", |_, _| {
27 Ok((Box::<tract_hir::ops::identity::Identity>::default(), vec![]))
28 });
29 reg.insert("Resize", resize::resize);
30 reg.insert("NonMaxSuppression", non_max_suppression::non_max_suppression);
31 reg.insert("Multinomial", multinomial::multinomial);
32 array::register_all_ops(reg);
33 cast::register_all_ops(reg);
34 cumsum::register_all_ops(reg);
35 d2s::register_all_ops(reg);
36 fft::register_all_ops(reg);
37 logic::register_all_ops(reg);
38 math::register_all_ops(reg);
39 ml::register_all_ops(reg);
40 nn::register_all_ops(reg);
41 quant::register_all_ops(reg);
42 random::register_all_ops(reg);
43 rec::register_all_ops(reg);
44 s2d::register_all_ops(reg);
45}
46
47fn konst(
48 ctx: &ParsingContext,
49 node: &NodeProto,
50) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
51 let value = if let Some(v) = node.get_attr_opt("value")? {
52 ctx.load_tensor(v)?
53 } else if let Some(i) = node.get_attr_opt::<i64>("value_int")? {
54 tensor0(i)
55 } else if let Some(v) = node.get_attr_opt::<f32>("value_float")? {
56 tensor0(v)
57 } else {
58 bail!("Could not extract value out of Constant node")
59 };
60 Ok((Box::new(tract_hir::ops::konst::Const::new(value.into_arc_tensor())?), vec![]))
61}