Skip to main content

tract_onnx/ops/
math.rs

1use crate::model::OnnxOpRegister;
2use crate::model::ParsingContext;
3use crate::pb::*;
4use tract_hir::internal::*;
5use tract_hir::ops;
6use tract_hir::ops::binary::Nary;
7
8mod clip;
9mod gemm;
10mod mat_mul_integer;
11mod pow;
12mod rem;
13
14pub fn register_all_ops(reg: &mut OnnxOpRegister) {
15    reg.insert("Add", |_, _| Ok((ops::math::Add.into_hir(), vec![])));
16    reg.insert("Sub", |_, _| Ok((ops::math::Sub.into_hir(), vec![])));
17    reg.insert("Mul", |_, _| Ok((ops::math::Mul.into_hir(), vec![])));
18    reg.insert("Div", |_, _| Ok((ops::math::Div.into_hir(), vec![])));
19    reg.insert("Mod", rem::rem);
20
21    reg.insert("BitShift", bitshift);
22    reg.insert("BitwiseAnd", |_, _| Ok((ops::logic::BitAnd.into_hir(), vec![])));
23    reg.insert("BitwiseOr", |_, _| Ok((ops::logic::BitOr.into_hir(), vec![])));
24    reg.insert("BitwiseXor", |_, _| Ok((ops::logic::BitXor.into_hir(), vec![])));
25    reg.insert("BitwiseNot", |_, _| Ok((ops::logic::bitnot().into_hir(), vec![])));
26
27    reg.insert("Sum", |_, _| Ok((Box::new(Nary(Box::new(ops::math::Add), false)), vec![])));
28    reg.insert("Max", |_, _| Ok((Box::new(Nary(Box::new(ops::math::Max), false)), vec![])));
29    reg.insert("Min", |_, _| Ok((Box::new(Nary(Box::new(ops::math::Min), false)), vec![])));
30    reg.insert("Mean", |_, _| Ok((Box::new(Nary(Box::new(ops::math::Add), true)), vec![])));
31
32    reg.insert("Abs", |_, _| Ok((ops::math::abs().into_hir(), vec![])));
33    reg.insert("Ceil", |_, _| Ok((ops::math::ceil().into_hir(), vec![])));
34    reg.insert("Floor", |_, _| Ok((ops::math::floor().into_hir(), vec![])));
35    reg.insert("Round", |_, _| Ok((ops::math::round_half_to_even().into_hir(), vec![])));
36    reg.insert("Clip", clip::clip);
37
38    reg.insert("Cos", |_, _| Ok((ops::math::cos().into_hir(), vec![])));
39    reg.insert("Sin", |_, _| Ok((ops::math::sin().into_hir(), vec![])));
40    reg.insert("Tan", |_, _| Ok((ops::math::tan().into_hir(), vec![])));
41    reg.insert("Acos", |_, _| Ok((ops::math::acos().into_hir(), vec![])));
42    reg.insert("Asin", |_, _| Ok((ops::math::asin().into_hir(), vec![])));
43    reg.insert("Atan", |_, _| Ok((ops::math::atan().into_hir(), vec![])));
44
45    reg.insert("Cosh", |_, _| Ok((ops::math::cosh().into_hir(), vec![])));
46    reg.insert("Sinh", |_, _| Ok((ops::math::sinh().into_hir(), vec![])));
47    reg.insert("Tanh", |_, _| Ok((ops::math::tanh().into_hir(), vec![])));
48    reg.insert("Acosh", |_, _| Ok((ops::math::acosh().into_hir(), vec![])));
49    reg.insert("Asinh", |_, _| Ok((ops::math::asinh().into_hir(), vec![])));
50    reg.insert("Atanh", |_, _| Ok((ops::math::atanh().into_hir(), vec![])));
51
52    reg.insert("Erf", |_, _| Ok((ops::math::erf().into_hir(), vec![])));
53    reg.insert("Exp", |_, _| Ok((ops::math::exp().into_hir(), vec![])));
54    reg.insert("Log", |_, _| Ok((ops::math::ln().into_hir(), vec![])));
55    reg.insert("Sqrt", |_, _| Ok((ops::math::sqrt().into_hir(), vec![])));
56    reg.insert("Rsqrt", |_, _| Ok((ops::math::rsqrt().into_hir(), vec![])));
57
58    reg.insert("IsNaN", |_, _| Ok((tract_onnx_opl::is_nan::is_nan().into_hir(), vec![])));
59    reg.insert("IsInf", isinf);
60    reg.insert("Neg", |_, _| Ok((ops::math::neg().into_hir(), vec![])));
61    reg.insert("Sign", |_, _| Ok((ops::math::sign().into_hir(), vec![])));
62    reg.insert("Reciprocal", |_, _| Ok((ops::math::recip().into_hir(), vec![])));
63
64    reg.insert("Pow", pow::pow);
65
66    reg.insert("MatMul", |_, _| Ok((expand(ops::matmul::MatMulInference::default()), vec![])));
67    reg.insert("MatMulInteger", mat_mul_integer::mat_mul_integer);
68    reg.insert("QLinearMatMul", mat_mul_integer::q_linear_mat_mul);
69    reg.insert("Gemm", gemm::gemm);
70}
71
72fn isinf(
73    _ctx: &ParsingContext,
74    node: &NodeProto,
75) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
76    let detect_positive = node.get_attr_opt("detect_positive")?.unwrap_or(1) != 0;
77    let detect_negative = node.get_attr_opt("detect_negative")?.unwrap_or(1) != 0;
78    Ok((tract_onnx_opl::is_inf::is_inf(detect_positive, detect_negative).into_hir(), vec![]))
79}
80
81fn bitshift(
82    _ctx: &ParsingContext,
83    node: &NodeProto,
84) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
85    let op: Box<dyn InferenceOp> = if node.get_attr_opt("direction")?.unwrap_or("LEFT") == "RIGHT" {
86        ops::math::ShiftRight.into_hir()
87    } else {
88        ops::math::ShiftLeft.into_hir()
89    };
90    Ok((op, vec![]))
91}