1mod compress;
2mod nonzero;
3mod one_hot;
4mod pad;
5mod shape;
6mod slice;
7mod split;
8mod squeeze;
9mod topk;
10mod trilu;
11mod unsqueeze;
12
13use tract_hir::internal::*;
14use tract_hir::ops::array;
15
16use crate::model::{OnnxOpRegister, ParsingContext};
17use crate::pb::*;
18
19pub fn register_all_ops(reg: &mut OnnxOpRegister) {
20 reg.insert("ArrayFeatureExtractor", array_feature_extractor);
21 reg.insert("Compress", compress::compress);
22 reg.insert("Concat", concat);
23 reg.insert("ConstantLike", constant_like);
24 reg.insert("ConstantOfShape", constant_of_shape);
25 reg.insert("Expand", |_, _| Ok((expand(array::MultiBroadcastTo), vec![])));
26 reg.insert("EyeLike", eye_like);
27 reg.insert("Flatten", flatten);
28 reg.insert("Gather", gather);
29 reg.insert("GatherElements", gather_elements);
30 reg.insert("GatherND", gather_nd);
31 reg.insert("NonZero", nonzero::non_zero);
32 reg.insert("OneHot", one_hot::one_hot);
33 reg.insert("Range", |_, _| Ok((expand(array::Range), vec![])));
34 reg.insert("Pad", pad::pad);
35 reg.insert("Reshape", |_, _| Ok((expand(array::Reshape::default()), vec![])));
36 reg.insert("Scatter", scatter_elements);
37 reg.insert("ScatterElements", scatter_elements);
38 reg.insert("ScatterND", |_, _| Ok((Box::new(array::ScatterNd), vec![])));
39 reg.insert("Shape", shape::shape);
40 reg.insert("Size", |_, _| Ok((expand(array::Size::new(DatumType::TDim)), vec![])));
41 reg.insert("Slice", slice::slice);
42 reg.insert("Split", split::split);
43 reg.insert("Squeeze", squeeze::squeeze);
44 reg.insert("Tile", |_, _| Ok((expand(array::Tile), vec![])));
45 reg.insert("TopK", topk::topk);
46 reg.insert("Transpose", transpose);
47 reg.insert("Trilu", trilu::trilu);
48 reg.insert("Unsqueeze", unsqueeze::unsqueeze);
49}
50
51pub fn array_feature_extractor(
52 _ctx: &ParsingContext,
53 _node: &NodeProto,
54) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
55 Ok((expand(array::ArrayFeatureExtractor), vec![]))
56}
57
58pub fn concat(
59 _ctx: &ParsingContext,
60 node: &NodeProto,
61) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
62 let axis = node.get_attr("axis")?;
63 Ok((expand(array::Concat::new(axis)), vec![]))
64}
65
66pub fn constant_like(
67 _ctx: &ParsingContext,
68 node: &NodeProto,
69) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
70 let value = node.get_attr_opt("value")?.unwrap_or(0.);
71 if node.input.len() == 0 {
72 let dt = node.get_attr_opt("dtype")?.unwrap_or(DatumType::F32);
73 let shape: Vec<usize> = node.get_attr_vec("shape")?;
74 let tensor =
75 tensor0(value).cast_to_dt(dt)?.broadcast_scalar_to_shape(&shape)?.into_arc_tensor();
76 Ok((Box::new(tract_hir::ops::konst::Const::new(tensor)?), vec![]))
77 } else {
78 Ok((Box::new(array::ConstantLike::new(value)), vec![]))
79 }
80}
81
82pub fn constant_of_shape(
83 ctx: &ParsingContext,
84 node: &NodeProto,
85) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
86 let mut value = match node.get_attr_opt("value")? {
87 Some(val) => ctx.load_tensor(val)?.into_arc_tensor(),
88 None => rctensor0(0.0),
89 };
90 if value.rank() > 0 {
91 if value.len() != 1 {
92 bail!("Expected scalar (or vector of length 1), got {:?}", value);
93 }
94 value = value.nth(0)?.into_arc_tensor();
95 }
96 Ok((expand(array::ConstantOfShape::new(value)), vec![]))
97}
98
99pub fn eye_like(
100 _ctx: &ParsingContext,
101 node: &NodeProto,
102) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
103 let dt = node.get_attr_opt("dtype")?;
104 let k = node.get_attr_opt("k")?.unwrap_or(0);
105 Ok((Box::new(array::EyeLike::new(dt, k)), vec![]))
106}
107
108pub fn flatten(
109 _ctx: &ParsingContext,
110 node: &NodeProto,
111) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
112 let axis: i64 = node.get_attr_opt("axis")?.unwrap_or(1);
113 Ok((expand(array::Flatten::new(axis)), vec![]))
114}
115
116pub fn gather(
117 _ctx: &ParsingContext,
118 node: &NodeProto,
119) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
120 let axis = node.get_attr_opt("axis")?.unwrap_or(0);
121 Ok((expand(array::Gather::new(axis)), vec![]))
122}
123
124pub fn gather_elements(
125 _ctx: &ParsingContext,
126 node: &NodeProto,
127) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
128 let axis = node.get_attr_opt("axis")?.unwrap_or(0);
129 Ok((expand(array::GatherElements::new(axis)), vec![]))
130}
131
132pub fn gather_nd(
133 _ctx: &ParsingContext,
134 node: &NodeProto,
135) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
136 let batch_dims = node.get_attr_opt("batch_dims")?.unwrap_or(0);
137 Ok((Box::new(array::GatherNd::new(batch_dims)), vec![]))
138}
139
140pub fn scatter_elements(
141 _ctx: &ParsingContext,
142 node: &NodeProto,
143) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
144 let axis = node.get_attr_opt("axis")?.unwrap_or(0);
145 Ok((expand(array::ScatterElements::new(axis)), vec![]))
146}
147
148pub fn transpose(
149 _ctx: &ParsingContext,
150 node: &NodeProto,
151) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
152 let perm = node.get_attr_opt_vec("perm")?;
153 Ok((expand(array::PermuteAxes::new(perm.map(|t| t.into()))), vec![]))
154}