tract_onnx/ops/
mod.rs

1use crate::model::{OnnxOpRegister, ParsingContext};
2use crate::pb::*;
3use tract_hir::internal::*;
4
5mod array;
6mod cast;
7pub mod cumsum;
8mod d2s;
9mod einsum;
10mod fft;
11pub mod logic;
12mod math;
13mod ml;
14pub mod multinomial;
15mod nn;
16mod non_max_suppression;
17mod quant;
18mod random;
19pub mod rec;
20mod resize;
21mod s2d;
22
23pub fn register_all_ops(reg: &mut OnnxOpRegister) {
24    reg.insert("Constant", konst);
25    reg.insert("Einsum", einsum::einsum);
26    reg.insert("Identity", |_, _| {
27        Ok((Box::<tract_hir::ops::identity::Identity>::default(), vec![]))
28    });
29    reg.insert("Resize", resize::resize);
30    reg.insert("NonMaxSuppression", non_max_suppression::non_max_suppression);
31    reg.insert("Multinomial", multinomial::multinomial);
32    array::register_all_ops(reg);
33    cast::register_all_ops(reg);
34    cumsum::register_all_ops(reg);
35    d2s::register_all_ops(reg);
36    fft::register_all_ops(reg);
37    logic::register_all_ops(reg);
38    math::register_all_ops(reg);
39    ml::register_all_ops(reg);
40    nn::register_all_ops(reg);
41    quant::register_all_ops(reg);
42    random::register_all_ops(reg);
43    rec::register_all_ops(reg);
44    s2d::register_all_ops(reg);
45}
46
47fn konst(
48    ctx: &ParsingContext,
49    node: &NodeProto,
50) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
51    let value = if let Some(v) = node.get_attr_opt("value")? {
52        ctx.load_tensor(v)?
53    } else if let Some(i) = node.get_attr_opt::<i64>("value_int")? {
54        tensor0(i)
55    } else if let Some(v) = node.get_attr_opt::<f32>("value_float")? {
56        tensor0(v)
57    } else {
58        bail!("Could not extract value out of Constant node")
59    };
60    Ok((Box::new(tract_hir::ops::konst::Const::new(value.into_arc_tensor())?), vec![]))
61}