tract_tensorflow/ops/
math.rs

1use tract_hir::internal::*;
2use tract_hir::ops;
3
4use crate::model::ParsingContext;
5use crate::model::TfOpRegister;
6use crate::tfpb::tensorflow::NodeDef;
7
8mod reduce;
9
10pub fn register_all_ops(reg: &mut TfOpRegister) {
11    reg.insert("Abs", |_, _| Ok(ops::math::abs().into_hir()));
12    reg.insert("Add", |_, _| Ok(ops::math::Add.into_hir()));
13    reg.insert("AddN", add_n);
14    reg.insert("AddV2", |_, _| Ok(ops::math::Add.into_hir()));
15    reg.insert("BiasAdd", |_, _| Ok(ops::math::Add.into_hir()));
16    reg.insert("Ceil", |_, _| Ok(ops::math::ceil().into_hir()));
17    reg.insert("Div", |_, _| Ok(ops::math::Div.into_hir()));
18    reg.insert("FloorMod", |_, _| Ok(ops::math::Rem.into_hir()));
19    reg.insert("MatMul", mat_mul);
20    reg.insert("Max", reduce::max);
21    reg.insert("Mean", reduce::mean);
22    reg.insert("Min", reduce::min);
23    reg.insert("Prod", reduce::prod);
24    reg.insert("Sum", reduce::sum);
25    reg.insert("Maximum", |_, _| Ok(ops::math::Max.into_hir()));
26    reg.insert("Minimum", |_, _| Ok(ops::math::Min.into_hir()));
27    reg.insert("Log", |_, _| Ok(ops::math::ln().into_hir()));
28    reg.insert("Mul", |_, _| Ok(ops::math::Mul.into_hir()));
29    reg.insert("Pow", |_, _| Ok(ops::math::Pow.into_hir()));
30    reg.insert("Neg", |_, _| Ok(ops::math::neg().into_hir()));
31    reg.insert("RealDiv", |_, _| Ok(ops::math::Div.into_hir()));
32    reg.insert("Rsqrt", |_, _| Ok(ops::math::rsqrt().into_hir()));
33    reg.insert("Sub", |_, _| Ok(ops::math::Sub.into_hir()));
34    reg.insert("Tanh", |_, _| Ok(ops::math::tanh().into_hir()));
35}
36
37pub fn add_n(_ctx: &ParsingContext, _pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
38    Ok(Box::new(ops::binary::Nary(Box::new(ops::math::Add), false)))
39}
40
41pub fn mat_mul(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
42    let trans_a = pb.get_attr_bool("transpose_a")?;
43    let trans_b = pb.get_attr_bool("transpose_b")?;
44    Ok(expand(ops::matmul::MatMulInference::default().with_a_trans(trans_a).with_b_trans(trans_b)))
45}