1use tract_hir::internal::*;
2
3use std::fs;
4
5#[allow(clippy::all)]
6mod google {
7 mod protobuf {
8 include!("prost/google.protobuf.rs");
9 }
10}
11
12#[allow(clippy::all)]
13pub mod tensorflow {
14 include!("prost/tensorflow.rs");
15}
16
17use self::tensorflow::attr_value::ListValue;
18use self::tensorflow::attr_value::Value;
19use self::tensorflow::{AttrValue, DataType, GraphDef, NodeDef, TensorProto, TensorShapeProto};
20
21use std::convert::TryInto;
22
23pub fn graph() -> GraphDef {
24 GraphDef::default()
25}
26
27pub fn node() -> NodeDef {
28 NodeDef {
29 name: String::new(),
30 op: String::new(),
31 input: vec![],
32 device: String::new(),
33 attr: HashMap::new(),
34 }
35}
36
37impl GraphDef {
38 pub fn node(mut self, n: NodeDef) -> Self {
39 self.node.push(n);
40 self
41 }
42 pub fn write_to_bytes(&self) -> TractResult<Vec<u8>> {
43 use prost::Message;
44 let mut buf = vec![];
45 self.encode(&mut buf)?;
46 Ok(buf)
47 }
48 pub fn save_to<P: AsRef<::std::path::Path>>(self, p: P) -> TractResult<()> {
49 let buf = self.write_to_bytes()?;
50 fs::write(p, buf)?;
51 Ok(())
52 }
53}
54
55impl NodeDef {
56 pub fn name<S: ToString>(mut self, n: S) -> NodeDef {
57 self.name = n.to_string();
58 self
59 }
60 pub fn op<S: ToString>(mut self, n: S) -> NodeDef {
61 self.op = n.to_string();
62 self
63 }
64 pub fn input<S: ToString>(mut self, n: S) -> NodeDef {
65 self.input.push(n.to_string());
66 self
67 }
68 pub fn attr<S: ToString, V: Into<AttrValue>>(mut self, n: S, v: V) -> NodeDef {
69 self.attr.insert(n.to_string(), v.into());
70 self
71 }
72}
73
74impl NodeDef {
75 pub fn get_attr_raw_str(&self, name: &str) -> TractResult<&[u8]> {
76 self.get_attr_opt_raw_str(name)?.with_context(|| {
77 format!("Node {} ({}) expected string attribute '{}'", self.name, self.op, name)
78 })
79 }
80
81 pub fn get_attr_opt_raw_str(&self, name: &str) -> TractResult<Option<&[u8]>> {
82 if let Some(a) = self.attr.get(name) {
83 if let Value::S(bytes) = a.value.as_ref().unwrap() {
84 return Ok(Some(bytes));
85 }
86 };
87 Ok(None)
88 }
89
90 pub fn get_attr_str(&self, name: &str) -> TractResult<String> {
91 self.get_attr_opt_str(name)?.with_context(|| {
92 format!("Node {} ({}) expected UTF-8 string attribute '{}'", self.name, self.op, name)
93 })
94 }
95
96 pub fn get_attr_opt_str(&self, name: &str) -> TractResult<Option<String>> {
97 if let Some(s) = self.get_attr_opt_raw_str(name)? {
98 Ok(Some(String::from_utf8(s.to_vec()).map_err(|_| {
99 format_err!(
100 "Node {} ({}) expected an UTF-8 string for attribute '{}'",
101 self.name,
102 self.op,
103 name
104 )
105 })?))
106 } else {
107 Ok(None)
108 }
109 }
110
111 pub fn get_attr_bool(&self, name: &str) -> TractResult<bool> {
112 self.get_attr_opt_bool(name)?.with_context(|| {
113 format!("Node {} ({}) expected bool attribute '{}'", self.name, self.op, name)
114 })
115 }
116
117 pub fn get_attr_opt_bool(&self, name: &str) -> TractResult<Option<bool>> {
118 if let Some(a) = self.attr.get(name) {
119 if let Value::B(v) = a.value.as_ref().unwrap() {
120 return Ok(Some(*v));
121 }
122 };
123 Ok(None)
124 }
125
126 pub fn get_attr_datum_type(&self, name: &str) -> TractResult<DatumType> {
127 self.get_attr_opt_datum_type(name)?.with_context(|| {
128 format!("Node {} ({}) expected datum_type attribute '{}'", self.name, self.op, name)
129 })
130 }
131
132 pub fn get_attr_opt_datum_type(&self, name: &str) -> TractResult<Option<DatumType>> {
133 if let Some(a) = self.attr.get(name) {
134 if let Value::Type(v) = a.value.as_ref().unwrap() {
135 return Ok(Some(DataType::from_i32(*v).unwrap().try_into()?));
136 }
137 };
138 Ok(None)
139 }
140
141 pub fn get_attr_shape(&self, name: &str) -> TractResult<TVec<isize>> {
142 self.get_attr_opt_shape(name)?.with_context(|| {
143 format!("Node {} ({}) expected shape attribute '{}'", self.name, self.op, name)
144 })
145 }
146
147 pub fn get_attr_opt_shape(&self, name: &str) -> TractResult<Option<TVec<isize>>> {
148 if let Some(a) = self.attr.get(name) {
149 if let Value::Shape(shape) = a.value.as_ref().unwrap() {
150 return Ok(Some(shape.try_into()?));
151 }
152 };
153 Ok(None)
154 }
155
156 pub fn get_attr_tensor(&self, name: &str) -> TractResult<Tensor> {
157 self.get_attr_opt_tensor(name)?.with_context(|| {
158 format!("Node {} ({}) expected tensor attribute '{}'", self.name, self.op, name)
159 })
160 }
161
162 pub fn get_attr_opt_tensor(&self, name: &str) -> TractResult<Option<Tensor>> {
163 if let Some(a) = self.attr.get(name) {
164 if let Value::Tensor(t) = a.value.as_ref().unwrap() {
165 return Ok(Some(t.try_into()?));
166 }
167 };
168 Ok(None)
169 }
170
171 pub fn get_attr_int<T: tract_num_traits::FromPrimitive>(&self, name: &str) -> TractResult<T> {
172 self.get_attr_opt_int(name)?.with_context(|| {
173 format!("Node {} ({}) expected int attribute '{}'", self.name, self.op, name)
174 })
175 }
176
177 pub fn get_attr_opt_int<T: tract_num_traits::FromPrimitive>(
178 &self,
179 name: &str,
180 ) -> TractResult<Option<T>> {
181 if let Some(a) = self.attr.get(name) {
182 if let Value::I(i) = a.value.as_ref().unwrap() {
183 return Ok(Some(T::from_i64(*i).unwrap()));
184 }
185 };
186 Ok(None)
187 }
188
189 pub fn get_attr_float<T: tract_num_traits::FromPrimitive>(&self, name: &str) -> TractResult<T> {
190 self.get_attr_opt_float(name)?.with_context(|| {
191 format!("Node {} ({}) expected int attribute '{}'", self.name, self.op, name)
192 })
193 }
194
195 pub fn get_attr_opt_float<T: tract_num_traits::FromPrimitive>(
196 &self,
197 name: &str,
198 ) -> TractResult<Option<T>> {
199 if let Some(a) = self.attr.get(name) {
200 if let Value::F(i) = a.value.as_ref().unwrap() {
201 return Ok(Some(T::from_f32(*i).unwrap()));
202 }
203 };
204 Ok(None)
205 }
206
207 pub fn get_attr_list_int<T: tract_num_traits::FromPrimitive>(
208 &self,
209 name: &str,
210 ) -> TractResult<Vec<T>> {
211 self.get_attr_opt_list_int(name)?.with_context(|| {
212 format!("Node {} ({}) expected list<int> attribute '{}'", self.name, self.op, name)
213 })
214 }
215
216 pub fn get_attr_opt_list_int<T: tract_num_traits::FromPrimitive>(
217 &self,
218 name: &str,
219 ) -> TractResult<Option<Vec<T>>> {
220 if let Some(a) = self.attr.get(name) {
221 if let Value::List(list) = a.value.as_ref().unwrap() {
222 return Ok(Some(list.i.iter().map(|&i| T::from_i64(i).unwrap()).collect()));
223 }
224 };
225 Ok(None)
226 }
227}
228
229impl From<DataType> for AttrValue {
230 fn from(t: DataType) -> AttrValue {
231 AttrValue { value: Some(Value::Type(t.into())) }
232 }
233}
234
235impl<'a> From<&'a str> for AttrValue {
236 fn from(t: &'a str) -> AttrValue {
237 AttrValue { value: Some(Value::S(t.as_bytes().to_vec())) }
238 }
239}
240
241impl From<i32> for AttrValue {
242 fn from(t: i32) -> AttrValue {
243 AttrValue::from(t as i64)
244 }
245}
246
247impl From<i64> for AttrValue {
248 fn from(t: i64) -> AttrValue {
249 AttrValue { value: Some(Value::I(t)) }
250 }
251}
252
253impl From<f32> for AttrValue {
254 fn from(t: f32) -> AttrValue {
255 AttrValue { value: Some(Value::F(t)) }
256 }
257}
258
259impl From<Vec<i64>> for AttrValue {
260 fn from(t: Vec<i64>) -> AttrValue {
261 AttrValue {
262 value: Some(Value::List(ListValue {
263 s: vec![],
264 i: t,
265 f: vec![],
266 b: vec![],
267 r#type: vec![],
268 shape: vec![],
269 tensor: vec![],
270 func: vec![],
271 })),
272 }
273 }
274}
275
276impl From<TensorProto> for AttrValue {
277 fn from(t: TensorProto) -> AttrValue {
278 AttrValue { value: Some(Value::Tensor(t)) }
279 }
280}
281
282impl From<TensorShapeProto> for AttrValue {
283 fn from(t: TensorShapeProto) -> AttrValue {
284 AttrValue { value: Some(Value::Shape(t)) }
285 }
286}
287
288impl From<bool> for AttrValue {
289 fn from(t: bool) -> AttrValue {
290 AttrValue { value: Some(Value::B(t)) }
291 }
292}