tract_tensorflow/ops/array/
mod.rs1use crate::model::TfOpRegister;
2use tract_hir::internal::*;
3
4use crate::model::ParsingContext;
5use crate::tfpb::tensorflow::NodeDef;
6use tract_core::ops::array::StridedSlice;
7
8mod concatv2;
9mod expand_dims;
10mod fill;
11mod gather_nd;
12mod gather_v2;
13mod pack;
14mod pad;
15mod squeeze;
16mod transpose;
17
18pub fn register_all_ops(reg: &mut TfOpRegister) {
19 reg.insert("ConcatV2", concatv2::build);
20 reg.insert("ExpandDims", expand_dims::build);
21 reg.insert("Fill", fill::fill);
22 reg.insert("GatherNd", gather_nd::gather_nd);
23 reg.insert("GatherV2", gather_v2::gather_v2);
24 reg.insert("Pack", pack::pack);
25 reg.insert("Pad", pad::pad);
26 reg.insert("Range", |_, _| Ok(expand(tract_hir::ops::array::Range)));
27 reg.insert("Reshape", |_, _| Ok(expand(tract_hir::ops::array::Reshape::new())));
28 reg.insert("Shape", |_, _| Ok(expand(tract_hir::ops::array::Shape::new(DatumType::TDim))));
29 reg.insert("Slice", slice);
30 reg.insert("Squeeze", squeeze::squeeze);
31 reg.insert("StridedSlice", strided_slice);
32 reg.insert("Tile", |_, _| Ok(expand(::tract_hir::ops::array::Tile)));
33 reg.insert("Transpose", transpose::transpose);
34}
35
36fn strided_slice(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
37 let begin_mask = pb.get_attr_opt_int("begin_mask")?.unwrap_or(0);
38 let end_mask = pb.get_attr_opt_int("end_mask")?.unwrap_or(0);
39 let shrink_axis_mask = pb.get_attr_opt_int("shrink_axis_mask")?.unwrap_or(0);
40 Ok(Box::new(StridedSlice {
41 begin_mask,
42 end_mask,
43 shrink_axis_mask,
44 optional_axes_input: None,
45 optional_steps_input: Some(3),
46 }))
47}
48
49fn slice(_ctx: &ParsingContext, _pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
50 Ok(Box::new(StridedSlice {
51 optional_axes_input: None,
52 optional_steps_input: None,
53 begin_mask: 0,
54 end_mask: 0,
55 shrink_axis_mask: 0,
56 }))
57}