1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
use tract_core::ops as tractops;

use crate::ops::OpRegister;
use crate::tfpb::node_def::NodeDef;
use tract_core::TractResult;

pub fn register_all_ops(reg: &mut OpRegister) {
    reg.insert("Abs", with_T!(tractops::math::Abs));
    reg.insert("Add", with_T!(tractops::math::Add::Bin));
    reg.insert("AddN", add_n);
    reg.insert("BiasAdd", with_T!(tractops::math::Add::Bin));
    reg.insert("Ceil", with_T!(tractops::math::Ceil));
    reg.insert("Div", with_T!(tractops::math::Div::Bin));
    reg.insert("FloorMod", with_T!(tractops::math::Rem::Bin));
    reg.insert("MatMul", mat_mul);
    reg.insert("Maximum", with_T!(tractops::math::Max::Bin));
    reg.insert("Minimum", with_T!(tractops::math::Min::Bin));
    reg.insert("Less", with_T!(tractops::logic::Lesser::Bin));
    reg.insert("Log", with_T!(tractops::math::Ln));
    reg.insert("Mul", with_T!(tractops::math::Mul::Bin));
    reg.insert("Pow", with_T!(tractops::math::Pow::Bin));
    reg.insert("Neg", with_T!(tractops::math::Neg));
    reg.insert("RealDiv", with_T!(tractops::math::Div::Bin));
    reg.insert("Rsqrt", with_T!(tractops::math::Rsqrt));
    reg.insert("Sub", with_T!(tractops::math::Sub::Bin));
    reg.insert("Tanh", with_T!(tractops::math::Tanh));
}

pub fn add_n(pb: &NodeDef) -> TractResult<Box<tractops::Op>> {
    let dtype = pb.get_attr_datum_type("T")?;
    let n = pb.get_attr_int("N")?;
    Ok(Box::new(tractops::math::AddN::new(dtype.into(), Some(n))))
}

pub fn mat_mul(pb: &NodeDef) -> TractResult<Box<tractops::Op>> {
    let trans_a = pb.get_attr_bool("transpose_a")?;
    let trans_b = pb.get_attr_bool("transpose_b")?;
    Ok(Box::new(tract_core::ops::math::Gemm::new(
        1.0, 0.0, trans_a, trans_b, false,
    )))
}