tract_onnx/
pb_helpers.rs

1use crate::pb::*;
2use attribute_proto::AttributeType;
3use tract_hir::internal::*;
4
5use tract_num_traits::{AsPrimitive, Bounded};
6
7use std::fmt::{self, Debug, Display};
8use std::str;
9
10use std::convert::TryInto;
11
12pub trait TryCollect<T, E>: Iterator<Item = Result<T, E>> + Sized {
13    fn try_collect<B: Default + Extend<T>>(self) -> Result<B, E> {
14        let mut out = B::default();
15        for item in self {
16            out.extend(Some(item?));
17        }
18        Ok(out)
19    }
20}
21
22impl<T, E, I> TryCollect<T, E> for I where I: Iterator<Item = Result<T, E>> + Sized {}
23
24pub trait Reason {
25    fn reason(&self) -> StaticName;
26}
27
28impl Reason for &'static str {
29    fn reason(&self) -> StaticName {
30        (*self).into()
31    }
32}
33
34impl<F> Reason for F
35where
36    F: Fn() -> String,
37{
38    fn reason(&self) -> StaticName {
39        self().into()
40    }
41}
42
43pub trait OptionExt {
44    type Item;
45
46    fn and_try<F, T>(self, f: F) -> TractResult<Option<T>>
47    where
48        F: Fn(Self::Item) -> TractResult<T>;
49
50    fn and_ok<F, T>(self, f: F) -> TractResult<Option<T>>
51    where
52        F: Fn(Self::Item) -> T;
53}
54
55impl<A> OptionExt for Option<A> {
56    type Item = A;
57
58    fn and_try<F, T>(self, f: F) -> TractResult<Option<T>>
59    where
60        F: Fn(Self::Item) -> TractResult<T>,
61    {
62        match self {
63            Some(attr) => f(attr).map(Some),
64            None => Ok(None),
65        }
66    }
67
68    fn and_ok<F, T>(self, f: F) -> TractResult<Option<T>>
69    where
70        F: Fn(Self::Item) -> T,
71    {
72        Ok(self.map(f))
73    }
74}
75
76impl Display for attribute_proto::AttributeType {
77    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78        f.write_str(match self {
79            AttributeType::Int => "int",
80            AttributeType::Float => "float",
81            AttributeType::Tensor => "tensor",
82            AttributeType::String => "string",
83            AttributeType::Ints => "list of ints",
84            AttributeType::Floats => "list of floats",
85            AttributeType::Tensors => "list of tensors",
86            AttributeType::Strings => "list of strings",
87            AttributeType::Graph => "graph",
88            AttributeType::Graphs => "graphs",
89            _ => "<undefined>",
90        })
91    }
92}
93
94pub trait AttrScalarType<'a>: 'a + Sized {
95    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>>;
96}
97
98impl<'a> AttrScalarType<'a> for DatumType {
99    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
100        i32::get_attr_opt_scalar(node, name)?
101            .map(tensor_proto::DataType::from_i32)
102            .map(|d| d.unwrap().try_into())
103            .transpose()
104    }
105}
106
107impl<'a> AttrScalarType<'a> for &'a TensorProto {
108    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
109        Ok(node
110            .get_attr_opt_with_type(name, AttributeType::Tensor)?
111            .map(|attr| attr.t.as_ref().unwrap()))
112    }
113}
114
115impl<'a> AttrScalarType<'a> for &'a [u8] {
116    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
117        Ok(node.get_attr_opt_with_type(name, AttributeType::String)?.map(|attr| &*attr.s))
118    }
119}
120
121impl<'a> AttrScalarType<'a> for &'a str {
122    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
123        let bytes: Option<&[u8]> = AttrScalarType::get_attr_opt_scalar(node, name)?;
124        bytes.and_try(|b| str::from_utf8(b).map_err(Into::into))
125    }
126}
127
128impl<'a> AttrScalarType<'a> for String {
129    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
130        let string: Option<&'a str> = AttrScalarType::get_attr_opt_scalar(node, name)?;
131        string.and_ok(Into::into)
132    }
133}
134
135impl<'a> AttrScalarType<'a> for i64 {
136    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
137        node.get_attr_opt_with_type(name, AttributeType::Int)?.and_ok(|a| a.i)
138    }
139}
140
141impl<'a> AttrScalarType<'a> for bool {
142    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
143        let int: Option<i64> = AttrScalarType::get_attr_opt_scalar(node, name)?;
144        int.and_try(|int| {
145            node.expect_attr(name, int == 0 || int == 1, "boolean (0 or 1)")?;
146            Ok(int == 1)
147        })
148    }
149}
150
151impl<'a> AttrScalarType<'a> for usize {
152    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
153        let int: Option<i64> = AttrScalarType::get_attr_opt_scalar(node, name)?;
154        int.and_try(|int| {
155            node.expect_attr(name, int >= 0, "non-negative int")?;
156            Ok(int as _)
157        })
158    }
159}
160
161impl<'a> AttrScalarType<'a> for &'a GraphProto {
162    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
163        node.get_attr_opt_with_type(name, AttributeType::Graph)?.and_ok(|a| a.g.as_ref().unwrap())
164    }
165}
166
167fn check_int<T>(node: &NodeProto, attr: &str, int: i64, is_list: bool) -> TractResult<T>
168where
169    T: AsPrimitive<i64> + Bounded + Display,
170    i64: AsPrimitive<T>,
171{
172    let desc = if is_list { "list of ints" } else { "int" };
173    node.expect_attr(attr, int <= T::max_value().as_(), || {
174        format!("{} <= {}, got {}", desc, T::max_value(), int)
175    })?;
176    node.expect_attr(attr, int >= T::min_value().as_(), || {
177        format!("{} >= {}, got {}", desc, T::min_value(), int)
178    })?;
179    Ok(int.as_())
180}
181
182macro_rules! impl_attr_scalar_type_int {
183    ($ty:ident) => {
184        impl<'a> AttrScalarType<'a> for $ty {
185            fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
186                AttrScalarType::get_attr_opt_scalar(node, name)?
187                    .and_try(|int| check_int(node, name, int, false))
188            }
189        }
190
191        impl<'a> AttrTVecType<'a> for $ty {
192            fn get_attr_opt_tvec(
193                node: &'a NodeProto,
194                name: &str,
195            ) -> TractResult<Option<TVec<Self>>> {
196                AttrTVecType::get_attr_opt_tvec(node, name)?.and_try(|ints| {
197                    ints.into_iter().map(|int| check_int(node, name, int, true)).try_collect()
198                })
199            }
200        }
201    };
202}
203
204impl_attr_scalar_type_int!(i8);
205impl_attr_scalar_type_int!(i16);
206impl_attr_scalar_type_int!(i32);
207impl_attr_scalar_type_int!(isize);
208
209impl<'a> AttrScalarType<'a> for f32 {
210    fn get_attr_opt_scalar(node: &'a NodeProto, name: &str) -> TractResult<Option<Self>> {
211        node.get_attr_opt_with_type(name, AttributeType::Float)?.and_ok(|x| x.f)
212    }
213}
214
215pub trait AttrSliceType<'a>: 'a + Sized {
216    fn get_attr_opt_slice(node: &'a NodeProto, name: &str) -> TractResult<Option<&'a [Self]>>;
217}
218
219impl<'a> AttrSliceType<'a> for Vec<u8> {
220    fn get_attr_opt_slice(node: &'a NodeProto, name: &str) -> TractResult<Option<&'a [Self]>> {
221        node.get_attr_opt_with_type(name, AttributeType::Strings)?.and_ok(|x| &*x.strings)
222    }
223}
224
225impl<'a> AttrSliceType<'a> for i64 {
226    fn get_attr_opt_slice(node: &'a NodeProto, name: &str) -> TractResult<Option<&'a [Self]>> {
227        node.get_attr_opt_with_type(name, AttributeType::Ints)?.and_ok(|a| &*a.ints)
228    }
229}
230
231impl<'a> AttrSliceType<'a> for f32 {
232    fn get_attr_opt_slice(node: &'a NodeProto, name: &str) -> TractResult<Option<&'a [Self]>> {
233        node.get_attr_opt_with_type(name, AttributeType::Floats)?.and_ok(|a| &*a.floats)
234    }
235}
236
237pub trait AttrTVecType<'a>: 'a + Sized {
238    fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>>;
239}
240
241impl<'a, T> AttrTVecType<'a> for T
242where
243    T: AttrSliceType<'a> + Clone,
244{
245    fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
246        T::get_attr_opt_slice(node, name)?.and_ok(Into::into)
247    }
248}
249
250impl<'a> AttrTVecType<'a> for &'a str {
251    fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
252        <Vec<u8>>::get_attr_opt_slice(node, name)?
253            .and_try(|b| b.iter().map(|v| str::from_utf8(v)).try_collect().map_err(Into::into))
254    }
255}
256
257impl<'a> AttrTVecType<'a> for String {
258    fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
259        <Vec<u8>>::get_attr_opt_slice(node, name)?.and_try(|b| {
260            b.iter().map(|v| str::from_utf8(v).map(Into::into)).try_collect().map_err(Into::into)
261        })
262    }
263}
264
265impl<'a> AttrTVecType<'a> for bool {
266    fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
267        let ints: Option<&[i64]> = AttrSliceType::get_attr_opt_slice(node, name)?;
268        ints.and_try(|ints| {
269            for int in ints.iter() {
270                node.expect_attr(name, *int == 0 || *int == 1, "list of booleans (0 or 1)")?;
271            }
272            Ok(ints.iter().map(|&x| x == 1).collect())
273        })
274    }
275}
276
277impl<'a> AttrTVecType<'a> for usize {
278    fn get_attr_opt_tvec(node: &'a NodeProto, name: &str) -> TractResult<Option<TVec<Self>>> {
279        let ints: Option<&[i64]> = AttrSliceType::get_attr_opt_slice(node, name)?;
280        ints.and_try(|ints| {
281            for int in ints.iter() {
282                node.expect_attr(name, *int >= 0, "list of non-negative ints")?;
283            }
284            Ok(ints.iter().map(|&x| x as _).collect())
285        })
286    }
287}
288
289impl NodeProto {
290    pub fn bail<T>(&self, msg: &str) -> TractResult<T> {
291        bail!("Node {} ({}): {}", self.name, self.op_type, msg)
292    }
293
294    pub fn bail_attr<T>(&self, attr: &str, msg: &str) -> TractResult<T> {
295        bail!("Node {} ({}), attribute '{}': {}", self.name, self.op_type, attr, msg)
296    }
297
298    pub fn expect<R: Reason>(&self, cond: bool, what: R) -> TractResult<()> {
299        if !cond {
300            self.bail(&format!("expected {}", what.reason()))
301        } else {
302            Ok(())
303        }
304    }
305
306    pub fn expect_attr<R: Reason>(&self, attr: &str, cond: bool, what: R) -> TractResult<()> {
307        if !cond {
308            self.bail_attr(attr, &format!("expected {}", what.reason()))
309        } else {
310            Ok(())
311        }
312    }
313
314    pub fn expect_ok_or_else<T, R: Reason>(&self, result: Option<T>, what: R) -> TractResult<T> {
315        match result {
316            Some(v) => Ok(v),
317            None => Err(self.expect(false, what).unwrap_err()),
318        }
319    }
320
321    fn get_attr_opt_with_type(
322        &self,
323        name: &str,
324        ty: AttributeType,
325    ) -> TractResult<Option<&AttributeProto>> {
326        let attr = match self.attribute.iter().find(|a| a.name == name) {
327            Some(attr) => attr,
328            _ => return Ok(None),
329        };
330        self.expect_attr(name, AttributeType::from_i32(attr.r#type).unwrap() == ty, || {
331            format!("{}, got {}", ty, attr.r#type)
332        })?;
333        Ok(Some(attr))
334    }
335
336    pub fn get_attr_opt<'a, T>(&'a self, name: &str) -> TractResult<Option<T>>
337    where
338        T: AttrScalarType<'a>,
339    {
340        T::get_attr_opt_scalar(self, name)
341    }
342
343    pub fn get_attr<'a, T>(&'a self, name: &str) -> TractResult<T>
344    where
345        T: AttrScalarType<'a>,
346    {
347        self.expect_ok_or_else(self.get_attr_opt(name)?, || format!("attribute '{name}'"))
348    }
349
350    pub fn check_value<T, V: Debug>(&self, attr: &str, value: Result<T, V>) -> TractResult<T> {
351        match value {
352            Ok(value) => Ok(value),
353            Err(err) => self.bail_attr(attr, &format!("unexpected value: {err:?}")),
354        }
355    }
356
357    pub fn get_attr_opt_slice<'a, T>(&'a self, name: &str) -> TractResult<Option<&'a [T]>>
358    where
359        T: AttrSliceType<'a>,
360    {
361        T::get_attr_opt_slice(self, name)
362    }
363
364    pub fn get_attr_slice<'a, T>(&'a self, name: &str) -> TractResult<&'a [T]>
365    where
366        T: AttrSliceType<'a>,
367    {
368        self.expect_ok_or_else(self.get_attr_opt_slice(name)?, || format!("attribute '{name}'"))
369    }
370
371    pub fn get_attr_opt_tvec<'a, T>(&'a self, name: &str) -> TractResult<Option<TVec<T>>>
372    where
373        T: AttrTVecType<'a>,
374    {
375        T::get_attr_opt_tvec(self, name)
376    }
377
378    pub fn get_attr_tvec<'a, T>(&'a self, name: &str) -> TractResult<TVec<T>>
379    where
380        T: AttrTVecType<'a>,
381    {
382        self.expect_ok_or_else(self.get_attr_opt_tvec(name)?, || format!("attribute '{name}'"))
383    }
384
385    pub fn get_attr_opt_vec<'a, T>(&'a self, name: &str) -> TractResult<Option<Vec<T>>>
386    where
387        T: AttrTVecType<'a>,
388    {
389        Ok(self.get_attr_opt_tvec(name)?.map(TVec::into_vec))
390    }
391
392    pub fn get_attr_vec<'a, T>(&'a self, name: &str) -> TractResult<Vec<T>>
393    where
394        T: AttrTVecType<'a>,
395    {
396        self.get_attr_tvec(name).map(TVec::into_vec)
397    }
398}