1use crate::pb::*;
2use attribute_proto::AttributeType;
3use tract_hir::internal::*;
4
5use tract_num_traits::{AsPrimitive, Bounded};
6
7use std::fmt::{self, Debug, Display};
8use std::str;
9
10use std::convert::TryInto;
11
12pub trait TryCollect<T, E>: Iterator<Item = Result<T, E>> + Sized {
13 fn try_collect<B: Default + Extend<T>>(self) -> Result<B, E> {
14 let mut out = B::default();
15 for item in self {
16 out.extend(Some(item?));
17 }
18 Ok(out)
19 }
20}
21
22impl<T, E, I> TryCollect<T, E> for I where I: Iterator<Item = Result<T, E>> + Sized {}
23
24pub trait Reason {
25 fn reason(&self) -> StaticName;
26}
27
28impl Reason for &'static str {
29 fn reason(&self) -> StaticName {
30 (*self).into()
31 }
32}
33
34impl<F> Reason for F
35where
36 F: Fn() -> String,
37{
38 fn reason(&self) -> StaticName {
39 self().into()
40 }
41}
42
43pub trait OptionExt {
44 type Item;
45
46 fn and_try<F, T>(self, f: F) -> TractResult<Option<T>>
47 where
48 F: Fn(Self::Item) -> TractResult<T>;
49
50 fn and_ok<F, T>(self, f: F) -> TractResult<Option<T>>
51 where
52 F: Fn(Self::Item) -> T;
53}
54
55impl<A> OptionExt for Option<A> {
56 type Item = A;
57
58 fn and_try<F, T>(self, f: F) -> TractResult<Option<T>>
59 where
60 F: Fn(Self::Item) -> TractResult<T>,
61 {
62 match self {
63 Some(attr) => f(attr).map(Some),
64 None => Ok(None),
65 }
66 }
67
68 fn and_ok<F, T>(self, f: F) -> TractResult<Option<T>>
69 where
70 F: Fn(Self::Item) -> T,
71 {
72 Ok(self.map(f))
73 }
74}
75
76impl Display for attribute_proto::AttributeType {
77 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78 f.write_str(match self {
79 AttributeType::Int => "int",
80 AttributeType::Float => "float",
81 AttributeType::Tensor => "tensor",
82 AttributeType::String => "string",
83 AttributeType::Ints => "list of ints",
84 AttributeType::Floats => "list of floats",
85 AttributeType::Tensors => "list of tensors",
86 AttributeType::Strings => "list of strings",
87 AttributeType::Graph => "graph",
88 AttributeType::Graphs => "graphs",
89 _ => "<undefined>",
90 })
91 }
92}
93
94pub trait AttrScalarType<'a>: 'a + Sized {
95 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>>;
96}
97
98impl<'a> AttrScalarType<'a> for DatumType {
99 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
100 i32::get_attr_opt_scalar(node, name)?
101 .map(tensor_proto::DataType::from_i32)
102 .map(|d| d.unwrap().try_into())
103 .transpose()
104 }
105}
106
107impl<'a> AttrScalarType<'a> for &'a TensorProto {
108 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
109 Ok(node
110 .get_attr_opt_with_type(name, AttributeType::Tensor)?
111 .map(|attr| attr.t.as_ref().unwrap()))
112 }
113}
114
115impl<'a> AttrScalarType<'a> for &'a [u8] {
116 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
117 Ok(node.get_attr_opt_with_type(name, AttributeType::String)?.map(|attr| &*attr.s))
118 }
119}
120
121impl<'a> AttrScalarType<'a> for &'a str {
122 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
123 let bytes: Option<&[u8]> = AttrScalarType::get_attr_opt_scalar(node, name)?;
124 bytes.and_try(|b| str::from_utf8(b).map_err(Into::into))
125 }
126}
127
128impl<'a> AttrScalarType<'a> for String {
129 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
130 let string: Option<&'a str> = AttrScalarType::get_attr_opt_scalar(node, name)?;
131 string.and_ok(Into::into)
132 }
133}
134
135impl<'a> AttrScalarType<'a> for i64 {
136 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
137 node.get_attr_opt_with_type(name, AttributeType::Int)?.and_ok(|a| a.i)
138 }
139}
140
141impl<'a> AttrScalarType<'a> for bool {
142 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
143 let int: Option<i64> = AttrScalarType::get_attr_opt_scalar(node, name)?;
144 int.and_try(|int| {
145 node.expect_attr(name, int == 0 || int == 1, "boolean (0 or 1)")?;
146 Ok(int == 1)
147 })
148 }
149}
150
151impl<'a> AttrScalarType<'a> for usize {
152 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
153 let int: Option<i64> = AttrScalarType::get_attr_opt_scalar(node, name)?;
154 int.and_try(|int| {
155 node.expect_attr(name, int >= 0, "non-negative int")?;
156 Ok(int as _)
157 })
158 }
159}
160
161impl<'a> AttrScalarType<'a> for &'a GraphProto {
162 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
163 node.get_attr_opt_with_type(name, AttributeType::Graph)?.and_ok(|a| a.g.as_ref().unwrap())
164 }
165}
166
167fn check_int<T>(node: &NodeProto, attr: &str, int: i64, is_list: bool) -> TractResult<T>
168where
169 T: AsPrimitive<i64> + Bounded + Display,
170 i64: AsPrimitive<T>,
171{
172 let desc = if is_list { "list of ints" } else { "int" };
173 node.expect_attr(attr, int <= T::max_value().as_(), || {
174 format!("{} <= {}, got {}", desc, T::max_value(), int)
175 })?;
176 node.expect_attr(attr, int >= T::min_value().as_(), || {
177 format!("{} >= {}, got {}", desc, T::min_value(), int)
178 })?;
179 Ok(int.as_())
180}
181
182macro_rules! impl_attr_scalar_type_int {
183 ($ty:ident) => {
184 impl<'a> AttrScalarType<'a> for $ty {
185 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
186 AttrScalarType::get_attr_opt_scalar(node, name)?
187 .and_try(|int| check_int(node, name, int, false))
188 }
189 }
190
191 impl<'a> AttrTVecType<'a> for $ty {
192 fn get_attr_opt_tvec(
193 node: &'a NodeProto,
194 name: &str,
195 ) -> TractResult<Option<TVec<Self>>> {
196 AttrTVecType::get_attr_opt_tvec(node, name)?.and_try(|ints| {
197 ints.into_iter().map(|int| check_int(node, name, int, true)).try_collect()
198 })
199 }
200 }
201 };
202}
203
204impl_attr_scalar_type_int!(i8);
205impl_attr_scalar_type_int!(i16);
206impl_attr_scalar_type_int!(i32);
207impl_attr_scalar_type_int!(isize);
208
209impl<'a> AttrScalarType<'a> for f32 {
210 fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
211 node.get_attr_opt_with_type(name, AttributeType::Float)?.and_ok(|x| x.f)
212 }
213}
214
215pub trait AttrSliceType<'a>: 'a + Sized {
216 fn get_attr_opt_slice(node: &'a NodeProto, name: &str) -> TractResult<Option<&'a [Self]>>;
217}
218
219impl<'a> AttrSliceType<'a> for Vec<u8> {
220 fn get_attr_opt_slice(node: &'a NodeProto, name: &str) -> TractResult<Option<&'a [Self]>> {
221 node.get_attr_opt_with_type(name, AttributeType::Strings)?.and_ok(|x| &*x.strings)
222 }
223}
224
225impl<'a> AttrSliceType<'a> for i64 {
226 fn get_attr_opt_slice(node: &'a NodeProto, name: &str) -> TractResult<Option<&'a [Self]>> {
227 node.get_attr_opt_with_type(name, AttributeType::Ints)?.and_ok(|a| &*a.ints)
228 }
229}
230
231impl<'a> AttrSliceType<'a> for f32 {
232 fn get_attr_opt_slice(node: &'a NodeProto, name: &str) -> TractResult<Option<&'a [Self]>> {
233 node.get_attr_opt_with_type(name, AttributeType::Floats)?.and_ok(|a| &*a.floats)
234 }
235}
236
237pub trait AttrTVecType<'a>: 'a + Sized {
238 fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>>;
239}
240
241impl<'a, T> AttrTVecType<'a> for T
242where
243 T: AttrSliceType<'a> + Clone,
244{
245 fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
246 T::get_attr_opt_slice(node, name)?.and_ok(Into::into)
247 }
248}
249
250impl<'a> AttrTVecType<'a> for &'a str {
251 fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
252 <Vec<u8>>::get_attr_opt_slice(node, name)?
253 .and_try(|b| b.iter().map(|v| str::from_utf8(v)).try_collect().map_err(Into::into))
254 }
255}
256
257impl<'a> AttrTVecType<'a> for String {
258 fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
259 <Vec<u8>>::get_attr_opt_slice(node, name)?.and_try(|b| {
260 b.iter().map(|v| str::from_utf8(v).map(Into::into)).try_collect().map_err(Into::into)
261 })
262 }
263}
264
265impl<'a> AttrTVecType<'a> for bool {
266 fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
267 let ints: Option<&[i64]> = AttrSliceType::get_attr_opt_slice(node, name)?;
268 ints.and_try(|ints| {
269 for int in ints.iter() {
270 node.expect_attr(name, *int == 0 || *int == 1, "list of booleans (0 or 1)")?;
271 }
272 Ok(ints.iter().map(|&x| x == 1).collect())
273 })
274 }
275}
276
277impl<'a> AttrTVecType<'a> for usize {
278 fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
279 let ints: Option<&[i64]> = AttrSliceType::get_attr_opt_slice(node, name)?;
280 ints.and_try(|ints| {
281 for int in ints.iter() {
282 node.expect_attr(name, *int >= 0, "list of non-negative ints")?;
283 }
284 Ok(ints.iter().map(|&x| x as _).collect())
285 })
286 }
287}
288
289impl NodeProto {
290 pub fn bail<T>(&self, msg: &str) -> TractResult<T> {
291 bail!("Node {} ({}): {}", self.name, self.op_type, msg)
292 }
293
294 pub fn bail_attr<T>(&self, attr: &str, msg: &str) -> TractResult<T> {
295 bail!("Node {} ({}), attribute '{}': {}", self.name, self.op_type, attr, msg)
296 }
297
298 pub fn expect<R: Reason>(&self, cond: bool, what: R) -> TractResult<()> {
299 if !cond {
300 self.bail(&format!("expected {}", what.reason()))
301 } else {
302 Ok(())
303 }
304 }
305
306 pub fn expect_attr<R: Reason>(&self, attr: &str, cond: bool, what: R) -> TractResult<()> {
307 if !cond {
308 self.bail_attr(attr, &format!("expected {}", what.reason()))
309 } else {
310 Ok(())
311 }
312 }
313
314 pub fn expect_ok_or_else<T, R: Reason>(&self, result: Option<T>, what: R) -> TractResult<T> {
315 match result {
316 Some(v) => Ok(v),
317 None => Err(self.expect(false, what).unwrap_err()),
318 }
319 }
320
321 fn get_attr_opt_with_type(
322 &self,
323 name: &str,
324 ty: AttributeType,
325 ) -> TractResult<Option<&AttributeProto>> {
326 let attr = match self.attribute.iter().find(|a| a.name == name) {
327 Some(attr) => attr,
328 _ => return Ok(None),
329 };
330 self.expect_attr(name, AttributeType::from_i32(attr.r#type).unwrap() == ty, || {
331 format!("{}, got {}", ty, attr.r#type)
332 })?;
333 Ok(Some(attr))
334 }
335
336 pub fn get_attr_opt<'a, T>(&'a self, name: &str) -> TractResult<Option<T>>
337 where
338 T: AttrScalarType<'a>,
339 {
340 T::get_attr_opt_scalar(self, name)
341 }
342
343 pub fn get_attr<'a, T>(&'a self, name: &str) -> TractResult<T>
344 where
345 T: AttrScalarType<'a>,
346 {
347 self.expect_ok_or_else(self.get_attr_opt(name)?, || format!("attribute '{name}'"))
348 }
349
350 pub fn check_value<T, V: Debug>(&self, attr: &str, value: Result<T, V>) -> TractResult<T> {
351 match value {
352 Ok(value) => Ok(value),
353 Err(err) => self.bail_attr(attr, &format!("unexpected value: {err:?}")),
354 }
355 }
356
357 pub fn get_attr_opt_slice<'a, T>(&'a self, name: &str) -> TractResult<Option<&'a [T]>>
358 where
359 T: AttrSliceType<'a>,
360 {
361 T::get_attr_opt_slice(self, name)
362 }
363
364 pub fn get_attr_slice<'a, T>(&'a self, name: &str) -> TractResult<&'a [T]>
365 where
366 T: AttrSliceType<'a>,
367 {
368 self.expect_ok_or_else(self.get_attr_opt_slice(name)?, || format!("attribute '{name}'"))
369 }
370
371 pub fn get_attr_opt_tvec<'a, T>(&'a self, name: &str) -> TractResult<Option<TVec<T>>>
372 where
373 T: AttrTVecType<'a>,
374 {
375 T::get_attr_opt_tvec(self, name)
376 }
377
378 pub fn get_attr_tvec<'a, T>(&'a self, name: &str) -> TractResult<TVec<T>>
379 where
380 T: AttrTVecType<'a>,
381 {
382 self.expect_ok_or_else(self.get_attr_opt_tvec(name)?, || format!("attribute '{name}'"))
383 }
384
385 pub fn get_attr_opt_vec<'a, T>(&'a self, name: &str) -> TractResult<Option<Vec<T>>>
386 where
387 T: AttrTVecType<'a>,
388 {
389 Ok(self.get_attr_opt_tvec(name)?.map(TVec::into_vec))
390 }
391
392 pub fn get_attr_vec<'a, T>(&'a self, name: &str) -> TractResult<Vec<T>>
393 where
394 T: AttrTVecType<'a>,
395 {
396 self.get_attr_tvec(name).map(TVec::into_vec)
397 }
398}