1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
use tract_hir::internal::*;
use tract_hir::ops;

use crate::model::ParsingContext;
use crate::model::TfOpRegister;
use crate::tfpb::tensorflow::NodeDef;

mod reduce;

pub fn register_all_ops(reg: &mut TfOpRegister) {
    reg.insert("Abs", |_, _| Ok(Box::new(ops::math::abs())));
    reg.insert("Add", |_, _| Ok(Box::new(ops::math::add::bin())));
    reg.insert("AddN", add_n);
    reg.insert("AddV2", |_, _| Ok(Box::new(ops::math::add::bin())));
    reg.insert("BiasAdd", |_, _| Ok(Box::new(ops::math::add::bin())));
    reg.insert("Ceil", |_, _| Ok(Box::new(ops::math::ceil())));
    reg.insert("Div", |_, _| Ok(Box::new(ops::math::div::bin())));
    reg.insert("FloorMod", |_, _| Ok(Box::new(ops::math::rem::bin())));
    reg.insert("MatMul", mat_mul);
    reg.insert("Max", reduce::max);
    reg.insert("Mean", reduce::mean);
    reg.insert("Min", reduce::min);
    reg.insert("Prod", reduce::prod);
    reg.insert("Sum", reduce::sum);
    reg.insert("Maximum", |_, _| Ok(Box::new(ops::math::max::bin())));
    reg.insert("Minimum", |_, _| Ok(Box::new(ops::math::min::bin())));
    reg.insert("Less", |_, _| Ok(Box::new(ops::logic::lesser::bin())));
    reg.insert("Log", |_, _| Ok(Box::new(ops::math::ln())));
    reg.insert("Mul", |_, _| Ok(Box::new(ops::math::mul::bin())));
    reg.insert("Pow", |_, _| Ok(Box::new(ops::math::pow::bin())));
    reg.insert("Neg", |_, _| Ok(Box::new(ops::math::neg())));
    reg.insert("RealDiv", |_, _| Ok(Box::new(ops::math::div::bin())));
    reg.insert("Rsqrt", |_, _| Ok(Box::new(ops::math::rsqrt())));
    reg.insert("Sub", |_, _| Ok(Box::new(ops::math::sub::bin())));
    reg.insert("Tanh", |_, _| Ok(Box::new(ops::math::tanh())));
}

pub fn add_n(_ctx: &ParsingContext, _pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
    Ok(Box::new(ops::binary::Nary(Box::new(ops::math::Add), false)))
}

pub fn mat_mul(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
    let trans_a = pb.get_attr_bool("transpose_a")?;
    let trans_b = pb.get_attr_bool("transpose_b")?;
    Ok(Box::new(ops::matmul::MatMul::default().with_a_trans(trans_a).with_b_trans(trans_b)))
}