tract_tensorflow/
tfpb.rs

1use tract_hir::internal::*;
2
3use std::fs;
4
5#[allow(clippy::all)]
6mod google {
7    mod protobuf {
8        include!("prost/google.protobuf.rs");
9    }
10}
11
12#[allow(clippy::all)]
13pub mod tensorflow {
14    include!("prost/tensorflow.rs");
15}
16
17use self::tensorflow::attr_value::ListValue;
18use self::tensorflow::attr_value::Value;
19use self::tensorflow::{AttrValue, DataType, GraphDef, NodeDef, TensorProto, TensorShapeProto};
20
21use std::convert::TryInto;
22
23pub fn graph() -> GraphDef {
24    GraphDef::default()
25}
26
27pub fn node() -> NodeDef {
28    NodeDef {
29        name: String::new(),
30        op: String::new(),
31        input: vec![],
32        device: String::new(),
33        attr: HashMap::new(),
34    }
35}
36
37impl GraphDef {
38    pub fn node(mut self, n: NodeDef) -> Self {
39        self.node.push(n);
40        self
41    }
42    pub fn write_to_bytes(&self) -> TractResult<Vec<u8>> {
43        use prost::Message;
44        let mut buf = vec![];
45        self.encode(&mut buf)?;
46        Ok(buf)
47    }
48    pub fn save_to<P: AsRef<::std::path::Path>>(self, p: P) -> TractResult<()> {
49        let buf = self.write_to_bytes()?;
50        fs::write(p, buf)?;
51        Ok(())
52    }
53}
54
55impl NodeDef {
56    pub fn name<S: ToString>(mut self, n: S) -> NodeDef {
57        self.name = n.to_string();
58        self
59    }
60    pub fn op<S: ToString>(mut self, n: S) -> NodeDef {
61        self.op = n.to_string();
62        self
63    }
64    pub fn input<S: ToString>(mut self, n: S) -> NodeDef {
65        self.input.push(n.to_string());
66        self
67    }
68    pub fn attr<S: ToString, V: Into<AttrValue>>(mut self, n: S, v: V) -> NodeDef {
69        self.attr.insert(n.to_string(), v.into());
70        self
71    }
72}
73
74impl NodeDef {
75    pub fn get_attr_raw_str(&self, name: &str) -> TractResult<&[u8]> {
76        self.get_attr_opt_raw_str(name)?.with_context(|| {
77            format!("Node {} ({}) expected string attribute '{}'", self.name, self.op, name)
78        })
79    }
80
81    pub fn get_attr_opt_raw_str(&self, name: &str) -> TractResult<Option<&[u8]>> {
82        if let Some(a) = self.attr.get(name) {
83            if let Value::S(bytes) = a.value.as_ref().unwrap() {
84                return Ok(Some(bytes));
85            }
86        };
87        Ok(None)
88    }
89
90    pub fn get_attr_str(&self, name: &str) -> TractResult<String> {
91        self.get_attr_opt_str(name)?.with_context(|| {
92            format!("Node {} ({}) expected UTF-8 string attribute '{}'", self.name, self.op, name)
93        })
94    }
95
96    pub fn get_attr_opt_str(&self, name: &str) -> TractResult<Option<String>> {
97        if let Some(s) = self.get_attr_opt_raw_str(name)? {
98            Ok(Some(String::from_utf8(s.to_vec()).map_err(|_| {
99                format_err!(
100                    "Node {} ({}) expected an UTF-8 string for attribute '{}'",
101                    self.name,
102                    self.op,
103                    name
104                )
105            })?))
106        } else {
107            Ok(None)
108        }
109    }
110
111    pub fn get_attr_bool(&self, name: &str) -> TractResult<bool> {
112        self.get_attr_opt_bool(name)?.with_context(|| {
113            format!("Node {} ({}) expected bool attribute '{}'", self.name, self.op, name)
114        })
115    }
116
117    pub fn get_attr_opt_bool(&self, name: &str) -> TractResult<Option<bool>> {
118        if let Some(a) = self.attr.get(name) {
119            if let Value::B(v) = a.value.as_ref().unwrap() {
120                return Ok(Some(*v));
121            }
122        };
123        Ok(None)
124    }
125
126    pub fn get_attr_datum_type(&self, name: &str) -> TractResult<DatumType> {
127        self.get_attr_opt_datum_type(name)?.with_context(|| {
128            format!("Node {} ({}) expected datum_type attribute '{}'", self.name, self.op, name)
129        })
130    }
131
132    pub fn get_attr_opt_datum_type(&self, name: &str) -> TractResult<Option<DatumType>> {
133        if let Some(a) = self.attr.get(name) {
134            if let Value::Type(v) = a.value.as_ref().unwrap() {
135                return Ok(Some(DataType::from_i32(*v).unwrap().try_into()?));
136            }
137        };
138        Ok(None)
139    }
140
141    pub fn get_attr_shape(&self, name: &str) -> TractResult<TVec<isize>> {
142        self.get_attr_opt_shape(name)?.with_context(|| {
143            format!("Node {} ({}) expected shape attribute '{}'", self.name, self.op, name)
144        })
145    }
146
147    pub fn get_attr_opt_shape(&self, name: &str) -> TractResult<Option<TVec<isize>>> {
148        if let Some(a) = self.attr.get(name) {
149            if let Value::Shape(shape) = a.value.as_ref().unwrap() {
150                return Ok(Some(shape.try_into()?));
151            }
152        };
153        Ok(None)
154    }
155
156    pub fn get_attr_tensor(&self, name: &str) -> TractResult<Tensor> {
157        self.get_attr_opt_tensor(name)?.with_context(|| {
158            format!("Node {} ({}) expected tensor attribute '{}'", self.name, self.op, name)
159        })
160    }
161
162    pub fn get_attr_opt_tensor(&self, name: &str) -> TractResult<Option<Tensor>> {
163        if let Some(a) = self.attr.get(name) {
164            if let Value::Tensor(t) = a.value.as_ref().unwrap() {
165                return Ok(Some(t.try_into()?));
166            }
167        };
168        Ok(None)
169    }
170
171    pub fn get_attr_int<T: tract_num_traits::FromPrimitive>(&self, name: &str) -> TractResult<T> {
172        self.get_attr_opt_int(name)?.with_context(|| {
173            format!("Node {} ({}) expected int attribute '{}'", self.name, self.op, name)
174        })
175    }
176
177    pub fn get_attr_opt_int<T: tract_num_traits::FromPrimitive>(
178        &self,
179        name: &str,
180    ) -> TractResult<Option<T>> {
181        if let Some(a) = self.attr.get(name) {
182            if let Value::I(i) = a.value.as_ref().unwrap() {
183                return Ok(Some(T::from_i64(*i).unwrap()));
184            }
185        };
186        Ok(None)
187    }
188
189    pub fn get_attr_float<T: tract_num_traits::FromPrimitive>(&self, name: &str) -> TractResult<T> {
190        self.get_attr_opt_float(name)?.with_context(|| {
191            format!("Node {} ({}) expected int attribute '{}'", self.name, self.op, name)
192        })
193    }
194
195    pub fn get_attr_opt_float<T: tract_num_traits::FromPrimitive>(
196        &self,
197        name: &str,
198    ) -> TractResult<Option<T>> {
199        if let Some(a) = self.attr.get(name) {
200            if let Value::F(i) = a.value.as_ref().unwrap() {
201                return Ok(Some(T::from_f32(*i).unwrap()));
202            }
203        };
204        Ok(None)
205    }
206
207    pub fn get_attr_list_int<T: tract_num_traits::FromPrimitive>(
208        &self,
209        name: &str,
210    ) -> TractResult<Vec<T>> {
211        self.get_attr_opt_list_int(name)?.with_context(|| {
212            format!("Node {} ({}) expected list<int> attribute '{}'", self.name, self.op, name)
213        })
214    }
215
216    pub fn get_attr_opt_list_int<T: tract_num_traits::FromPrimitive>(
217        &self,
218        name: &str,
219    ) -> TractResult<Option<Vec<T>>> {
220        if let Some(a) = self.attr.get(name) {
221            if let Value::List(list) = a.value.as_ref().unwrap() {
222                return Ok(Some(list.i.iter().map(|&i| T::from_i64(i).unwrap()).collect()));
223            }
224        };
225        Ok(None)
226    }
227}
228
229impl From<DataType> for AttrValue {
230    fn from(t: DataType) -> AttrValue {
231        AttrValue { value: Some(Value::Type(t.into())) }
232    }
233}
234
235impl<'a> From<&'a str> for AttrValue {
236    fn from(t: &'a str) -> AttrValue {
237        AttrValue { value: Some(Value::S(t.as_bytes().to_vec())) }
238    }
239}
240
241impl From<i32> for AttrValue {
242    fn from(t: i32) -> AttrValue {
243        AttrValue::from(t as i64)
244    }
245}
246
247impl From<i64> for AttrValue {
248    fn from(t: i64) -> AttrValue {
249        AttrValue { value: Some(Value::I(t)) }
250    }
251}
252
253impl From<f32> for AttrValue {
254    fn from(t: f32) -> AttrValue {
255        AttrValue { value: Some(Value::F(t)) }
256    }
257}
258
259impl From<Vec<i64>> for AttrValue {
260    fn from(t: Vec<i64>) -> AttrValue {
261        AttrValue {
262            value: Some(Value::List(ListValue {
263                s: vec![],
264                i: t,
265                f: vec![],
266                b: vec![],
267                r#type: vec![],
268                shape: vec![],
269                tensor: vec![],
270                func: vec![],
271            })),
272        }
273    }
274}
275
276impl From<TensorProto> for AttrValue {
277    fn from(t: TensorProto) -> AttrValue {
278        AttrValue { value: Some(Value::Tensor(t)) }
279    }
280}
281
282impl From<TensorShapeProto> for AttrValue {
283    fn from(t: TensorShapeProto) -> AttrValue {
284        AttrValue { value: Some(Value::Shape(t)) }
285    }
286}
287
288impl From<bool> for AttrValue {
289    fn from(t: bool) -> AttrValue {
290        AttrValue { value: Some(Value::B(t)) }
291    }
292}