tract_tensorflow/ops/nn/
mod.rs1use tract_hir::internal::*;
2use tract_hir::ops::cnn::PaddingSpec;
3use tract_hir::ops::nn::{DataFormat, LayerSoftmax};
4
5use crate::model::TfOpRegister;
6use crate::tfpb::tensorflow::NodeDef;
7
8pub mod conv2d;
9pub mod dw_conv2d;
10pub mod fused_batch_norm;
11pub mod pools;
12pub mod s2b;
13
14pub fn register_all_ops(reg: &mut TfOpRegister) {
15 reg.insert("AvgPool", pools::avgpool);
16 reg.insert("Conv2D", conv2d::conv2d);
17 reg.insert("DepthwiseConv2dNative", dw_conv2d::depthwise_conv2d);
18 reg.insert("FusedBatchNorm", fused_batch_norm::fused_batch_norm);
19 reg.insert("MaxPool", pools::maxpool);
20 reg.insert("Relu", |_, _| Ok(expand(tract_hir::ops::activations::Clip::new(Some(0.0), None))));
21 reg.insert("Relu6", |_, _| {
22 Ok(expand(tract_hir::ops::activations::Clip::new(Some(0.0), Some(6.0))))
23 });
24 reg.insert("Sigmoid", |_, _| Ok(tract_hir::ops::nn::sigmoid().into_hir()));
25 reg.insert("Softmax", |_, _| Ok(expand(LayerSoftmax::new(1, true))));
26 reg.insert("SpaceToBatchND", s2b::space_to_batch_nd);
27 reg.insert("BatchToSpaceND", s2b::batch_to_space_nd);
28}
29
30pub fn strides(pb: &NodeDef) -> TractResult<Vec<usize>> {
31 let strides: Vec<usize> = pb.get_attr_list_int("strides")?;
32 if strides.len() != 4 || strides[0] != 1 && strides[3] != 1 {
33 bail!("strides must be of the form [1, h, v, 1], found {:?}", strides)
34 };
35 Ok(strides)
36}
37
38pub fn data_format(pb: &NodeDef) -> TractResult<DataFormat> {
39 let df = if pb.get_attr_opt_raw_str("data_format")?.unwrap_or(b"NHWC") == b"NHWC" {
40 DataFormat::NHWC
41 } else {
42 DataFormat::NCHW
43 };
44 Ok(df)
45}
46
47pub fn padding(pb: &NodeDef) -> TractResult<PaddingSpec> {
48 let padding = pb.get_attr_raw_str("padding")?;
49 match padding {
50 b"VALID" => Ok(PaddingSpec::Valid),
51 b"SAME" => Ok(PaddingSpec::SameUpper),
52 s => bail!("unsupported Padding {}", String::from_utf8_lossy(s)),
53 }
54}